# 3. Multi-Head Attention

Multi-head attention runs multiple attention mechanisms in parallel, each with different learned projections.
This allows the model to attend to different types of information simultaneously!


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


## Why Multiple Heads?

Single attention learns one type of relationship. Multi-head attention learns multiple types:
- Head 1 might focus on syntactic relationships
- Head 2 might focus on semantic relationships
- Head 3 might focus on long-range dependencies
- etc.


In [None]:
# Multi-head attention splits d_model into num_heads smaller subspaces
d_model = 64
num_heads = 8
d_k = d_model // num_heads  # Dimension per head

print(f"d_model: {d_model}")
print(f"num_heads: {num_heads}")
print(f"d_k (dimension per head): {d_k}")
print(f"\nEach head operates in a {d_k}-dimensional subspace!")


## Multi-Head Attention Implementation

1. Split Q, K, V into multiple heads
2. Apply attention to each head independently
3. Concatenate all heads
4. Apply output projection


In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention from scratch"""
    
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        Returns:
            output: [batch_size, seq_len, d_model]
        """
        batch_size, seq_len, d_model = x.shape
        
        # Create Q, K, V
        Q = self.W_q(x)  # [batch_size, seq_len, d_model]
        K = self.W_k(x)  # [batch_size, seq_len, d_model]
        V = self.W_v(x)  # [batch_size, seq_len, d_model]
        
        # Reshape for multi-head: [batch_size, num_heads, seq_len, d_k]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention for each head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        attention_weights = F.softmax(scores, dim=-1)
        attention_output = torch.matmul(attention_weights, V)  # [batch_size, num_heads, seq_len, d_k]
        
        # Concatenate heads: [batch_size, seq_len, d_model]
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, d_model)
        
        # Output projection
        output = self.W_o(attention_output)
        
        return output

# Test multi-head attention
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)

output = mha(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nMulti-head attention successfully processes {num_heads} attention heads in parallel!")
