# Building LLMs From Scratch (Part 9): Multi-Head Attention

Welcome to Part 9! In this notebook, we'll implement **Multi-Head Attention**, the production-ready attention mechanism used in modern Transformers like GPT-4 and Llama.

### üîó Quick Links
- **Medium Article**: [Part 9: Multi-Head Attention](https://medium.com/@soloshun/building-llms-from-scratch-part-9-multi-head-attention)
- **GitHub Repository**: [llm-from-scratch](https://github.com/soloeinsteinmit/llm-from-scratch)

### üìã What We'll Cover
1. **The Concept**: Why we need multiple heads
2. **Two Approaches**: Wrapper vs. Efficient Weight Split
3. **The Dimensions**: Understanding `d_out`, `num_heads`, and `head_dim`
4. **Implementation**: Building `MultiHeadAttention` in PyTorch
5. **Shape Tracing**: Step-by-step tensor transformations


## Setup and Imports


In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

print("‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")


## The Concept: Why Multiple Heads?

Imagine reading a complex sentence. You might need to track:
1. **Grammar**: Which noun corresponds to this verb?
2. **Sentiment**: Is this sentence positive or negative?
3. **Facts**: What specific entities are mentioned?

A single attention head can only learn one type of relationship. **Multi-Head Attention** allows the model to learn multiple types of relationships in parallel!


## Setup: CausalAttention (from Part 8)

First, let's bring in our `CausalAttention` class from Part 8. We'll use it as a building block.


In [None]:
class CausalAttention(nn.Module):
    """Single-head causal attention from Part 8"""
    
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)
        
        context_vec = attn_weights @ values
        return context_vec

print("‚úÖ CausalAttention class defined!")


## Approach 1: The Wrapper (Naive)

The simplest way to implement multi-head attention is to create multiple `CausalAttention` instances and concatenate their outputs.


In [None]:
class MultiHeadAttentionWrapper(nn.Module):
    """
    Multi-Head Attention using the Wrapper approach.
    
    Pros: Easy to understand
    Cons: Less efficient (many small matrix operations)
    """
    
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )
        
    def forward(self, x):
        # Run each head independently and concatenate
        return torch.cat([head(x) for head in self.heads], dim=-1)

print("‚úÖ MultiHeadAttentionWrapper class defined!")


Let's test the wrapper approach:


In [None]:
# Setup
inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55],
])

batch = torch.stack((inputs, inputs), dim=0)

print(f"Input shape: {batch.shape}")
print(f"  - Batch size: {batch.shape[0]}")
print(f"  - Sequence length: {batch.shape[1]}")
print(f"  - Embedding dim: {batch.shape[2]}")

# Create wrapper multi-head attention
torch.manual_seed(123)
d_in, d_out = 3, 2
num_heads = 2

mha_wrapper = MultiHeadAttentionWrapper(
    d_in=d_in,
    d_out=d_out,
    context_length=batch.shape[1],
    dropout=0.0,
    num_heads=num_heads
)

context_vecs = mha_wrapper(batch)

print(f"\nOutput shape: {context_vecs.shape}")
print(f"  - Note: d_out ({d_out}) √ó num_heads ({num_heads}) = {d_out * num_heads}")
print(f"\nOutput:\n{context_vecs}")

print("\n‚úÖ Wrapper approach: Each head processes independently,")
print("   then outputs are concatenated!")


## Approach 2: Weight Splits (Efficient)

The production-ready approach used in PyTorch, TensorFlow, and all modern Transformers. Instead of separate layers, we:
1. Create ONE large set of Q, K, V weights
2. Reshape to split into multiple heads
3. Process all heads in parallel
4. Concatenate heads back together


### Understanding the Dimensions

Before coding, let's clarify the math:
- `d_in` = Input embedding dimension
- `d_out` = Total output dimension
- `num_heads` = Number of attention heads
- `head_dim` = `d_out / num_heads`

Example: If `d_out=6` and `num_heads=2`, then `head_dim=3`.


Let's test the efficient implementation:


In [None]:
torch.manual_seed(789)
d_in, d_out = 3, 6
num_heads = 2

mha = MultiHeadAttention(
    d_in=d_in,
    d_out=d_out,
    context_length=batch.shape[1],
    dropout=0.0,
    num_heads=num_heads
)

context_vecs = mha(batch)

print(f"Input shape: {batch.shape}")
print(f"Output shape: {context_vecs.shape}")
print(f"  - d_out remains {d_out}")
print(f"  - head_dim = d_out / num_heads = {d_out} / {num_heads} = {d_out // num_heads}")
print(f"\nOutput:\n{context_vecs}")

print("\n‚úÖ Efficient approach works!")


## Shape Tracing

Let's trace the tensor transformations step by step to understand exactly what's happening:


In [None]:
# Configuration
b, num_tokens, d_in = 1, 3, 6
d_out = 6
num_heads = 2
head_dim = d_out // num_heads

print(f"Configuration:")
print(f"  - Batch size: {b}")
print(f"  - Sequence length: {num_tokens}")
print(f"  - Input dimension: {d_in}")
print(f"  - Output dimension: {d_out}")
print(f"  - Number of heads: {num_heads}")
print(f"  - Head dimension: {head_dim}")

# Create dummy input
x = torch.randn(b, num_tokens, d_in)

# Create model
torch.manual_seed(42)
mha = MultiHeadAttention(d_in, d_out, num_tokens, 0.0, num_heads)

# Manual forward pass with shape printing
print(f"\nüìä Step-by-Step Shape Transformations:")
print(f"1. Input:           {tuple(x.shape)}")

queries = mha.W_query(x)
print(f"2. After Linear:    {tuple(queries.shape)}")

queries = queries.view(b, num_tokens, num_heads, head_dim)
print(f"3. After Reshape:   {tuple(queries.shape)}")

queries = queries.transpose(1, 2)
print(f"4. After Transpose: {tuple(queries.shape)}")

# Simulate attention scores
keys = mha.W_key(x).view(b, num_tokens, num_heads, head_dim).transpose(1, 2)
attn_scores = queries @ keys.transpose(2, 3)
print(f"5. Attention Scores: {tuple(attn_scores.shape)}")
print(f"   (Last 2 dims are {num_tokens}√ó{num_tokens} attention matrix)")

# Complete forward pass
output = mha(x)
print(f"6. Final Output:    {tuple(output.shape)}")

print(f"\n‚úÖ Shape preserved! Input {tuple(x.shape)} ‚Üí Output {tuple(output.shape)}")


## Visualizing Multi-Head Attention

Let's visualize how different heads might learn different patterns:


In [None]:
import matplotlib.pyplot as plt

# Create a simple example
torch.manual_seed(42)
seq_len = 6
d_in = 8
d_out = 8
num_heads = 2

x = torch.randn(1, seq_len, d_in)
mha = MultiHeadAttention(d_in, d_out, seq_len, 0.0, num_heads)

# Get attention weights (we'll extract them manually)
b, num_tokens, _ = x.shape
keys = mha.W_key(x).view(b, num_tokens, num_heads, d_out // num_heads).transpose(1, 2)
queries = mha.W_query(x).view(b, num_tokens, num_heads, d_out // num_heads).transpose(1, 2)

attn_scores = queries @ keys.transpose(2, 3)
mask_bool = mha.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for head_idx in range(num_heads):
    ax = axes[head_idx]
    weights = attn_weights[0, head_idx].detach().numpy()
    
    im = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1)
    ax.set_title(f'Head {head_idx + 1} Attention Weights', fontsize=14)
    ax.set_xlabel('Key Position (attending TO)')
    ax.set_ylabel('Query Position (attending FROM)')
    
    # Add values
    for i in range(seq_len):
        for j in range(seq_len):
            text = ax.text(j, i, f'{weights[i, j]:.2f}',
                          ha="center", va="center",
                          color="white" if weights[i, j] > 0.5 else "black",
                          fontsize=9)
    
    plt.colorbar(im, ax=ax, shrink=0.8)

plt.tight_layout()
plt.show()

print("‚úÖ Notice how each head learns different attention patterns!")


## Performance Comparison

Let's compare the wrapper approach vs. the efficient approach:


In [None]:
import time

# Setup
batch_size = 8
seq_len = 128
d_in = 512
d_out_per_head = 64
num_heads = 8

x = torch.randn(batch_size, seq_len, d_in)

# Wrapper approach
print("Testing Wrapper Approach...")
torch.manual_seed(123)
wrapper = MultiHeadAttentionWrapper(d_in, d_out_per_head, seq_len, 0.0, num_heads)

start = time.time()
for _ in range(10):
    _ = wrapper(x)
wrapper_time = time.time() - start

# Efficient approach
print("Testing Efficient Approach...")
torch.manual_seed(123)
efficient = MultiHeadAttention(d_in, d_out_per_head * num_heads, seq_len, 0.0, num_heads)

start = time.time()
for _ in range(10):
    _ = efficient(x)
efficient_time = time.time() - start

print(f"\nüìä Results (10 forward passes):")
print(f"  Wrapper Approach:   {wrapper_time:.4f}s")
print(f"  Efficient Approach: {efficient_time:.4f}s")
print(f"  Speedup: {wrapper_time / efficient_time:.2f}x")

print(f"\n‚úÖ The efficient approach is significantly faster!")
print("   This is why all production Transformers use it.")


## Summary and Key Takeaways

### üéØ What We Learned

1. **Why Multi-Head?**: Single-head attention can only learn one type of relationship. Multiple heads allow the model to learn different aspects (grammar, sentiment, facts, etc.) in parallel.

2. **Two Approaches**:
   - **Wrapper**: Easy to understand, but slower (many small matrix operations)
   - **Weight Split**: Production-ready, much faster (single large matrix operation)

3. **The Math**: 
   - `head_dim = d_out / num_heads`
   - Example: `d_out=6`, `num_heads=2` ‚Üí `head_dim=3`

4. **The Transformation**: 
   - Input: `(batch, tokens, d_in)`
   - Linear: `(batch, tokens, d_out)`
   - Reshape: `(batch, tokens, num_heads, head_dim)`
   - Transpose: `(batch, num_heads, tokens, head_dim)`
   - Attention: Process all heads in parallel
   - Combine: `(batch, tokens, d_out)`

### ‚úÖ Our Implementation Now Has

- ‚úÖ Input Embeddings
- ‚úÖ Positional Encodings
- ‚úÖ **Causal Multi-Head Attention** ‚Üê We are here!
- ‚è≠Ô∏è Dropout & Layer Normalization (coming soon)
- ‚è≠Ô∏è Feed Forward Networks (coming soon)

### üîú What's Next?

In **Part 10**, we'll zoom out and take a **Bird's Eye View of the LLM Architecture**. We'll see how all these pieces (embeddings, attention, feedforward) fit together to form the complete Transformer!

---

### üìö Resources

- **Medium Article**: [Building LLMs From Scratch (Part 9): Multi-Head Attention](https://medium.com/@soloshun/building-llms-from-scratch-part-9-multi-head-attention)
- **GitHub**: [llm-from-scratch](https://github.com/soloeinsteinmit/llm-from-scratch)
- **Previous Part**: [Part 8: Causal Attention](https://medium.com/@soloshun/building-llms-from-scratch-part-8-causal-attention)

---

Thank you for following along! üôè
