# Attention Mechanisms - Hands-On Exploration

This notebook provides interactive exploration of attention mechanisms.

## Topics
1. Basic Attention Implementation
2. Multi-Head Attention
3. Attention Visualization
4. Flash Attention Comparison
5. Efficient Attention Variants
6. Profiling and Analysis

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## 1. Basic Scaled Dot-Product Attention

The core formula:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Basic scaled dot-product attention.
    
    Args:
        Q: (batch, seq_len, d_k)
        K: (batch, seq_len, d_k)
        V: (batch, seq_len, d_v)
        mask: (batch, seq_len, seq_len) or (seq_len, seq_len)
    """
    d_k = Q.size(-1)
    
    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Step 2: Apply mask (optional)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 3: Softmax to get attention weights
    attn_weights = F.softmax(scores, dim=-1)
    
    # Step 4: Apply attention to values
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

# Test it
batch_size, seq_len, d_model = 2, 6, 64
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"Input shapes: Q={Q.shape}, K={K.shape}, V={V.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"\nAttention weights sum to 1: {weights[0, 0].sum().item():.4f}")

In [None]:
# Visualize attention weights
plt.figure(figsize=(8, 6))
plt.imshow(weights[0].detach().numpy(), cmap='Blues')
plt.colorbar(label='Attention Weight')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Attention Weights (Batch 0)')
plt.show()

print("Each row sums to 1 (probability distribution over keys)")

## 2. Causal (Masked) Attention

For autoregressive models, each position can only attend to previous positions.

In [None]:
# Create causal mask
def create_causal_mask(seq_len):
    """Create lower-triangular mask for causal attention."""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

causal_mask = create_causal_mask(seq_len)
print("Causal Mask:")
print(causal_mask)

# Apply causal attention
output_causal, weights_causal = scaled_dot_product_attention(Q, K, V, mask=causal_mask)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].imshow(weights[0].detach().numpy(), cmap='Blues')
axes[0].set_title('Full Attention')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')

axes[1].imshow(weights_causal[0].detach().numpy(), cmap='Blues')
axes[1].set_title('Causal Attention (Masked)')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')

plt.tight_layout()
plt.show()

print("Notice: Causal attention has zeros above diagonal (no future info)")

## 3. Multi-Head Attention

Multiple attention heads allow the model to attend to different aspects.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # Project and reshape to (batch, heads, seq, head_dim)
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention
        scale = math.sqrt(self.head_dim)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)
        
        # Concat heads and project
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(context)
        
        return output, attn_weights

# Test
mha = MultiHeadAttention(d_model=64, num_heads=8)
x = torch.randn(2, 10, 64)
output, attn = mha(x)

print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Attention per head: {attn.shape}")

In [None]:
# Visualize different heads
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(8):
    axes[i].imshow(attn[0, i].detach().numpy(), cmap='Blues')
    axes[i].set_title(f'Head {i}')
    axes[i].set_xlabel('Key')
    axes[i].set_ylabel('Query')

plt.suptitle('Different Attention Heads Learn Different Patterns', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Flash Attention vs Standard Attention

Compare memory and speed of Flash Attention (via SDPA).

In [None]:
def profile_attention(seq_len, batch=4, heads=8, head_dim=64, iterations=20):
    """Profile standard vs flash attention."""
    if not torch.cuda.is_available():
        return None, None
    
    q = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
    k = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
    v = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
    
    # Standard attention
    def standard():
        scale = 1.0 / math.sqrt(head_dim)
        scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, v)
    
    # Flash attention
    def flash():
        return F.scaled_dot_product_attention(q, k, v)
    
    # Warmup
    for _ in range(5):
        _ = standard()
        _ = flash()
    torch.cuda.synchronize()
    
    # Time standard
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iterations):
        _ = standard()
    end.record()
    torch.cuda.synchronize()
    std_time = start.elapsed_time(end) / iterations
    
    # Time flash
    start.record()
    for _ in range(iterations):
        _ = flash()
    end.record()
    torch.cuda.synchronize()
    flash_time = start.elapsed_time(end) / iterations
    
    return std_time, flash_time

if torch.cuda.is_available():
    print("Sequence Length | Standard (ms) | Flash (ms) | Speedup")
    print("-" * 55)
    
    for seq_len in [512, 1024, 2048, 4096]:
        std_time, flash_time = profile_attention(seq_len)
        if std_time:
            speedup = std_time / flash_time
            print(f"{seq_len:^15} | {std_time:^13.2f} | {flash_time:^10.2f} | {speedup:.2f}x")
else:
    print("CUDA not available for profiling")

## 5. Linear Attention

O(N) complexity by removing softmax.

In [None]:
def linear_attention(Q, K, V, feature_map='elu'):
    """
    Linear attention: O(N) complexity.
    
    Key insight: φ(Q) @ (φ(K)^T @ V) instead of softmax(QK^T) @ V
    """
    # Apply feature map (must be non-negative)
    if feature_map == 'elu':
        Q = F.elu(Q) + 1
        K = F.elu(K) + 1
    elif feature_map == 'relu':
        Q = F.relu(Q)
        K = F.relu(K)
    
    # Linear attention: Q @ (K^T @ V)
    # Compute K^T @ V first: O(d²) instead of O(N²)
    KV = torch.einsum('bnd,bnv->bdv', K, V)  # (batch, d_k, d_v)
    
    # Q @ KV
    output = torch.einsum('bnd,bdv->bnv', Q, KV)  # (batch, seq, d_v)
    
    # Normalize
    K_sum = K.sum(dim=1, keepdim=True)  # (batch, 1, d)
    normalizer = torch.einsum('bnd,bkd->bnk', Q, K_sum).squeeze(-1)  # (batch, seq)
    output = output / (normalizer.unsqueeze(-1) + 1e-6)
    
    return output

# Compare standard vs linear
Q = torch.randn(2, 100, 64)
K = torch.randn(2, 100, 64)
V = torch.randn(2, 100, 64)

out_standard, _ = scaled_dot_product_attention(Q, K, V)
out_linear = linear_attention(Q, K, V)

print(f"Standard output shape: {out_standard.shape}")
print(f"Linear output shape: {out_linear.shape}")
print(f"\nNote: Linear attention is an APPROXIMATION, not exact!")

## 6. Attention Statistics and Debugging

In [None]:
def compute_attention_stats(attn_weights):
    """
    Compute useful statistics for debugging attention.
    
    Args:
        attn_weights: (batch, seq, seq) or (batch, heads, seq, seq)
    """
    # Flatten heads if present
    if attn_weights.dim() == 4:
        attn_weights = attn_weights.mean(dim=1)  # Average over heads
    
    # Entropy (higher = more uniform)
    entropy = -(attn_weights * torch.log(attn_weights + 1e-10)).sum(dim=-1)
    
    # Max attention weight per query
    max_attn = attn_weights.max(dim=-1).values
    
    # Effective context (exp of entropy)
    eff_context = torch.exp(entropy)
    
    # Diagonal dominance (self-attention)
    seq_len = attn_weights.size(-1)
    diagonal = torch.diagonal(attn_weights, dim1=-2, dim2=-1)
    
    return {
        'entropy': entropy.mean().item(),
        'max_attention': max_attn.mean().item(),
        'effective_context': eff_context.mean().item(),
        'diagonal_attention': diagonal.mean().item()
    }

# Generate different attention patterns
seq_len = 20

# Random attention
Q_rand = torch.randn(1, seq_len, 64)
K_rand = torch.randn(1, seq_len, 64)
V_rand = torch.randn(1, seq_len, 64)
_, attn_rand = scaled_dot_product_attention(Q_rand, K_rand, V_rand)

# Identity-like (Q = K)
Q_id = torch.randn(1, seq_len, 64)
_, attn_id = scaled_dot_product_attention(Q_id, Q_id * 10, V_rand)  # Scale for peaky

print("Attention Pattern Analysis:")
print("=" * 50)
print(f"\nRandom Q, K:")
for k, v in compute_attention_stats(attn_rand).items():
    print(f"  {k}: {v:.3f}")

print(f"\nIdentity-like (Q ≈ K):")
for k, v in compute_attention_stats(attn_id).items():
    print(f"  {k}: {v:.3f}")

In [None]:
# Visualize the patterns
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].imshow(attn_rand[0].detach().numpy(), cmap='Blues')
axes[0].set_title('Random Attention (distributed)')
axes[0].set_xlabel('Key')
axes[0].set_ylabel('Query')

axes[1].imshow(attn_id[0].detach().numpy(), cmap='Blues')
axes[1].set_title('Identity-like Attention (diagonal)')
axes[1].set_xlabel('Key')
axes[1].set_ylabel('Query')

plt.tight_layout()
plt.show()

## 7. Position Encodings Comparison

In [None]:
def sinusoidal_position_encoding(seq_len, d_model):
    """Sinusoidal position encoding from 'Attention Is All You Need'."""
    pe = torch.zeros(seq_len, d_model)
    position = torch.arange(0, seq_len).unsqueeze(1).float()
    
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe

# Generate position encodings
pe = sinusoidal_position_encoding(100, 64)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Position encoding heatmap
axes[0].imshow(pe.numpy(), aspect='auto', cmap='RdBu')
axes[0].set_xlabel('Dimension')
axes[0].set_ylabel('Position')
axes[0].set_title('Sinusoidal Position Encoding')

# Similarity between positions
similarity = torch.mm(pe, pe.t())
axes[1].imshow(similarity.numpy(), cmap='RdBu')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Position')
axes[1].set_title('Position Similarity (dot product)')

plt.tight_layout()
plt.show()

print("Note: Nearby positions have higher similarity (visible in right plot)")

## Summary

Key concepts explored:

1. **Scaled Dot-Product Attention** - Core formula with scaling
2. **Causal Masking** - For autoregressive models
3. **Multi-Head Attention** - Multiple attention perspectives
4. **Flash Attention** - IO-aware efficient implementation
5. **Linear Attention** - O(N) approximation
6. **Debugging** - Statistics and visualization

In [None]:
print("""
╔══════════════════════════════════════════════════════════════════╗
║                ATTENTION MECHANISMS SUMMARY                      ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                  ║
║  STANDARD ATTENTION: O(N²) time and memory                       ║
║    Attention(Q,K,V) = softmax(QK^T / √d) · V                     ║
║                                                                  ║
║  MULTI-HEAD: h parallel attention heads                          ║
║    More expressive, same asymptotic cost                         ║
║                                                                  ║
║  FLASH ATTENTION: Same output, O(N) memory                       ║
║    Use F.scaled_dot_product_attention()                          ║
║                                                                  ║
║  LINEAR ATTENTION: O(N) time                                     ║
║    Approximation - trades quality for speed                      ║
║                                                                  ║
║  DEBUGGING: Check entropy, visualize patterns                    ║
║                                                                  ║
╚══════════════════════════════════════════════════════════════════╝
""")