# Chapter 2: The Attention Mechanism

Welcome to the second notebook in our LLM from Scratch series! In this chapter, we'll explore the **attention mechanism** - the revolutionary innovation that powers modern transformer models.

## What You'll Learn

1. **What is attention** and why it's revolutionary
2. **Scaled dot-product attention**: The core formula
3. **Multi-head attention**: Attending to multiple representation subspaces
4. **Causal masking**: Essential for autoregressive generation
5. **Hands-on visualization** of attention patterns
6. **Implementation details** from our codebase

This is the heart of the transformer architecture - let's understand it deeply!

## 1. What is Attention?

**Attention** allows each position in a sequence to look at other positions and gather relevant information.

### The Core Idea

When processing a word, we want to look at **all other words** in the sentence to understand context:

```
Sentence: "The cat sat on the mat"

When processing "cat":
- Look at "The" → what kind of cat?
- Look at "sat" → what is the cat doing?
- Look at "mat" → where is the cat?
```

### Why is this Revolutionary?

Before transformers, models processed sequences **sequentially** (RNNs/LSTMs):
- ❌ Slow (cannot parallelize)
- ❌ Limited memory (vanishing gradients)
- ❌ Struggles with long-range dependencies

Attention processes **all positions in parallel**:
- ✅ Fast (massively parallel on GPUs)
- ✅ Unlimited memory (direct connections)
- ✅ Captures long-range dependencies easily

This is why GPT, BERT, and other transformers are so powerful!

## 2. Scaled Dot-Product Attention

The attention mechanism computes three things:

### Query, Key, Value (Q, K, V)

For each position, we create three vectors:
- **Query (Q)**: What I'm looking for
- **Key (K)**: What I can offer
- **Value (V)**: The actual information I contain

Think of it like a database lookup:
- Query: "Show me all documents about cats"
- Keys: Document tags/titles
- Values: Document contents

### The Formula

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Let's break this down:

1. **$QK^T$**: Compute similarity scores between all query-key pairs
   - Shape: (seq_len, seq_len)
   - Each element [i, j] = how much position i attends to position j

2. **$\frac{1}{\sqrt{d_k}}$**: Scale by square root of key dimension
   - Prevents dot products from growing too large
   - Keeps gradients stable (avoids softmax saturation)

3. **$\text{softmax}(...)$**: Convert scores to probabilities
   - Each row sums to 1
   - Higher scores → higher attention weights

4. **$... V$**: Weighted sum of values
   - Aggregate information based on attention weights
   - Output: contextualized representation for each position

## 3. Hands-On: Implementing Attention Step-by-Step

Let's implement attention from scratch to understand each component!

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)

print("PyTorch version:", torch.__version__)

### 3.1 Simple Attention from Scratch

In [None]:
def simple_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> tuple:
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Queries of shape (batch, seq_len, d_k)
        K: Keys of shape (batch, seq_len, d_k)
        V: Values of shape (batch, seq_len, d_v)
    
    Returns:
        Output and attention weights
    """
    # Step 1: Compute attention scores
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, seq_len, seq_len)
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # Step 3: Apply softmax
    attn_weights = F.softmax(scores, dim=-1)  # (batch, seq_len, seq_len)
    
    # Step 4: Apply attention to values
    output = torch.matmul(attn_weights, V)  # (batch, seq_len, d_v)
    
    return output, attn_weights

# Test with simple example
batch_size = 1
seq_len = 4
d_model = 8

# Create random Q, K, V
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = simple_attention(Q, K, V)

print(f"Input shapes:")
print(f"  Q: {Q.shape}")
print(f"  K: {K.shape}")
print(f"  V: {V.shape}")
print(f"\nOutput shapes:")
print(f"  Output: {output.shape}")
print(f"  Attention weights: {attn_weights.shape}")
print(f"\nAttention weights (each row sums to 1):")
print(attn_weights[0])
print(f"\nRow sums: {attn_weights[0].sum(dim=-1)}")

### 3.2 Visualizing Attention Patterns

In [None]:
def plot_attention(attention_weights: torch.Tensor, title: str = "Attention Weights") -> None:
    """
    Visualize attention weights as a heatmap.
    
    Args:
        attention_weights: Tensor of shape (seq_len, seq_len)
        title: Plot title
    """
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        attention_weights.detach().cpu().numpy(),
        annot=True,
        fmt=".3f",
        cmap="YlOrRd",
        cbar=True,
        square=True,
        xticklabels=[f"Pos {i}" for i in range(attention_weights.size(0))],
        yticklabels=[f"Pos {i}" for i in range(attention_weights.size(0))],
    )
    plt.title(title)
    plt.xlabel("Key Position (attending TO)")
    plt.ylabel("Query Position (attending FROM)")
    plt.tight_layout()
    plt.show()

# Visualize the attention pattern
plot_attention(
    attn_weights[0],
    title="Attention Pattern (No Masking)"
)

### Interpreting the Heatmap:

- **Rows**: Query positions (which position is attending)
- **Columns**: Key positions (which positions are being attended to)
- **Values**: Attention weights (0 to 1, each row sums to 1)
- **Bright cells**: High attention (this position is important)
- **Dark cells**: Low attention (this position is ignored)

Each position can look at **all positions** including future ones!

## 4. Causal Masking: Essential for Autoregressive Models

For language generation (like GPT), we need **causal masking**:
- When predicting token at position $i$, we can only look at positions $< i$
- Cannot "cheat" by looking at future tokens
- This makes the model **autoregressive**: generates one token at a time

### Causal Mask Pattern

```
Position:  0  1  2  3
        0 [✓  ✗  ✗  ✗]  ← Can only see position 0
        1 [✓  ✓  ✗  ✗]  ← Can see positions 0, 1
        2 [✓  ✓  ✓  ✗]  ← Can see positions 0, 1, 2
        3 [✓  ✓  ✓  ✓]  ← Can see all positions
```

This creates a **lower triangular** pattern!

In [None]:
# Import our causal mask function
from src.llm.attention import create_causal_mask

# Create a causal mask
seq_len = 6
causal_mask = create_causal_mask(seq_len)

print(f"Causal mask for sequence length {seq_len}:")
print(causal_mask.int())
print(f"\nTrue = masked (cannot attend), False = unmasked (can attend)")

# Visualize the mask
plt.figure(figsize=(8, 6))
sns.heatmap(
    causal_mask.int().numpy(),
    annot=True,
    fmt="d",
    cmap="RdYlGn_r",
    cbar=False,
    square=True,
    xticklabels=[f"Pos {i}" for i in range(seq_len)],
    yticklabels=[f"Pos {i}" for i in range(seq_len)],
)
plt.title("Causal Mask (1 = Masked, 0 = Visible)")
plt.xlabel("Key Position")
plt.ylabel("Query Position")
plt.tight_layout()
plt.show()

### Applying Causal Mask to Attention

In [None]:
def causal_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> tuple:
    """
    Compute scaled dot-product attention with causal masking.
    """
    batch_size, seq_len, d_k = Q.shape
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # Apply causal mask
    mask = create_causal_mask(seq_len, device=Q.device)
    scores = scores.masked_fill(mask.unsqueeze(0), -1e9)  # Set masked positions to very negative
    
    # Softmax and apply to values
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

# Test causal attention
output_causal, attn_weights_causal = causal_attention(Q, K, V)

print(f"Causal attention weights:")
print(attn_weights_causal[0])

# Visualize
plot_attention(
    attn_weights_causal[0],
    title="Attention Pattern WITH Causal Masking"
)

### Observation:

Notice the **lower triangular pattern**:
- Top-right triangle is all zeros (future tokens masked)
- Each position only attends to itself and previous positions
- This is essential for autoregressive language generation!

## 5. Multi-Head Attention

**Multi-head attention** runs attention multiple times in parallel:
- Each "head" learns different attention patterns
- Head 1 might focus on syntax
- Head 2 might focus on semantics
- Head 3 might focus on long-range dependencies

### How it Works:

1. **Split** $d_{model}$ into $n_{heads}$ smaller dimensions
2. **Run attention** on each head independently
3. **Concatenate** all head outputs
4. **Project** back to $d_{model}$

### Example:

```
d_model = 512, n_heads = 8
→ Each head dimension = 512 / 8 = 64

Input: (batch, seq_len, 512)
  ↓ Split into 8 heads
Heads: 8 × (batch, seq_len, 64)
  ↓ Run attention on each head
Outputs: 8 × (batch, seq_len, 64)
  ↓ Concatenate
Concat: (batch, seq_len, 512)
  ↓ Output projection
Final: (batch, seq_len, 512)
```

## 6. Using Our Multi-Head Attention Implementation

In [None]:
from src.llm.attention import MultiHeadAttention

# Create multi-head attention module
d_model = 128
n_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads, dropout=0.0)

print(f"Multi-Head Attention Configuration:")
print(f"  d_model: {d_model}")
print(f"  n_heads: {n_heads}")
print(f"  head_dim: {mha.head_dim}")
print(f"  scale factor: {mha.scale:.4f}")

# Create input
x = torch.randn(batch_size, seq_len, d_model)

# Forward pass without mask
output, _ = mha(x, return_attention=False)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Shape preserved: {x.shape == output.shape}")

### 6.1 Visualizing Multi-Head Attention Patterns

In [None]:
# Get attention weights for all heads
mask = create_causal_mask(seq_len)
output, attn_weights = mha(x, mask=mask, return_attention=True)

print(f"Attention weights shape: {attn_weights.shape}")
print(f"  (batch_size, n_heads, seq_len, seq_len)")

# Visualize first 4 heads for the first batch item
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Attention Patterns Across Different Heads", fontsize=16)

for head_idx in range(8):
    row = head_idx // 4
    col = head_idx % 4
    
    ax = axes[row, col]
    head_attn = attn_weights[0, head_idx].detach().cpu().numpy()
    
    im = ax.imshow(head_attn, cmap="YlOrRd", aspect="auto")
    ax.set_title(f"Head {head_idx + 1}")
    ax.set_xlabel("Key Position")
    ax.set_ylabel("Query Position")
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

### Observations:

Each head learns **different attention patterns**:
- Some heads might focus on **nearby positions** (local context)
- Some heads might focus on **distant positions** (long-range dependencies)
- Some heads might attend **uniformly** (global context)
- All heads respect the **causal mask** (lower triangular)

This diversity allows the model to capture different types of relationships!

## 7. Practical Example: Sentence Processing

In [None]:
from src.llm import Tokenizer

# Create tokenizer and encode a sentence
tokenizer = Tokenizer()
text = "The cat sat on the mat"
token_ids = tokenizer.encode(text)
tokens = [tokenizer.decode([tid]) for tid in token_ids]

print(f"Sentence: {text}")
print(f"Tokens: {tokens}")
print(f"Token IDs: {token_ids}")

# Create embeddings (simplified - real model learns these)
vocab_size = tokenizer.vocab_size
embedding = torch.nn.Embedding(vocab_size, d_model)
token_tensor = torch.tensor([token_ids])
x = embedding(token_tensor)

print(f"\nEmbedding shape: {x.shape}")

# Apply multi-head attention with causal mask
seq_len = len(token_ids)
mask = create_causal_mask(seq_len)
output, attn_weights = mha(x, mask=mask, return_attention=True)

# Visualize attention for head 0
plt.figure(figsize=(10, 8))
sns.heatmap(
    attn_weights[0, 0].detach().cpu().numpy(),
    annot=True,
    fmt=".2f",
    cmap="YlOrRd",
    xticklabels=tokens,
    yticklabels=tokens,
    square=True,
)
plt.title(f'Attention Pattern for: "{text}" (Head 1)')
plt.xlabel("Attending TO")
plt.ylabel("Attending FROM")
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Each row shows where that token is looking")
print("- Higher values mean stronger attention")
print("- Causal mask ensures we only look at past tokens")

## 8. Key Implementation Details

Let's examine some important implementation details from our codebase:

### 8.1 Why Scale by $\sqrt{d_k}$?

Without scaling, dot products can become very large, causing softmax to saturate:

```python
# Without scaling
scores = Q @ K.T  # Can be very large!
probs = softmax(scores)  # Gradients → 0 (vanishing gradients)

# With scaling
scores = (Q @ K.T) / sqrt(d_k)  # Keeps values reasonable
probs = softmax(scores)  # Healthy gradients
```

In [None]:
# Demonstrate the effect of scaling
d_k = 64
q = torch.randn(1, 10, d_k)
k = torch.randn(1, 10, d_k)

scores_unscaled = torch.matmul(q, k.transpose(-2, -1))
scores_scaled = scores_unscaled / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

print(f"Unscaled scores - Mean: {scores_unscaled.mean():.2f}, Std: {scores_unscaled.std():.2f}")
print(f"Scaled scores   - Mean: {scores_scaled.mean():.2f}, Std: {scores_scaled.std():.2f}")
print(f"\nScaling factor: 1/√{d_k} = {1/torch.sqrt(torch.tensor(d_k, dtype=torch.float32)):.4f}")

### 8.2 Handling NaN from Fully Masked Rows

When an entire row is masked (e.g., padding tokens), softmax can produce NaN:

```python
# Problem
scores = [-inf, -inf, -inf]  # All masked
probs = softmax(scores)      # [NaN, NaN, NaN]

# Solution in our code
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)  # Replace NaN with 0
```

### 8.3 Dropout in Attention

Our implementation applies dropout to:
1. **Attention weights** (after softmax)
2. **Output** (after final projection)

```python
attn_weights = self.attn_dropout(attn_weights)  # Regularize attention
output = self.out_dropout(self.out_proj(attn_output))  # Regularize output
```

This acts as a regularizer and prevents overfitting!

## 9. Key Takeaways

Let's recap what we've learned:

1. **Attention allows parallel processing** of sequences:
   - Each position looks at all other positions
   - No sequential dependency like RNNs
   - This is why transformers are so fast!

2. **Scaled dot-product attention** formula:
   $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
   - Query-Key similarity → attention weights
   - Attention weights → weighted sum of Values

3. **Causal masking** is essential for autoregressive models:
   - Prevents attending to future tokens
   - Creates lower triangular attention pattern
   - Enables left-to-right generation

4. **Multi-head attention** learns diverse patterns:
   - Multiple heads = multiple attention patterns
   - Different heads capture different relationships
   - Concatenate and project back to $d_{model}$

5. **Implementation details matter**:
   - Scale by $\sqrt{d_k}$ for gradient stability
   - Handle NaN from fully masked rows
   - Apply dropout for regularization

## Next Steps

Now that you understand attention, we're ready to build the **Transformer Block** - the complete building block that combines attention with feedforward networks!

Continue to **Notebook 03: Transformer Blocks** →

---

## Further Reading

- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Original Transformer paper)
- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/) (Jay Alammar)
- [Visualizing Attention in Transformer Models](https://arxiv.org/abs/1904.02679)