# Task 3: The Interrogation (Grad-CAM)
## "To see is to know."

In this notebook, we implement **Grad-CAM (Gradient-weighted Class Activation Mapping)** from scratch to visualize *where* the model is looking.

### Hypothesis
- **Biased Input (Red 0)**: The model should look at the **color** (random pixels or the whole digit blob) rather than the specific shape features.
- **Conflicting Input (Green 0)**:
    - If it predicts **0** (Shape), the heatmap should focus on the **digit stroke**.
    - If it predicts **1** (Color), the heatmap might look for **Green pixels** anywhere.


In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import cv2

# Add project root to path
sys.path.append(os.path.abspath('..'))

from src.models.simple_cnn import SimpleCNN
from src.gradcam import GradCAM
from src.data.biased_mnist import BiasedMNIST
import torch.nn.functional as F
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## 1. Load Model & Data
We reload the cheater model and the dataset.


In [None]:
# Define Model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.fc = nn.Linear(16 * 7 * 7, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 7 * 7)
        x = self.fc(x)
        return x

model = SimpleCNN().to(device)

# Quick Retrain (Same as Task 2, to ensure state)
import torch.optim as optim
print("Re-training Cheater (Quickly)...")
train_dataset = BiasedMNIST(root='./data', train=True, download=True, bias_ratio=0.995)
# Test set for conflicting examples
test_dataset = BiasedMNIST(root='./data', train=False, download=True)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(2): 
    for images, labels, colors in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
print("Model Ready.")
model.eval()


## 2. Grad-CAM Setup
We hook into the last convolutional layer (`conv2`).


In [None]:
# Hook into conv2
grad_cam = GradCAM(model, model.conv2)

def interpret_image(image_tensor, label, title_prefix=""):
    # image_tensor: [1, 3, 28, 28]
    heatmap, pred_class = grad_cam.generate_cam(image_tensor.to(device), target_class=None)
    
    # Overlay
    img_np = image_tensor.squeeze().permute(1, 2, 0).numpy()
    result = GradCAM.overlay_heatmap(img_np, heatmap)
    
    plt.figure(figsize=(10, 4))
    
    plt.subplot(1, 3, 1)
    plt.imshow(img_np)
    plt.title(f"{title_prefix} Input (True: {label})")
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(heatmap, cmap='jet')
    plt.title(f"Heatmap (Pred: {pred_class})")
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(result)
    plt.title("Overlay")
    plt.axis('off')
    plt.show()
    
    return pred_class


### Experiment A: Biased Input (Red 0)
Let's see a "normal" training-distribution example. The model should predict this correctly (0).
Does it look at the shape or just the red color?


In [None]:
# Get a Training Example (Red 0)
# We iterate until we find a 0 with Red color (index 0)
found = False
for imgs, lbls, clrs in train_loader:
    for i in range(len(lbls)):
        if lbls[i] == 0: # 0 is correlated with Red (color 0)
            # BiasedMNIST doesn't return color index in loader? It returns (img, target, color_idx)
            # Oh wait, the implemented __getitem__ returns (img, target, color_idx)
            # But DataLoader collates them.
            if clrs[i] == 0: # Red
                sample_red_0 = imgs[i].unsqueeze(0)
                found = True
                break
    if found: break

interpret_image(sample_red_0, label=0, title_prefix="Biased (Red 0)")


### Experiment B: Conflicting Input (Green 0)
Now the interesting part. A **0** that is **Green** (which usually means 1).
- If it predicts **0**, it overcame the bias. Where did it look? (Shape?)
- If it predicts **1**, it succumbed to the bias. Where did it look? (Green color?)


In [None]:
# Get a Test Example (Green 0)
# In test set, colors are random but conflicting. 
# We look for a 0 that happens to be Green (color 1).
found = False
for imgs, lbls, clrs in test_loader:
    for i in range(len(lbls)):
        if lbls[i] == 0 and clrs[i] == 1: # Green 0
            sample_green_0 = imgs[i].unsqueeze(0)
            found = True
            break
    if found: break

if found:
    pred = interpret_image(sample_green_0, label=0, title_prefix="Conflicting (Green 0)")
    print(f"Model Predicted: {pred} (Should be 0, likely 1)")
else:
    print("No Green 0 found in this batch/loader subset. Try checking more batches.")


### Experiment C: The Trap (Red 1)
The famous "Red 1". Red usually means 0.


In [None]:
# Manually construct Red 1
# Get a 1 from raw dataset
raw_data = BiasedMNIST(root='./data', train=True, download=True)
idx = (raw_data.targets == 1).nonzero(as_tuple=True)[0][0]
img_raw = raw_data.data[idx] # 28x28
img_pil = transforms.ToPILImage()(img_raw)
img_tensor = transforms.ToTensor()(img_pil)

# Color it Red
red_color = torch.tensor([1.0, 0.0, 0.0]).view(3, 1, 1)
red_1 = img_tensor * red_color
red_1_batch = red_1.unsqueeze(0)

interpret_image(red_1_batch, label=1, title_prefix="Trap (Red 1)")
