# 6. Grad-CAM: Visualizing What the Model Sees

## Overview

Deep learning models are often called "black boxes". In this notebook, we'll use **Grad-CAM** (Gradient-weighted Class Activation Mapping) to understand what our models are looking at.

### What is Grad-CAM?

Grad-CAM produces a heatmap showing which parts of an image are most important for the model's prediction.

**How it works:**
1. Forward pass: Get feature maps from the last convolutional layer
2. Backward pass: Get gradients of the target class score
3. Weight feature maps by their gradient importance
4. Create a heatmap highlighting important regions

### Why This Matters:

| Question | Grad-CAM Helps Answer |
|----------|----------------------|
| Is the model cheating? | Check if it looks at faces or backgrounds |
| Why did it misclassify? | See what features confused it |
| What defines each emotion? | Which facial regions matter most |
| Can we trust the model? | Verify it learns meaningful patterns |

## Step 1: Import Libraries

In [None]:
import numpy as np
import pandas as pd
import pickle
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Step 2: Configuration

In [None]:
IMG_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

EMOTION_CLASSES = ["anger", "disgust", "fear", "happiness", "neutral", "sadness", "surprise"]
NUM_CLASSES = len(EMOTION_CLASSES)
IDX_TO_EMOTION = {i: e for i, e in enumerate(EMOTION_CLASSES)}

test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

print("Configuration loaded!")

## Step 3: Load Data and Model

In [None]:
# Load test data
with open('processed_data.pkl', 'rb') as f:
    data = pickle.load(f)

test_df = data['test_df'].reset_index(drop=True)
print(f"Test samples: {len(test_df)}")

In [None]:
# Define model classes (same as training notebooks)
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        self.bn = nn.BatchNorm2d(out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x = self.pool(x)
        return x

class CustomCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, dropout=0.5):
        super().__init__()
        self.conv1 = ConvBlock(3, 32)
        self.conv2 = ConvBlock(32, 64)
        self.conv3 = ConvBlock(64, 128)
        self.conv4 = ConvBlock(128, 256)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(256, 128)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class TransferLearningModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, freeze_backbone=False):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(num_features, num_classes))
    
    def forward(self, x):
        return self.backbone(x)

print("Model classes defined!")

In [None]:
# Try to load available models
models_loaded = {}

# Custom CNN
try:
    custom_cnn = CustomCNN()
    checkpoint = torch.load('custom_cnn_best.pth', map_location=device, weights_only=False)
    custom_cnn.load_state_dict(checkpoint['model_state_dict'])
    custom_cnn = custom_cnn.to(device)
    custom_cnn.eval()
    models_loaded['Custom CNN'] = custom_cnn
    print("Loaded Custom CNN")
except FileNotFoundError:
    print("Custom CNN checkpoint not found - run notebook 03 first")

# Transfer Learning
try:
    transfer_model = TransferLearningModel()
    checkpoint = torch.load('transfer_learning_best.pth', map_location=device, weights_only=False)
    transfer_model.load_state_dict(checkpoint['model_state_dict'])
    transfer_model = transfer_model.to(device)
    transfer_model.eval()
    models_loaded['ResNet18'] = transfer_model
    print("Loaded ResNet18")
except FileNotFoundError:
    print("Transfer Learning checkpoint not found - run notebook 04 first")

print(f"\nLoaded {len(models_loaded)} models")

## Step 4: Implement Grad-CAM

In [None]:
class GradCAM:
    """
    Grad-CAM implementation for visualizing CNN attention.
    
    How it works:
    1. Register hooks to capture feature maps and gradients
    2. Forward pass to get prediction
    3. Backward pass to get gradients of target class
    4. Weight feature maps by gradient importance
    5. Create heatmap showing important regions
    """
    
    def __init__(self, model, model_type='custom'):
        self.model = model
        self.model.eval()
        self.model_type = model_type
        
        # Storage for hooks
        self.feature_maps = None
        self.gradients = None
        
        # Register hooks on target layer
        self._register_hooks()
    
    def _register_hooks(self):
        """
        Register forward and backward hooks on the target layer.
        
        Target layer is the last convolutional layer:
        - Custom CNN: conv4
        - ResNet18: layer4
        """
        if self.model_type == 'custom':
            target_layer = self.model.conv4
        else:  # ResNet18
            target_layer = self.model.backbone.layer4
        
        # Forward hook: save feature maps
        def forward_hook(module, input, output):
            self.feature_maps = output.detach()
        
        # Backward hook: save gradients
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        
        target_layer.register_forward_hook(forward_hook)
        target_layer.register_full_backward_hook(backward_hook)
    
    def generate_heatmap(self, image_tensor, target_class=None):
        """
        Generate Grad-CAM heatmap for an image.
        
        Args:
            image_tensor: Preprocessed image tensor (1, 3, H, W)
            target_class: Class to visualize (None = predicted class)
            
        Returns:
            heatmap: 2D numpy array (H, W) with values 0-1
            pred_class: Predicted class index
            confidence: Softmax probability
        """
        # Forward pass
        output = self.model(image_tensor)
        probs = F.softmax(output, dim=1)
        
        # Get target class
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        confidence = probs[0, target_class].item()
        
        # Backward pass
        self.model.zero_grad()
        output[0, target_class].backward()
        
        # Get gradients and feature maps
        gradients = self.gradients  # (1, C, H, W)
        feature_maps = self.feature_maps  # (1, C, H, W)
        
        # Global average pooling of gradients -> importance weights
        weights = gradients.mean(dim=[2, 3], keepdim=True)  # (1, C, 1, 1)
        
        # Weighted combination of feature maps
        cam = (weights * feature_maps).sum(dim=1, keepdim=True)  # (1, 1, H, W)
        
        # ReLU to keep only positive contributions
        cam = F.relu(cam)
        
        # Resize to input size
        cam = F.interpolate(cam, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
        
        # Normalize to 0-1
        cam = cam.squeeze().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam, target_class, confidence

print("GradCAM class defined!")

## Step 5: Helper Functions for Visualization

In [None]:
def denormalize_image(image_tensor):
    """
    Convert normalized tensor back to displayable image.
    """
    mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
    std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
    image = image_tensor.cpu() * std + mean
    image = image.permute(1, 2, 0).numpy()
    image = np.clip(image * 255, 0, 255).astype(np.uint8)
    return image


def overlay_heatmap(image, heatmap, alpha=0.5):
    """
    Overlay heatmap on image.
    
    Args:
        image: RGB image (H, W, 3)
        heatmap: 2D array (H, W) with values 0-1
        alpha: Blending factor
        
    Returns:
        Blended image
    """
    # Apply colormap (jet: blue=low, red=high)
    colormap = cm.jet(heatmap)[:, :, :3]
    colormap = (colormap * 255).astype(np.uint8)
    
    # Blend
    blended = (1 - alpha) * image + alpha * colormap
    blended = blended.astype(np.uint8)
    
    return blended


def load_and_preprocess(image_path):
    """Load and preprocess an image."""
    image = Image.open(image_path).convert('RGB')
    tensor = test_transform(image).unsqueeze(0).to(device)
    return tensor

print("Helper functions defined!")

## Step 6: Visualize Grad-CAM for Sample Images

In [None]:
def visualize_gradcam_samples(model, model_name, model_type, test_df, n_samples=8):
    """
    Generate Grad-CAM visualizations for random test samples.
    """
    gradcam = GradCAM(model, model_type)
    
    # Get random samples
    samples = test_df.sample(n=n_samples, random_state=42)
    
    fig, axes = plt.subplots(n_samples, 4, figsize=(14, 3*n_samples))
    
    for i, (_, row) in enumerate(samples.iterrows()):
        # Load image
        image_tensor = load_and_preprocess(row['image_path'])
        true_label = row['label']
        true_emotion = IDX_TO_EMOTION[true_label]
        
        # Generate Grad-CAM
        heatmap, pred_class, confidence = gradcam.generate_heatmap(image_tensor)
        pred_emotion = IDX_TO_EMOTION[pred_class]
        
        # Prepare display image
        display_image = denormalize_image(image_tensor.squeeze())
        overlay = overlay_heatmap(display_image, heatmap)
        
        # Determine if correct
        correct = pred_class == true_label
        color = 'green' if correct else 'red'
        
        # Plot
        axes[i, 0].imshow(Image.open(row['image_path']))
        axes[i, 0].set_title(f'Original\nTrue: {true_emotion}', fontsize=10)
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(display_image)
        axes[i, 1].set_title('Preprocessed', fontsize=10)
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(heatmap, cmap='jet')
        axes[i, 2].set_title('Grad-CAM Heatmap', fontsize=10)
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow(overlay)
        symbol = '✓' if correct else '✗'
        axes[i, 3].set_title(f'Pred: {pred_emotion} {symbol}\n({confidence:.1%})', 
                             fontsize=10, color=color)
        axes[i, 3].axis('off')
    
    plt.suptitle(f'Grad-CAM Visualization: {model_name}', fontsize=14, y=1.01)
    plt.tight_layout()
    plt.show()

# Generate visualizations for each model
for model_name, model in models_loaded.items():
    model_type = 'custom' if 'Custom' in model_name else 'resnet'
    print(f"\nGenerating Grad-CAM for {model_name}...")
    visualize_gradcam_samples(model, model_name, model_type, test_df, n_samples=6)

## Step 7: Compare Grad-CAM Across Emotions

In [None]:
def visualize_by_emotion(model, model_name, model_type, test_df):
    """
    Show Grad-CAM for one correctly classified example of each emotion.
    """
    gradcam = GradCAM(model, model_type)
    
    fig, axes = plt.subplots(len(EMOTION_CLASSES), 3, figsize=(10, 3*len(EMOTION_CLASSES)))
    
    for i, emotion in enumerate(EMOTION_CLASSES):
        # Find a sample of this emotion
        emotion_samples = test_df[test_df['emotion'] == emotion]
        
        if len(emotion_samples) == 0:
            continue
        
        # Try to find a correctly classified sample
        found = False
        for _, row in emotion_samples.sample(frac=1, random_state=42).iterrows():
            image_tensor = load_and_preprocess(row['image_path'])
            heatmap, pred_class, confidence = gradcam.generate_heatmap(image_tensor)
            
            if pred_class == row['label']:
                found = True
                break
        
        if not found:
            # Use first sample even if misclassified
            row = emotion_samples.iloc[0]
            image_tensor = load_and_preprocess(row['image_path'])
            heatmap, pred_class, confidence = gradcam.generate_heatmap(image_tensor)
        
        # Display
        display_image = denormalize_image(image_tensor.squeeze())
        overlay = overlay_heatmap(display_image, heatmap)
        
        axes[i, 0].imshow(display_image)
        axes[i, 0].set_ylabel(emotion.upper(), fontsize=12, rotation=0, ha='right', va='center')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(heatmap, cmap='jet')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(overlay)
        axes[i, 2].set_title(f'{confidence:.1%}', fontsize=10)
        axes[i, 2].axis('off')
    
    axes[0, 0].set_title('Original', fontsize=11)
    axes[0, 1].set_title('Heatmap', fontsize=11)
    axes[0, 2].set_title('Overlay', fontsize=11)
    
    plt.suptitle(f'Grad-CAM by Emotion: {model_name}', fontsize=14, y=1.01)
    plt.tight_layout()
    plt.show()

# Generate for each model
for model_name, model in models_loaded.items():
    model_type = 'custom' if 'Custom' in model_name else 'resnet'
    print(f"\nGrad-CAM by emotion for {model_name}...")
    visualize_by_emotion(model, model_name, model_type, test_df)

## Step 8: Analyze Misclassifications

In [None]:
def visualize_misclassifications(model, model_name, model_type, test_df, n_samples=6):
    """
    Show Grad-CAM for misclassified examples.
    
    This helps understand WHY the model made mistakes.
    """
    gradcam = GradCAM(model, model_type)
    
    misclassified = []
    
    # Find misclassified samples
    for _, row in test_df.iterrows():
        image_tensor = load_and_preprocess(row['image_path'])
        
        with torch.no_grad():
            output = model(image_tensor)
            pred_class = output.argmax(dim=1).item()
        
        if pred_class != row['label']:
            misclassified.append({
                'image_path': row['image_path'],
                'true_label': row['label'],
                'pred_label': pred_class
            })
        
        if len(misclassified) >= n_samples:
            break
    
    if len(misclassified) == 0:
        print(f"No misclassifications found for {model_name}!")
        return
    
    # Visualize
    fig, axes = plt.subplots(len(misclassified), 3, figsize=(10, 3*len(misclassified)))
    
    if len(misclassified) == 1:
        axes = [axes]
    
    for i, sample in enumerate(misclassified):
        image_tensor = load_and_preprocess(sample['image_path'])
        heatmap, _, confidence = gradcam.generate_heatmap(image_tensor, sample['pred_label'])
        
        display_image = denormalize_image(image_tensor.squeeze())
        overlay = overlay_heatmap(display_image, heatmap)
        
        true_emotion = IDX_TO_EMOTION[sample['true_label']]
        pred_emotion = IDX_TO_EMOTION[sample['pred_label']]
        
        axes[i][0].imshow(display_image)
        axes[i][0].set_ylabel(f'True: {true_emotion}', fontsize=10)
        axes[i][0].axis('off')
        
        axes[i][1].imshow(heatmap, cmap='jet')
        axes[i][1].axis('off')
        
        axes[i][2].imshow(overlay)
        axes[i][2].set_title(f'Pred: {pred_emotion}', fontsize=10, color='red')
        axes[i][2].axis('off')
    
    axes[0][0].set_title('Original', fontsize=11)
    axes[0][1].set_title('Attention', fontsize=11)
    axes[0][2].set_title('Overlay', fontsize=11)
    
    plt.suptitle(f'Misclassification Analysis: {model_name}', fontsize=14, y=1.01)
    plt.tight_layout()
    plt.show()

# Analyze misclassifications
for model_name, model in models_loaded.items():
    model_type = 'custom' if 'Custom' in model_name else 'resnet'
    print(f"\nAnalyzing misclassifications for {model_name}...")
    visualize_misclassifications(model, model_name, model_type, test_df, n_samples=4)

## Step 9: Interpretation Guidelines

### What to Look For:

| Pattern | Interpretation |
|---------|---------------|
| **Focus on face** | Good! Model learned relevant features |
| **Focus on background** | Bad! Model may be "cheating" |
| **Focus on eyes/brows** | Common for anger, fear, surprise |
| **Focus on mouth** | Common for happiness, disgust |
| **Diffuse attention** | Model uncertain, features unclear |

### Common Issues:

1. **Background focus**: Model learned dataset bias, not faces
2. **Wrong region**: Model looking at irrelevant features
3. **Too localized**: Missing important facial cues
4. **Too diffuse**: Model hasn't learned specific features

## Summary

### What We Learned:

1. **Grad-CAM reveals model attention** - Shows which image regions drive predictions

2. **Validates model learning** - Confirms model focuses on faces, not backgrounds

3. **Explains errors** - Shows why misclassifications happen

4. **Emotion-specific patterns** - Different emotions activate different facial regions

### Key Takeaways:

- **Good models** focus on facial features (eyes, mouth, brows)
- **Misclassifications** often occur when attention is diffuse or wrong
- **Transfer learning** models often have more focused attention
- **Grad-CAM** is essential for understanding and trusting deep learning models

### Project Complete!

You've now built and analyzed a complete facial expression recognition system:
1. Data exploration and preprocessing
2. Baseline model (HOG + SVM)
3. Custom CNN from scratch
4. Transfer learning with ResNet18
5. Comprehensive evaluation and comparison
6. Model interpretability with Grad-CAM