In [1]:
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import json

class SimpsonsCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpsonsCNN, self).__init__()
        
        def conv_block(in_c, out_c, pool=True):
            layers = [
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            ]
            if pool:
                layers.append(nn.MaxPool2d(2))
            return nn.Sequential(*layers)
        
        self.block1 = conv_block(3, 32)
        self.block2 = conv_block(32, 64)
        self.block3 = conv_block(64, 128)
        self.block4 = conv_block(128, 256)
        
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

In [2]:
from tqdm.auto import tqdm

def infer(data_dir, model_path):
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print("="*80)
    print("INFERENCE MODE")
    print("="*80)
    print(f"Device: {device}")
    print(f"Model: {model_path}")
    print(f"Data directory: {data_dir}")
    print("="*80)
    
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"‚úó Model not found: {model_path}")
    
    print("\nüì¶ Loading model...")
    checkpoint = torch.load(model_path, map_location=device)
    class_names = checkpoint['class_names']
    img_size = checkpoint['config']['img_size']
    
    model = SimpsonsCNN(num_classes=len(class_names))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    print(f"‚úì Model loaded | Classes: {len(class_names)} | Image size: {img_size}x{img_size}")
    
    infer_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    if not os.path.isdir(data_dir):
        print(f"‚úó Directory not found: {data_dir}")
        return
    
    valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
    file_list = []
    
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.lower().endswith(valid_extensions):
                full_path = os.path.join(root, file)
                rel_path = os.path.relpath(full_path, data_dir)
                file_list.append((full_path, rel_path))
    
    if len(file_list) == 0:
        print(f"‚úó No images found in {data_dir}")
        return
    
    print(f"\nüñºÔ∏è Found {len(file_list)} images")
    print("="*80)
    
    results = {}
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(
            file_list,
            desc="üìä Processing images",
            unit="img",
            bar_format='{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}] {postfix}',
            colour='green'
        )
        
        for full_path, rel_path in pbar:
            try:
                image = Image.open(full_path).convert('RGB')
                input_tensor = infer_transform(image).unsqueeze(0).to(device)
                
                outputs = model(input_tensor)
                _, predicted_idx = torch.max(outputs, 1)
                predicted_class = class_names[predicted_idx.item()]
                
                true_label = os.path.basename(os.path.dirname(full_path))
                
                results[rel_path] = {
                    'predicted': predicted_class,
                    'true_label': true_label
                }
                
                if true_label in class_names:
                    total += 1
                    if predicted_class == true_label:
                        correct += 1
                
                accuracy = (correct / total * 100) if total > 0 else 0
                pbar.set_postfix({'pred': predicted_class[:12], 'acc': f'{accuracy:.1f}%'})
                
            except Exception as e:
                print(f"\n‚ö†Ô∏è Skipping {rel_path}: {e}")
                continue
    
    output_file = 'results.json'
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"\n{'='*80}")
    print("INFERENCE COMPLETE")
    print("="*80)
    print(f"Total images processed: {len(results)}")
    if total > 0:
        print(f"Accuracy: {correct}/{total} = {correct/total*100:.2f}%")
    print(f"Results saved: {output_file}")
    print("="*80)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
if __name__ == "__main__":
    infer('characters_train/bart_simpson', 'model.pth')
    pass

INFERENCE MODE
Device: cpu
Model: model.pth
Data directory: characters_train/bart_simpson

üì¶ Loading model...
‚úì Model loaded | Classes: 42 | Image size: 128x128

üñºÔ∏è Found 1074 images


üìä Processing images: 100%|[32m‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà[0m| 1074/1074 [00:37<00:00, 28.35img/s] , pred=bart_simpson, acc=97.2%


‚úÖ INFERENCE COMPLETE
Total images processed: 1074
Accuracy: 1044/1074 = 97.21%
Results saved: results.json



