[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14xNi1SJZm17TBc6rKaWvOJYwx8wUOTcK)

Author: 
- **Safouane El Ghazouali**, 
- Ph.D. in AI, 
- Senior data scientist and researcher at TOELT LLC,
- Lecturer at HSLU

# -----  -----  -----  -----  -----  -----  -----  -----

# Visualizing Attention Maps in Transformers

Welcome to this hands-on notebook on visualizing attention maps in Vision Transformers (ViTs)! Attention maps are a powerful tool to understand how transformers focus on different parts of an input image, revealing what the model 'pays attention to' during inference.

![Attention Map Example](https://www.researchgate.net/publication/353612417/figure/fig3/AS:1051863327727618@1627795168385/The-output-attention-map-of-the-base-model-and-the-multi-attention-guided-method-on.png)

### Why Visualize Attention Maps?
- **Interpretability**: Understand which parts of an image the model focuses on.
- **Debugging**: Identify if the model is attending to irrelevant regions.
- **Insight into Transformers**: See the attention mechanism in action.

### What You'll Learn
- How attention maps are generated in Vision Transformers.
- Extracting attention weights from a pre-trained ViT model.
- Visualizing attention maps overlaid on images.
- Interpreting the results.
- Exploring and inspecting the model structure for better understanding.

# 🧰 Environment Setup

We'll use PyTorch and the `timm` library to load a pre-trained Vision Transformer and Matplotlib for visualization.

In [None]:
!pip install -q torch torchvision timm matplotlib

### Import Libraries

We import the necessary libraries for loading the model, processing images, and visualizing attention maps.

In [None]:
import torch
import timm
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import requests
import torch.nn.functional as F

print("PyTorch version:", torch.__version__)
print("timm version:", timm.__version__)

# 📦 Loading a Pre-trained Vision Transformer

We'll use a pre-trained Vision Transformer (ViT) model from `timm`. The model is pre-trained on ImageNet.

In [None]:
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=1000)
model.eval()

# Explanation
# model: A pre-trained Vision Transformer with 16x16 patch size and 224x224 input resolution.
# eval(): Sets the model to evaluation mode for inference.

In [None]:
import timm

# Create the model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Access the model's configuration
config = timm.data.resolve_data_config(model.pretrained_cfg)

# Print the input size
print(f"Model: {model.pretrained_cfg['architecture']}")
print(f"Input size: {config['input_size']}")
print(f"Normalization mean: {config['mean']}")
print(f"Normalization std: {config['std']}")

num_classes = model.head.out_features
print(f"The model is configured to predict {num_classes} classes.")


# 🔎 Exploring the Model Structure

Before diving into attention visualization, it's important to understand the model's architecture. As a data scientist, inspecting models helps in debugging, customizing, and understanding how they work internally.

### Why Explore Models?
- Identify layers and blocks for modification or hooking.
- Check dimensions, parameters, and configurations (e.g., number of heads).
- Facilitate transfer learning or fine-tuning.

### Printing the Entire Model

Printing the model gives a high-level overview of its components: patch embedding, positional embedding, transformer blocks, normalization, and head.

In [None]:
print(model)

# Explanation
# This displays the hierarchical structure of the ViT model, including the sequence of transformer blocks.

### Accessing Specific Blocks

The transformer consists of multiple blocks (layers). We can access them individually to inspect or modify.

In [None]:
# Number of transformer blocks
num_blocks = len(model.blocks)
print(f"Number of transformer blocks: {num_blocks}")

# Access the first block
first_block = model.blocks[0]
print("\nFirst transformer block:")
print(first_block)
print("\------------------------/")
# Access the last block
last_block = model.blocks[-1]
print("\nLast transformer block:")
print(last_block)

# Explanation
# model.blocks: A list of Block modules, each containing attention, MLP, and normalization layers.
# We often focus on the last block for attention visualization as it captures high-level features.

### Accessing Layers Within a Block

Each block has sub-layers like attention (attn), MLP, and norms. Accessing them allows fine-grained control, e.g., for hooking or parameter inspection.

In [None]:
# Access the attention module in the last block
attn_module = last_block.attn
print("Attention module in last block:")
print(attn_module)

# Get key attributes
print(f"Number of attention heads: {attn_module.num_heads}")
print(f"Head dimension: {attn_module.head_dim}")

# Access the QKV linear layer
qkv_layer = attn_module.qkv
print("\nQKV linear layer:")
print(qkv_layer)

# Explanation
# attn_module: The Attention submodule handling self-attention.
# num_heads: Number of parallel attention heads (12 for ViT-Base).
# qkv: Linear layer projecting input to queries, keys, values.

### Other Useful Exploration Techniques
- **List attributes/methods**: Use `dir(object)` to see available properties, e.g., `dir(model)`.
- **Parameter count**: `sum(p.numel() for p in model.parameters())`.
- **Layer names**: Use `model.named_modules()` to iterate over all modules with names.

In [None]:
# Total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

# List some module names
for name, module in list(model.named_modules())[:5]:
    print(name, module)

# Explanation
# total_params: Helps gauge model size and complexity.
# named_modules(): Useful for targeting specific layers by name.

# 🖼️ Image Preprocessing

We need to preprocess an input image to match the model's expected input format (224x224 pixels, normalized).

In [None]:
# Download a sample image
url = 'https://farm4.staticflickr.com/3427/3188200587_fbddbcecbb_z.jpg'
img = Image.open(requests.get(url, stream=True).raw)

# Define preprocessing pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = transform(img).unsqueeze(0)

# Display the image
plt.imshow(img)
plt.axis('off')
plt.show()

# Explanation
# img: The raw input image loaded from a URL.
# transform: Resizes and normalizes the image to match ViT's requirements.
# input_tensor: The processed image as a tensor with shape [1, 3, 224, 224].

# 🔍 Extracting Attention Weights

To capture attention weights, we modify the forward method of the attention module in the last block. This wrapper saves the attention map internally.

In [None]:
from typing import Optional

def my_forward_wrapper(attn_obj):
    def my_forward(x, attn_mask: Optional[torch.Tensor] = None):
        B, N, C = x.shape
        qkv = attn_obj.qkv(x).reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
        attn = attn.softmax(dim=-1)
        attn = attn_obj.attn_drop(attn)
        attn_obj.attn_map = attn
        attn_obj.cls_attn_map = attn[:, :, 0, 1:]  # CLS token attention to patches

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = attn_obj.proj(x)
        x = attn_obj.proj_drop(x)
        return x
    return my_forward

# Apply the wrapper to the last attention module
model.blocks[-1].attn.forward = my_forward_wrapper(model.blocks[-1].attn)


# Explanation
# my_forward_wrapper: Replaces the original forward to compute and store attention maps as attributes.
# attn_map: Full attention matrix.
# cls_attn_map: Attention from CLS token to image patches.

### Running Inference to Capture Attention

Pass the image through the model; the wrapper will save the attention maps.

In [None]:
with torch.no_grad():
    output = model(input_tensor)

# Explanation
# output: Model logits; attention maps are now stored in the attention module.

# 🛠️ Processing Attention Weights

Retrieve the saved attention maps, average across heads, and prepare for visualization.

In [None]:
# Retrieve attention maps from the module
attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()
cls_attn = model.blocks[-1].attn.cls_attn_map.mean(dim=1).squeeze(0).detach()

# Reshape CLS attention to patch grid (14x14)
grid_size = 14  # 224 / 16
cls_attn_map = cls_attn.view(grid_size, grid_size)

# Explanation
# attn_map: Full attention matrix averaged over heads.
# cls_attn_map: CLS token's attention to patches, reshaped to 14x14 grid.

# 🎨 Visualizing the Attention Map

Overlay the CLS attention map on the original image.

In [None]:
# Resize CLS attention map to image size
cls_attn_resized = F.interpolate(cls_attn_map.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear').squeeze()

# Normalize for visualization
cls_attn_resized = (cls_attn_resized - cls_attn_resized.min()) / (cls_attn_resized.max() - cls_attn_resized.min())

# Convert input tensor to image for overlay
img_tensor = input_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
img_tensor = (img_tensor * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
img_tensor = np.clip(img_tensor, 0, 1)

# Plot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(img_tensor)
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(img_tensor)
plt.imshow(cls_attn_resized, cmap='jet', alpha=0.5)
plt.title('Attention Map Overlay')
plt.axis('off')

plt.show()

# Explanation
# cls_attn_resized: Upscaled to 224x224 using bilinear interpolation.
# Overlay: Shows attention intensity (red=high, blue=low) on the image.

# 🧠 Interpreting the Attention Map

The attention map highlights regions the model focuses on for classification. High-attention areas (red) are key for the CLS token's representation.

# # # # # # # # # # # # # # # # # # # # # # # #
# 💡 Student Task

Use images from the [Microsoft COCO dataset explorer](https://cocodataset.org/#explore).

Tasks:
1. Load a different image and visualize its attention map.
2. Explore the model: Print the number of parameters in the attention module of the first block.
3. Access and print the positional embedding shape (`model.pos_embed.shape`).
4. Load another ViT variant (e.g., 'vit_small_patch16_224') and compare attention maps.
5. Modify the wrapper to capture attention from a different block (e.g., model.blocks[0]) and visualize it.
6. Discuss differences in attention across layers or models.

Tips:
- Use `timm.list_models('vit*')` for variants.
- For exploration: Use `sum(p.numel() for p in attn_module.parameters())` for param count.

In [None]:
# Starter code
print(timm.list_models('vit*'))

# Example: Parameter count in attention
attn_params = sum(p.numel() for p in model.blocks[0].attn.parameters())
print(f"Attention params in first block: {attn_params:,}")

# Your code here

In [None]:
import torch
import timm
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np

# Set up the model
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
model.eval()

# Download a sample image
url = 'http://farm8.staticflickr.com/7012/6597749473_03b2f736ac_z.jpg'
img = Image.open(requests.get(url, stream=True).raw)

# Define preprocessing pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = transform(img).unsqueeze(0)

# Display the image
plt.imshow(img)
plt.axis('off')
plt.show()

def get_grad_cam(model, target_layer, input_tensor, target_class=None):
    
    activations = {}
    gradients = {}

    def save_activation(module, input, output):
        activations['output'] = output.detach()
        
    def save_gradient(module, grad_in, grad_out):
        gradients['output'] = grad_out[0].detach()

    # Register the hooks
    hook_fwd = target_layer.register_forward_hook(save_activation)
    hook_bwd = target_layer.register_backward_hook(save_gradient)

    # Perform the forward pass to get the model's output
    # This must be done outside `torch.no_grad()` for the backward pass to work
    output = model(input_tensor)
    
    # If a target class isn't specified, use the predicted class
    if target_class is None:
        target_class = output.argmax()

    # Zero gradients and perform backward pass for the target class
    model.zero_grad()
    one_hot_output = torch.zeros_like(output)
    one_hot_output[0][target_class] = 1
    output.backward(gradient=one_hot_output, retain_graph=True)

    # Remove the hooks to avoid memory leaks
    hook_fwd.remove()
    hook_bwd.remove()

    # Get the feature maps and their gradients
    feature_maps = activations['output']
    grads = gradients['output']
    
    # Compute the weights by global average pooling the gradients
    weights = torch.mean(grads, dim=(2, 3), keepdim=True)
    
    # Combine feature maps and weights, then apply ReLU
    cam = F.relu(torch.sum(feature_maps * weights, dim=1))
    
    # Resize the CAM to the size of the original image
    cam = F.interpolate(cam.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze()

    return cam

# In a ResNet, a good target layer for Grad-CAM is the last convolutional layer.
# For a resnet50 from timm, this is typically `model.layer4[-1].conv3`.
target_layer = model.layer4[-1].conv3

# Run the Grad-CAM generation
# The forward pass is now inside the function, outside of any `no_grad` block.
grad_cam_map = get_grad_cam(model, target_layer, input_tensor)

# Normalize for visualization
grad_cam_map = (grad_cam_map - grad_cam_map.min()) / (grad_cam_map.max() - grad_cam_map.min())

# Convert input tensor to image for overlay
img_tensor = input_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
img_tensor = (img_tensor * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
img_tensor = np.clip(img_tensor, 0, 1)

# Plot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(img_tensor)
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(img_tensor)
plt.imshow(grad_cam_map, cmap='hot', alpha=0.5)
plt.title('Grad-CAM Overlay')
plt.axis('off')

plt.show()

# Explanation
# grad_cam_map: The generated Grad-CAM map, upscaled to the image size.
# The corrected code now performs the forward pass outside of a `no_grad`
# block, allowing the computation graph to be built and the backward pass to succeed.