# Chapter 2.3 — Multi-Head Attention: Parallel “Views” of Meaning

*Understanding how multiple attention heads learn different aspects of context.*

Companion article: [Medium — Chapter 2.3](https://medium.com/@vadidsadikshaikh/chapter-2-3-multi-head-attention-parallel-views-of-meaning-5c47b51b9e73)

Reference: Sebastian Raschka, *Build a Large Language Model (From Scratch)*

Purpose: Implement multi-head attention by splitting queries, keys, and values into multiple heads, computing attention in parallel, and concatenating the results.

In [1]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Reuse scaled dot-product attention from Chapter 2.1
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = F.softmax(scores, dim=-1)
    output = torch.matmul(attn, V)
    return output, attn

In [2]:
# Define Multi-Head Attention class
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # Linear projection and split into heads
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Compute attention for each head
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # Final linear layer
        output = self.W_o(attn_output)
        return output, attn_weights

In [3]:
# Example usage
torch.manual_seed(0)
batch_size, seq_len, d_model, num_heads = 1, 5, 64, 8

x = torch.rand(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)

output, attn = mha(x, x, x)

print('Output shape:', output.shape)
print('Attention shape:', attn.shape)

Output shape: torch.Size([1, 5, 64])
Attention shape: torch.Size([1, 8, 5, 5])


### Notes:
- Each head processes a different “projection” of the same input.
- Attention heads learn different relationships — syntax, semantics, or position.
- Concatenating outputs allows the model to integrate diverse insights.
- In GPT-style architectures, multi-head attention is used in every transformer block.

**Next:** Chapter 2.4 — Dropout and Normalization: Keeping Learning Stable