In [None]:
# Occlusion Sensitivity Analysis (Part-B)
# Kernel: Python (ml_torch)
# Loads best model automatically from checkpoint


In [None]:

import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder


In [None]:

DATASET_ROOT = "../dataset"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OCCLUSION_SIZE = 8
STRIDE = 2

transform = transforms.Compose([
    transforms.Resize((84, 84)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_dataset = ImageFolder(os.path.join(DATASET_ROOT, "test"), transform=transform)
class_names = test_dataset.classes


In [None]:

class ConfigurableCNN(nn.Module):
    def __init__(self, conv_filters, fc_layers, use_pool=True, stride=1, num_classes=None):
        super().__init__()
        layers = []
        in_channels = 3
        for f in conv_filters:
            layers.append(nn.Conv2d(in_channels, f, 3, stride=stride, padding=1))
            layers.append(nn.ReLU())
            if use_pool:
                layers.append(nn.MaxPool2d(2))
            in_channels = f
        self.features = nn.Sequential(*layers)

        with torch.no_grad():
            dummy = torch.zeros(1, 3, 84, 84)
            out = self.features(dummy)
            flat_dim = out.view(1, -1).size(1)

        fc = []
        in_dim = flat_dim
        for h in fc_layers:
            fc.append(nn.Linear(in_dim, h))
            fc.append(nn.ReLU())
            in_dim = h
        fc.append(nn.Linear(in_dim, num_classes))
        self.classifier = nn.Sequential(*fc)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)


In [None]:

checkpoint = torch.load("best_model_checkpoint.pth", map_location=DEVICE)
cfg = checkpoint["config"]

model = ConfigurableCNN(
    conv_filters=cfg["conv_filters"],
    fc_layers=cfg["fc_layers"],
    use_pool=cfg["use_pool"],
    stride=cfg["stride"],
    num_classes=len(class_names)
).to(DEVICE)

model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print("Loaded model config:", cfg)


In [None]:
# Show a few misclassified test images

misclassified = []

model.eval()
with torch.no_grad():
    for img, label in test_dataset:
        img = img.unsqueeze(0).to(DEVICE)
        output = model(img)
        pred = output.argmax(dim=1).item()

        if pred != label:
            misclassified.append((img.cpu().squeeze(), label, pred))
        
        if len(misclassified) == 5:
            break

# Plot misclassified images
plt.figure(figsize=(12, 4))
for i, (img, true_label, pred_label) in enumerate(misclassified):
    plt.subplot(1, 5, i + 1)
    plt.imshow(img.permute(1, 2, 0))
    plt.title(f"T: {class_names[true_label]}\nP: {class_names[pred_label]}")
    plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:

def occlusion_sensitivity(model, image, true_label):
    _, H, W = image.shape
    confidence = np.zeros((H // STRIDE, W // STRIDE))
    image = image.to(DEVICE)

    with torch.no_grad():
        for i in range(0, H, STRIDE):
            for j in range(0, W, STRIDE):
                occluded = image.clone()
                occluded[:, i:i+OCCLUSION_SIZE, j:j+OCCLUSION_SIZE] = 0.0
                out = model(occluded.unsqueeze(0))
                prob = torch.softmax(out, 1)[0, true_label].item()
                confidence[i//STRIDE, j//STRIDE] = prob
    return confidence


In [None]:

indices = list(range(10))

for idx in indices:
    img, label = test_dataset[idx]
    conf_map = occlusion_sensitivity(model, img, label)

    plt.figure(figsize=(6, 3))
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
    plt.title(class_names[label])
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(conf_map, cmap="hot")
    plt.title("Occlusion Sensitivity")
    plt.colorbar()
    plt.show()
