# Attention Mechanism Visualization

This notebook provides an interactive guide to understanding this component of GPT.


In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add project root to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath('')))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import our model classes
from src.model.attention import MultiHeadAttention
from src.model.gpt import GPTModel
from src.config import GPTConfig
from src.data.tokenizer import get_tokenizer
import tiktoken

## Understanding Multi-Head Attention

Attention is the core mechanism that allows transformers to understand relationships between tokens. This notebook visualizes how attention works in our GPT model.

In [None]:
# Create a simple attention layer to explore
embedding_dim = 64
context_length = 32
num_heads = 4

attention = MultiHeadAttention(
    input_dimension=embedding_dim,
    output_dimension=embedding_dim,
    context_length=context_length,
    dropout=0.0,  # Disable dropout for visualization
    number_of_heads=num_heads
)

print(f"Attention layer created:")
print(f"  Embedding dimension: {embedding_dim}")
print(f"  Number of heads: {num_heads}")
print(f"  Head dimension: {embedding_dim // num_heads}")
print(f"  Context length: {context_length}")

In [None]:
# Examine the causal mask
print("Causal mask shape:", attention.mask.shape)
print("\nCausal mask (first 10x10):")
print(attention.mask[:10, :10].int())
print("\nThe mask prevents tokens from attending to future tokens (upper triangle is True/1)")

In [None]:
# Create a simple input and run attention
batch_size = 1
seq_len = 8
x = torch.randn(batch_size, seq_len, embedding_dim)

print(f"Input shape: {x.shape}")
output = attention(x)
print(f"Output shape: {output.shape}")
print(f"Input and output shapes match: {x.shape == output.shape}")

## Extracting Attention Weights from a Model

To visualize attention, we need to extract the attention weights during the forward pass. We'll use PyTorch hooks to capture these weights.

In [None]:
# Hook class to capture attention weights
class AttentionHook:
    """Hook to capture attention weights from MultiHeadAttention module."""
    
    def __init__(self):
        self.attention_weights = []
        self.layer_idx = 0
    
    def __call__(self, module, input, output):
        """Capture attention weights during forward pass."""
        x = input[0]  # Input tensor
        
        # Get Q, K, V
        queries = module.W_query(x)
        keys = module.W_key(x)
        values = module.W_value(x)
        
        batch_size, num_tokens, _ = x.shape
        
        # Split into heads
        queries = queries.view(batch_size, num_tokens, module.number_of_heads, module.head_dimension)
        queries = queries.transpose(1, 2)  # [batch, heads, tokens, head_dimension]
        
        keys = keys.view(batch_size, num_tokens, module.number_of_heads, module.head_dimension)
        keys = keys.transpose(1, 2)
        
        # Compute attention scores
        attention_scores = queries @ keys.transpose(-2, -1)  # [batch, heads, tokens, tokens]
        
        # Apply mask
        mask = module.mask[:num_tokens, :num_tokens]
        attention_scores = attention_scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        # Apply scaling and softmax
        scaling_factor = module.head_dimension ** 0.5
        attention_weights = torch.softmax(attention_scores / scaling_factor, dim=-1)
        
        # Store weights (remove batch dimension, take first sample)
        self.attention_weights.append(attention_weights[0].detach().cpu().numpy())

In [None]:
# Create a small GPT model
config = GPTConfig(
    vocab_size=50257,
    context_length=128,
    embedding_dimension=128,
    number_of_heads=4,
    number_of_layers=2,
    dropout_rate=0.0  # Disable dropout for visualization
)

model = GPTModel(config)
model.eval()
print(f"Model created with {config.number_of_layers} layers and {config.number_of_heads} heads per layer")

In [None]:
# Prepare input text
tokenizer = get_tokenizer("gpt2")
text = "The cat sat on the mat. It was fluffy."
token_ids = tokenizer.encode(text)
tokens = [tokenizer.decode([tid]) for tid in token_ids]

print(f"Text: '{text}'")
print(f"Tokens: {tokens}")
print(f"Number of tokens: {len(tokens)}")

# Convert to tensor
input_ids = torch.tensor([token_ids], dtype=torch.long)
print(f"Input shape: {input_ids.shape}")

In [None]:
# Register hooks to capture attention weights
hooks = []
attention_hooks = []

for i, block in enumerate(model.transformer_blocks):
    hook = AttentionHook()
    hook.layer_idx = i
    attention_hooks.append(hook)
    
    # Register forward hook
    handle = block.attention.register_forward_hook(hook)
    hooks.append(handle)

# Forward pass
with torch.no_grad():
    _ = model(input_ids)

# Extract attention weights
all_weights = []
for hook in attention_hooks:
    if len(hook.attention_weights) > 0:
        all_weights.append(hook.attention_weights[0])
    else:
        all_weights.append(None)

# Remove hooks
for handle in hooks:
    handle.remove()

print(f"Extracted attention from {len(all_weights)} layers")
if all_weights[0] is not None:
    print(f"Shape per layer: {all_weights[0].shape}  # [num_heads, seq_len, seq_len]")

## Visualizing Attention Patterns

In [None]:
def visualize_attention_head(attention_weights, tokens, layer_idx, head_idx, ax=None):
    """Visualize attention weights for a single head."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 8))
    
    # Create heatmap
    im = ax.imshow(attention_weights, cmap='Blues', aspect='auto', vmin=0, vmax=1)
    
    # Set ticks and labels
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=8)
    ax.set_yticklabels(tokens, fontsize=8)
    
    # Add colorbar
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    # Title
    ax.set_title(f'Layer {layer_idx}, Head {head_idx}', fontsize=12, fontweight='bold')
    ax.set_xlabel('Key (Attended To)', fontsize=10)
    ax.set_ylabel('Query (Attending From)', fontsize=10)
    
    return ax

# Visualize first head of first layer
if all_weights[0] is not None:
    fig, ax = plt.subplots(figsize=(10, 8))
    visualize_attention_head(all_weights[0][0], tokens, 0, 0, ax=ax)
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize all heads for a layer
def visualize_all_heads(attention_weights, tokens, layer_idx):
    """Visualize all heads for a layer in a grid."""
    num_heads = attention_weights.shape[0]
    cols = 2
    rows = (num_heads + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(12, 5 * rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    for head_idx in range(num_heads):
        row = head_idx // cols
        col = head_idx % cols
        ax = axes[row, col]
        
        visualize_attention_head(
            attention_weights[head_idx],
            tokens,
            layer_idx,
            head_idx,
            ax=ax
        )
    
    # Hide unused subplots
    for head_idx in range(num_heads, rows * cols):
        row = head_idx // cols
        col = head_idx % cols
        axes[row, col].axis('off')
    
    plt.suptitle(f'All Attention Heads - Layer {layer_idx}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    return fig

# Visualize all heads of first layer
if all_weights[0] is not None:
    fig = visualize_all_heads(all_weights[0], tokens, 0)
    plt.show()

In [None]:
# Compare attention patterns across layers
if len(all_weights) > 1 and all_weights[0] is not None:
    # Average across all heads for each layer
    average_weights = [w.mean(axis=0) for w in all_weights if w is not None]
    
    fig, axes = plt.subplots(1, len(average_weights), figsize=(6 * len(average_weights), 6))
    if len(average_weights) == 1:
        axes = [axes]
    
    for layer_idx, avg_weights in enumerate(average_weights):
        im = axes[layer_idx].imshow(avg_weights, cmap='Blues', aspect='auto', vmin=0, vmax=1)
        axes[layer_idx].set_xticks(range(len(tokens)))
        axes[layer_idx].set_yticks(range(len(tokens)))
        axes[layer_idx].set_xticklabels(tokens, rotation=45, ha='right', fontsize=8)
        axes[layer_idx].set_yticklabels(tokens, fontsize=8)
        axes[layer_idx].set_title(f'Layer {layer_idx} (Averaged)', fontsize=12, fontweight='bold')
        plt.colorbar(im, ax=axes[layer_idx], label='Attention Weight')
    
    plt.suptitle('Attention Patterns Across Layers', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

## Analyzing Attention Patterns

In [None]:
# Analyze which tokens attend to which other tokens
if all_weights[0] is not None:
    # Average across all heads
    avg_weights = all_weights[0].mean(axis=0)  # [seq_len, seq_len]
    
    print("Top attention targets for each token (Layer 0, averaged across heads):\n")
    for token_idx, token in enumerate(tokens):
        # Get attention from this token to all others (only previous tokens due to causal mask)
        attention_from_token = avg_weights[token_idx, :token_idx+1]
        
        if len(attention_from_token) > 1:  # Not just self-attention
            # Get top-3 attended tokens
            top_indices = np.argsort(attention_from_token)[-3:][::-1]
            top_weights = attention_from_token[top_indices]
            
            print(f"'{token}' attends most to:")
            for i, (idx, weight) in enumerate(zip(top_indices, top_weights)):
                if idx < len(tokens):
                    print(f"  {i+1}. '{tokens[idx]}' (weight: {weight:.3f})")
            print()

In [None]:
# Plot attention flow across layers
if len(all_weights) > 1:
    average_weights = [w.mean(axis=0) for w in all_weights if w is not None]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    layers = list(range(len(average_weights)))
    # Average self-attention (diagonal) across all tokens for each layer
    self_attention = [np.mean(np.diag(avg_weights)) for avg_weights in average_weights]
    
    ax.plot(layers, self_attention, marker='o', linewidth=2, markersize=8, label='Average Self-Attention')
    
    # Also plot average attention to previous tokens
    previous_attention = []
    for avg_weights in average_weights:
        # Get attention to previous tokens (lower triangle, excluding diagonal)
        mask = np.tril(np.ones_like(avg_weights), k=-1).astype(bool)
        previous_attention.append(np.mean(avg_weights[mask]) if np.any(mask) else 0.0)
    
    ax.plot(layers, previous_attention, marker='s', linewidth=2, markersize=8, 
            label='Average Attention to Previous Tokens')
    
    ax.set_xlabel('Layer', fontsize=12)
    ax.set_ylabel('Attention Weight', fontsize=12)
    ax.set_title('Attention Patterns Across Layers', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xticks(layers)
    
    plt.tight_layout()
    plt.show()