In [None]:
import os
import json
from pathlib import Path


import torch
from PIL import Image
from torchvision import transforms


# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
import torch.nn as nn


class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        
        
        nn.Conv2d(32, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        
        
        nn.Conv2d(64, 128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),
        
        
        nn.Conv2d(128, 256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.AdaptiveAvgPool2d((4, 4)),
        )
        self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Dropout(0.5),
        nn.Linear(256 * 4 * 4, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.4),
        nn.Linear(512, 1) # placeholder
        )
    
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
from torchvision import transforms


def infer(data_dir, model_path):
# load label map
    label_map_path = 'label_map.json'
    if not os.path.exists(label_map_path):
        raise FileNotFoundError('label_map.json not found. Make sure train.ipynb saved it in cwd')
    with open(label_map_path, 'r') as f:
        label_map = json.load(f) # "idx": "class_name"
    # convert keys to int (they were strings when dumped)
        label_map = {int(k): v for k, v in label_map.items()}


# load checkpoint
    ckpt = torch.load(model_path, map_location=device)
    class_to_idx = ckpt.get('class_to_idx')
    if class_to_idx is None:
        raise RuntimeError('Saved checkpoint missing class_to_idx')


    idx_to_class = {v: k for k, v in class_to_idx.items()} # ensure mapping
    num_classes = len(idx_to_class)
    
    
    # instantiate model and replace last layer to match saved num_classes
    model = SimpleCNN(num_classes=1) # temporary
    # reconstruct classifier final layer properly
    model = SimpleCNN(num_classes=1) # features match; now patch classifier
    # We'll create a fresh classifier with correct output dim
    model.classifier = nn.Sequential(
    nn.Flatten(),
    nn.Dropout(0.5),
    nn.Linear(256 * 4 * 4, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.4),
    nn.Linear(512, num_classes)
    )
    
    
    model.load_state_dict(ckpt['model_state_dict'])
    model.to(device)
    model.eval()
    
    
    # transforms (use the same normalization/resize as training val)
    img_size = ckpt.get('img_size', 128)
    transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


    data_dir = Path(data_dir)
    images = [p for p in data_dir.iterdir() if p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
    
    
    results = {}
    with torch.no_grad():
        for p in images:
            img = Image.open(p).convert('RGB')
            x = transform(img).unsqueeze(0).to(device)
            logits = model(x)
            pred = logits.argmax(dim=1).item()
            cls_name = idx_to_class[pred]
            results[p.name] = cls_name


# save results.json
    with open('results.json', 'w') as f:
        json.dump(results, f, indent=2)


    print('Wrote results.json with', len(results), 'predictions')
    return results