# The "Springfield" Identity - Inference Notebook


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
import os
WORK_DIR = '/content/drive/MyDrive/shared/bonusHW'
os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)


In [None]:
import torch


In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from pathlib import Path
import json


In [None]:
class SimpsonsCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpsonsCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        
        self.fc1 = nn.Linear(512 * 4 * 4, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)
        
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = self.pool(self.relu(self.bn3(self.conv3(x))))
        x = self.pool(self.relu(self.bn4(self.conv4(x))))
        x = self.pool(self.relu(self.bn5(self.conv5(x))))
        
        x = x.view(x.size(0), -1)
        
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.dropout(self.relu(self.fc2(x)))
        x = self.fc3(x)
        
        return x


In [None]:
def infer(data_dir, model_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    class_mappings_path = os.path.join(WORK_DIR, 'class_mappings.json')
    if not os.path.exists(class_mappings_path):
        class_mappings_path = 'class_mappings.json'
    
    with open(class_mappings_path, 'r') as f:
        class_mappings = json.load(f)
    
    idx_to_class = {int(k): v for k, v in class_mappings['idx_to_class'].items()}
    num_classes = len(idx_to_class)
    
    model = SimpsonsCNN(num_classes=num_classes).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    data_path = Path(data_dir)
    image_files = sorted(list(data_path.glob('*.jpg')))
    
    if len(image_files) == 0:
        image_files = sorted(list(data_path.glob('*.png'))) + sorted(list(data_path.glob('*.jpeg')))
    
    results = {}
    
    with torch.no_grad():
        for idx, img_path in enumerate(image_files):
            try:
                image = Image.open(img_path).convert('RGB')
                image_tensor = transform(image).unsqueeze(0).to(device)
                
                outputs = model(image_tensor)
                _, predicted = torch.max(outputs, 1)
                predicted_class_idx = predicted.item()
                predicted_class_name = idx_to_class[predicted_class_idx]
                
                results[img_path.name] = predicted_class_name
                
                if (idx + 1) % 100 == 0:
                    print(f"Processed {idx + 1}/{len(image_files)} images...")
                
            except Exception as e:
                print(f"Error processing {img_path}: {e}")
                results[img_path.name] = list(idx_to_class.values())[0]
    
    with open('results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Predictions completed. Total: {len(results)}")
    
    return results


In [None]:
# results = infer('test_data_dir', os.path.join(WORK_DIR, 'model.pth'))
