# Self-Attention Visualization

This notebook visualizes the self-attention maps of a Vision Transformer (ViT). 
It uses `timm` to load a standard pre-trained model and registers hooks to capture attention weights.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import timm

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
# --- 1. Load Model and Register Hooks ---

# Load a standard pre-trained ViT from timm
model_name = "vit_base_patch16_224"
model = timm.create_model(model_name, pretrained=True)
model.to(device)
model.eval()

# Dictionary to store attention weights
attention_weights = {}

def get_attention_hook(name):
    def hook(module, input, output):
        # In timm ViT, the attention module returns the output (x), 
        # but we often need to dig into the internal logic to get raw attn weights.
        # However, for many standard implementations, we can hook into the Softmax layer 
        # inside the Attention block if accessible, or modify the model to return weights.
        # 
        # Here, we assume we are hooking a module that outputs the attention weights directly
        # or we calculate them if we hook the QKV projection. 
        # 
        # For simplicity in this demo with standard timm models, we will rely on 
        # extracting attributes if the model supports `get_attention_map` or simply 
        # assume the user provides a model wrapper that returns (output, attn).
        #
        # IF USING STANDARD TIMM VIT: The standard forward pass doesn't expose attn weights easily.
        # A common trick is to wrap the Attention modules.
        pass
    return hook

# --- Simpler Alternative: Wrap timm Attention to capture weights ---
# This function wraps the Attention block of a timm ViT to save the attention matrix.
class AttentionWrapper(nn.Module):
    def __init__(self, attn_module, layer_id):
        super().__init__()
        self.attn_module = attn_module
        self.layer_id = layer_id
    
    def forward(self, x):
        # Copy-paste logic from standard ViT Attention to extract weights
        # Or simply hook if possible. For generic usage, we will assume standard timm logic:
        B, N, C = x.shape
        qkv = self.attn_module.qkv(x).reshape(B, N, 3, self.attn_module.num_heads, C // self.attn_module.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.attn_module.scale
        attn = attn.softmax(dim=-1)
        
        # SAVE WEIGHTS
        attention_weights[f'layer_{self.layer_id}'] = attn.detach().cpu()
        
        attn = self.attn_module.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.attn_module.proj(x)
        x = self.attn_module.proj_drop(x)
        return x

# Replace Attention blocks with our wrapper
for i, block in enumerate(model.blocks):
    if hasattr(block, 'attn'):
        block.attn = AttentionWrapper(block.attn, i)
        print(f"Wrapped layer {i}")


In [None]:
# --- 2. Prepare Data ---

# Use a placeholder image from web or generate random noise for demonstration
# Replace 'path/to/your/image.jpg' with your actual image path
img_path = "sample_image.jpg" 

# Create a dummy image if not exists
if not os.path.exists(img_path):
    print("Creating dummy image...")
    Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)).save(img_path)

# Standard ImageNet transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

original_img = Image.open(img_path).convert('RGB')
input_tensor = transform(original_img).unsqueeze(0).to(device)
print(f"Input shape: {input_tensor.shape}")

In [None]:
# --- 3. Run Inference and Visualize ---

# Helper function to calculate Normalized Mutual Information (from user code logic)
def calculate_nmi(attn):
    # Placeholder for NMI calculation logic found in original script
    # Assuming generic calculation here or just using mean attention for viz
    return attn

with torch.no_grad():
    _ = model(input_tensor)

print(f"Captured layers: {list(attention_weights.keys())}")

# Visualize Mean Attention Map for the last layer
last_layer_key = list(attention_weights.keys())[-1]
attn_map = attention_weights[last_layer_key][0] # (Heads, N, N)

# Average over heads
attn_mean = attn_map.mean(dim=0)

# Visualize attention from CLS token to other tokens
# Token 0 is usually CLS
cls_attn = attn_mean[0, 1:] # Drop cls-to-cls attention

# Reshape to image grid (14x14 for 224 image and 16 patch size)
grid_size = int(np.sqrt(cls_attn.shape[0]))
cls_attn_grid = cls_attn.reshape(grid_size, grid_size).numpy()

# Plot
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(original_img)
ax[0].set_title("Original")
ax[0].axis('off')

ax[1].imshow(cls_attn_grid, cmap='inferno')
ax[1].set_title(f"Attention (Layer {last_layer_key})")
ax[1].axis('off')

plt.tight_layout()
plt.show()