# 4. Multi-Head Attention

**Combining multiple attention heads into a unified representation**

Alright. We've got attention outputs from both heads. Now what?

Time to combine them.

This is the "multi-head" part of multi-head attention (shocking, I know). Each head's been looking at the sequence through its own lens, learning different patterns and relationships. Now we need to merge these perspectives into a single, unified representation.

## Why Multiple Heads?

Think of it like having multiple experts examine the same data. Each one notices different things.

In a trained model, different heads genuinely specialize:
- **Head 0** might focus on local patterns (adjacent words, nearby relationships)
- **Head 1** might capture long-range dependencies (distant relationships, document structure)

Some might learn syntactic patterns—subject-verb agreement, grammatical structure. Others capture semantic relationships—what concepts are related, what words mean together.

Our model isn't trained yet (obviously), so the heads haven't learned these specializations. But the architecture is ready for it.

## The Algorithm

The process is pretty straightforward:

1. **Concatenate** the outputs from all heads
2. **Project** the concatenated result through a learned linear transformation

That's it. Two steps.

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_0, \text{head}_1) W_O$$

In [1]:
import random
import math

# Set seed for reproducibility
random.seed(42)

# Model hyperparameters
VOCAB_SIZE = 6
D_MODEL = 16
MAX_SEQ_LEN = 5
NUM_HEADS = 2
D_K = D_MODEL // NUM_HEADS  # 8

TOKEN_NAMES = ["<PAD>", "<BOS>", "<EOS>", "I", "like", "transformers"]

In [2]:
# Helper functions (same as previous notebooks)
def random_vector(size, scale=0.1):
    return [random.gauss(0, scale) for _ in range(size)]

def random_matrix(rows, cols, scale=0.1):
    return [[random.gauss(0, scale) for _ in range(cols)] for _ in range(rows)]

def add_vectors(v1, v2):
    return [a + b for a, b in zip(v1, v2)]

def matmul(A, B):
    m, n = len(A), len(A[0])
    p = len(B[0])
    return [[sum(A[i][k] * B[k][j] for k in range(n)) for j in range(p)] for i in range(m)]

def transpose(A):
    return [[A[i][j] for i in range(len(A))] for j in range(len(A[0]))]

def softmax(vec):
    max_val = max(v for v in vec if v != float('-inf'))
    exp_vec = [math.exp(v - max_val) if v != float('-inf') else 0 for v in vec]
    sum_exp = sum(exp_vec)
    return [e / sum_exp for e in exp_vec]

def format_vector(vec, decimals=4):
    return "[" + ", ".join([f"{v:7.{decimals}f}" for v in vec]) + "]"

In [3]:
# Recreate everything from previous notebooks
E_token = [random_vector(D_MODEL) for _ in range(VOCAB_SIZE)]
E_pos = [random_vector(D_MODEL) for _ in range(MAX_SEQ_LEN)]
tokens = [1, 3, 4, 5, 2]
seq_len = len(tokens)
X = [add_vectors(E_token[tokens[i]], E_pos[i]) for i in range(seq_len)]

# QKV weights and projections
W_Q = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_K = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
W_V = [random_matrix(D_MODEL, D_K) for _ in range(NUM_HEADS)]
Q_all = [matmul(X, W_Q[h]) for h in range(NUM_HEADS)]
K_all = [matmul(X, W_K[h]) for h in range(NUM_HEADS)]
V_all = [matmul(X, W_V[h]) for h in range(NUM_HEADS)]

# Compute attention for each head
def compute_attention(Q, K, V):
    seq_len, d_k = len(Q), len(Q[0])
    scale = math.sqrt(d_k)
    scores = matmul(Q, transpose(K))
    scaled = [[s / scale for s in row] for row in scores]
    for i in range(seq_len):
        for j in range(seq_len):
            if j > i:
                scaled[i][j] = float('-inf')
    weights = [softmax(row) for row in scaled]
    return matmul(weights, V)

attention_output_all = [compute_attention(Q_all[h], K_all[h], V_all[h]) for h in range(NUM_HEADS)]
print("Recreated attention outputs from previous notebooks")

Recreated attention outputs from previous notebooks


## Step 1: Concatenate Head Outputs

Each head produced an output of shape $[5, 8]$ (5 tokens, 8 dimensions per head). We concatenate along the feature dimension to get $[5, 16]$.

We literally just stick the vectors together, end to end. Head 0's 8 dimensions followed by Head 1's 8 dimensions = 16 dimensions total.

In [4]:
# Concatenate head outputs
concat_output = []
for i in range(seq_len):
    # Concatenate head 0 output with head 1 output for each position
    concat_output.append(attention_output_all[0][i] + attention_output_all[1][i])

print("Concatenated Output")
print(f"Shape: [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(concat_output):
    print(f"  {format_vector(row)}  # pos {i}: {TOKEN_NAMES[tokens[i]]}")

Concatenated Output
Shape: [5, 16]

  [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737,  0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]  # pos 0: <BOS>
  [ 0.0683,  0.0368, -0.0263, -0.0574,  0.0152, -0.0174, -0.0084, -0.0760, -0.0199, -0.0151,  0.0026,  0.0107,  0.0091, -0.0204, -0.0320, -0.0193]  # pos 1: I
  [ 0.0247,  0.0789,  0.0074, -0.0635,  0.0180, -0.0098, -0.0184, -0.0173, -0.0320, -0.0102,  0.0178, -0.0153,  0.0433,  0.0026,  0.0002, -0.0198]  # pos 2: like
  [ 0.0254,  0.0511, -0.0182, -0.0322,  0.0103, -0.0126, -0.0282,  0.0018, -0.0111, -0.0085,  0.0093,  0.0101,  0.0440,  0.0237,  0.0056, -0.0311]  # pos 3: transformers
  [ 0.0325,  0.0367, -0.0202, -0.0262,  0.0188, -0.0040, -0.0321,  0.0167, -0.0119, -0.0013, -0.0069,  0.0016,  0.0480,  0.0233,  0.0096, -0.0121]  # pos 4: <EOS>


In [5]:
# Show the concatenation for position 0
print("Example: Position 0 (<BOS>)")
print("="*60)
print()
print(f"Head 0 output: {format_vector(attention_output_all[0][0])}")
print(f"Head 1 output: {format_vector(attention_output_all[1][0])}")
print()
print(f"Concatenated:  {format_vector(concat_output[0])}")

Example: Position 0 (<BOS>)

Head 0 output: [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737]
Head 1 output: [ 0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]

Concatenated:  [ 0.0800,  0.0257, -0.0117, -0.1056,  0.0339, -0.0891, -0.0083, -0.0737,  0.0107, -0.0291, -0.0100, -0.0312,  0.0214,  0.0372,  0.0105,  0.0279]


## Step 2: Output Projection

Now we've got 16-dimensional vectors, and we need to project them using a learned weight matrix $W_O$.

**Wait—why project if we're already at the right dimension?**

Even though the dimensions match, the projection serves a critical purpose: it lets the model learn how to **mix information** from different heads.

Without this projection, Head 0 and Head 1 would operate completely independently. The projection matrix $W_O$ allows the model to learn optimal combinations.

In [6]:
# Initialize output projection matrix
W_O = random_matrix(D_MODEL, D_MODEL)  # [16, 16]

print(f"Output Projection Matrix W_O")
print(f"Shape: [{D_MODEL}, {D_MODEL}]")

Output Projection Matrix W_O
Shape: [16, 16]


In [7]:
# Apply output projection: output = concat @ W_O^T
W_O_T = transpose(W_O)
multi_head_output = matmul(concat_output, W_O_T)

print("Multi-Head Attention Output")
print(f"Shape: [{seq_len}, {D_MODEL}]")
print()
for i, row in enumerate(multi_head_output):
    print(f"  {format_vector(row)}  # pos {i}: {TOKEN_NAMES[tokens[i]]}")

Multi-Head Attention Output
Shape: [5, 16]

  [ 0.0334,  0.0033, -0.0041, -0.0073,  0.0185,  0.0074,  0.0169,  0.0107,  0.0277,  0.0060,  0.0222,  0.0241,  0.0074,  0.0067, -0.0067,  0.0063]  # pos 0: <BOS>
  [ 0.0269,  0.0066,  0.0113, -0.0154,  0.0114,  0.0032, -0.0065, -0.0108,  0.0190, -0.0091,  0.0180,  0.0097, -0.0075,  0.0061, -0.0079,  0.0110]  # pos 1: I
  [ 0.0085,  0.0086,  0.0159, -0.0177,  0.0026,  0.0205, -0.0057, -0.0055,  0.0059, -0.0043,  0.0007, -0.0053,  0.0075, -0.0012, -0.0043, -0.0016]  # pos 2: like
  [ 0.0064, -0.0084,  0.0092, -0.0173,  0.0068,  0.0119, -0.0100, -0.0027,  0.0027, -0.0073,  0.0036, -0.0076,  0.0022, -0.0070, -0.0095, -0.0070]  # pos 3: transformers
  [ 0.0039, -0.0068,  0.0098, -0.0136,  0.0031,  0.0090, -0.0086, -0.0027, -0.0003, -0.0044, -0.0029, -0.0062,  0.0060, -0.0048, -0.0036, -0.0115]  # pos 4: <EOS>


## What Have We Accomplished?

Starting from the original embeddings, we've now:

1. **Projected** into queries, keys, and values for each head
2. **Computed attention** in each head independently
3. **Combined** the heads through concatenation and projection

Each token's representation now contains:
- Information from other tokens it attended to
- Patterns detected by multiple attention heads
- A richer, more context-aware representation than the original embeddings

In [8]:
# Compare before and after for position 1
print("Position 1 ('I') - Before and After Attention")
print("="*60)
print()
print(f"Original embedding X[1]:")
print(f"  {format_vector(X[1])}")
print()
print(f"After multi-head attention:")
print(f"  {format_vector(multi_head_output[1])}")

Position 1 ('I') - Before and After Attention

Original embedding X[1]:
  [-0.1254, -0.0720,  0.1255, -0.0556, -0.0678,  0.3698, -0.1265, -0.1463,  0.0866,  0.0181,  0.0726, -0.0374,  0.2312, -0.0091,  0.0860, -0.0251]

After multi-head attention:
  [ 0.0269,  0.0066,  0.0113, -0.0154,  0.0114,  0.0032, -0.0065, -0.0108,  0.0190, -0.0091,  0.0180,  0.0097, -0.0075,  0.0061, -0.0079,  0.0110]


## Dimensions At Each Stage

| Stage | Shape | Description |
|-------|-------|-------------|
| Input $X$ | $[5, 16]$ | Original embeddings |
| After Q/K/V projection (per head) | $[5, 8]$ | Each head projects to smaller dimension |
| Attention weights (per head) | $[5, 5]$ | How much each position attends to others |
| Attention output (per head) | $[5, 8]$ | Weighted sum of values |
| After concatenation | $[5, 16]$ | Heads combined side-by-side |
| After output projection | $[5, 16]$ | Final multi-head attention output |

Notice we start at $[5, 16]$ and end at $[5, 16]$. The attention mechanism is a transformation that preserves the shape while enriching the content.

## What's Next

Multi-head attention is done! But we're not finished with the transformer block yet:

1. **Feed-forward network** — Apply position-wise non-linear transformations
2. **Residual connections** — Add the original input back to prevent information loss
3. **Layer normalization** — Stabilize the activations

Then we'll project to vocabulary and compute the loss.

In [9]:
# Store for next notebook
multi_head_data = {
    'X': X,
    'tokens': tokens,
    'multi_head_output': multi_head_output,
    'W_O': W_O
}
print("Multi-head data stored for next notebook.")

Multi-head data stored for next notebook.
