# Attention — Try it in PyTorch

This is an **optional** hands-on companion to [Chapter 7](https://learnai.robennals.org/07-attention). You'll implement attention from scratch: queries, keys, values, softmax, multiple heads, and rotary position encoding.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

## Dot Products Measure Similarity

Attention is built on dot products — they measure how similar two vectors are. Let's refresh: two vectors pointing in the same direction have a high dot product, orthogonal vectors have zero, and opposite vectors are negative.

In [None]:
# Dot products measure similarity between vectors
a = torch.tensor([1.0, 0.0])   # pointing right
b = torch.tensor([0.8, 0.6])   # mostly right, a bit up
c = torch.tensor([0.0, 1.0])   # pointing up (perpendicular to a)
d = torch.tensor([-1.0, 0.0])  # pointing left (opposite of a)

print("Dot products with a = [1, 0] (pointing right):")
print(f"  a · b (similar direction):  {torch.dot(a, b).item():.2f}")
print(f"  a · c (perpendicular):      {torch.dot(a, c).item():.2f}")
print(f"  a · d (opposite):           {torch.dot(a, d).item():.2f}")
print("\nAttention uses dot products to find which words are relevant to each other.")

## Softmax: Turning Scores into Percentages

After computing dot products, we need to turn them into weights that add up to 1. **Softmax** does this — it exponentiates each score and normalizes. The biggest score dominates, and small scores shrink to nearly zero.

In [None]:
scores = torch.tensor([2.0, 1.0, 0.1, -1.0])
weights = F.softmax(scores, dim=-1)

print(f"Raw scores:     {scores.tolist()}")
print(f"After softmax:  {[f'{w:.3f}' for w in weights.tolist()]}")
print(f"Sum:            {weights.sum().item():.3f}")
print("\nThe biggest score (2.0) gets most of the weight.")

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))
labels = ["word A", "word B", "word C", "word D"]
ax1.bar(labels, scores.numpy())
ax1.set_title("Raw dot-product scores")
ax1.set_ylabel("Score")

ax2.bar(labels, weights.numpy(), color="orange")
ax2.set_title("After softmax (probabilities)")
ax2.set_ylabel("Weight")
ax2.set_ylim(0, 1)

plt.tight_layout()
plt.show()

## Queries, Keys, and Values from Scratch

Each word gets three representations:
- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What do I offer?"
- **Value (V)**: "What information do I give?"

These are computed by multiplying the word's embedding by three learned weight matrices. Let's build this step by step.

In [None]:
# A tiny sentence with pre-defined embeddings
words = ["The", "cat", "sat", "on", "it"]
seq_len = len(words)
embed_dim = 4

# Pretend embeddings (in reality these come from nn.Embedding)
torch.manual_seed(42)
embeddings = torch.randn(seq_len, embed_dim)

# Learned weight matrices that create Q, K, V
head_dim = 4
W_Q = nn.Linear(embed_dim, head_dim, bias=False)
W_K = nn.Linear(embed_dim, head_dim, bias=False)
W_V = nn.Linear(embed_dim, head_dim, bias=False)

# Compute Q, K, V for every word
Q = W_Q(embeddings)  # (seq_len, head_dim)
K = W_K(embeddings)
V = W_V(embeddings)

print(f"Embeddings shape: {embeddings.shape}")
print(f"Q shape: {Q.shape}  (one query vector per word)")
print(f"K shape: {K.shape}  (one key vector per word)")
print(f"V shape: {V.shape}  (one value vector per word)")

## Computing Attention Weights

Now the attention mechanism:
1. Dot product between each query and every key → relevance scores
2. Scale by √(head_dim) to keep values from getting too large
3. Softmax to turn scores into weights
4. Weighted sum of values

In [None]:
# Step 1: Dot products between all queries and all keys
# Q @ K^T gives a (seq_len × seq_len) matrix of scores
scores = Q @ K.T
print(f"Attention scores (before scaling):\n{scores.detach().numpy().round(2)}")

# Step 2: Scale by sqrt(head_dim)
scale = head_dim ** 0.5
scores_scaled = scores / scale

# Step 3: Softmax along the key dimension
attn_weights = F.softmax(scores_scaled, dim=-1)
print(f"\nAttention weights (after softmax):")
print(f"Each row sums to 1 — it's how much each word attends to every other word.")
for i, word in enumerate(words):
    weights_str = "  ".join(f"{words[j]}:{attn_weights[i,j].item():.2f}" for j in range(seq_len))
    print(f"  {word:>3} attends to: {weights_str}")

# Step 4: Weighted sum of values
output = attn_weights @ V
print(f"\nOutput shape: {output.shape}")
print("Each word now has a new representation enriched with context from relevant words.")

In [None]:
# Visualize the attention pattern
plt.figure(figsize=(5, 4))
plt.imshow(attn_weights.detach().numpy(), cmap="Blues", vmin=0, vmax=1)
plt.xticks(range(seq_len), words)
plt.yticks(range(seq_len), words)
plt.xlabel("Attending to (keys)")
plt.ylabel("From (queries)")
plt.title("Attention weights")
plt.colorbar(label="Weight")
for i in range(seq_len):
    for j in range(seq_len):
        plt.text(j, i, f"{attn_weights[i,j].item():.2f}", ha="center", va="center", fontsize=8)
plt.tight_layout()
plt.show()

## Using PyTorch's Built-in Attention

PyTorch provides `nn.MultiheadAttention` that does all of the above in one call. Let's verify our manual version matches.

In [None]:
# PyTorch's built-in scaled dot-product attention
# (added in PyTorch 2.0)
with torch.no_grad():
    # Our manual Q, K, V — add a batch dimension
    out_manual = F.scaled_dot_product_attention(
        Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0)
    ).squeeze(0)

print("Our manual output (first word):")
print(f"  {output[0].detach().numpy().round(4)}")
print("PyTorch's built-in output (first word):")
print(f"  {out_manual[0].numpy().round(4)}")
print("\nThey match! Our manual implementation is correct.")

## Multiple Attention Heads

One attention head can only ask one kind of question. Multiple heads run in parallel, each with their own Q, K, V weights. One head might track grammar, another might track pronoun references, another might find the sentence topic.

In [None]:
n_heads = 3
head_dim = 4

# Each head has its own Q, K, V projections
torch.manual_seed(0)
heads = []
for h in range(n_heads):
    wq = nn.Linear(embed_dim, head_dim, bias=False)
    wk = nn.Linear(embed_dim, head_dim, bias=False)
    wv = nn.Linear(embed_dim, head_dim, bias=False)
    heads.append((wq, wk, wv))

# Run each head and collect attention patterns
fig, axes = plt.subplots(1, n_heads, figsize=(4 * n_heads, 3.5))

for h, (wq, wk, wv) in enumerate(heads):
    with torch.no_grad():
        q = wq(embeddings)
        k = wk(embeddings)
        scores = q @ k.T / (head_dim ** 0.5)
        weights = F.softmax(scores, dim=-1)

    ax = axes[h]
    ax.imshow(weights.numpy(), cmap="Blues", vmin=0, vmax=1)
    ax.set_xticks(range(seq_len))
    ax.set_xticklabels(words, fontsize=9)
    ax.set_yticks(range(seq_len))
    ax.set_yticklabels(words, fontsize=9)
    ax.set_title(f"Head {h + 1}")

plt.suptitle("Each head learns different attention patterns", y=1.02)
plt.tight_layout()
plt.show()

print("Different random weights → different attention patterns.")
print("After training, each head specializes in a different relationship type.")

## Attention Is Order-Blind

Here's a surprising fact: if we scramble the word order, the attention scores between the same pairs of words don't change. Attention only sees embeddings — it has no idea *where* words are in the sentence.

In [None]:
# Original order
original_order = [0, 1, 2, 3, 4]

# Scrambled order
scrambled_order = [4, 2, 0, 3, 1]

emb_original = embeddings[original_order]
emb_scrambled = embeddings[scrambled_order]

with torch.no_grad():
    q_orig = W_Q(emb_original)
    k_orig = W_K(emb_original)
    scores_orig = q_orig @ k_orig.T

    q_scram = W_Q(emb_scrambled)
    k_scram = W_K(emb_scrambled)
    scores_scram = q_scram @ k_scram.T

# Check: what's the score between "cat" (idx 1) and "it" (idx 4)?
cat_it_original = scores_orig[1, 4].item()  # cat=pos1, it=pos4 in original
# In scrambled: cat is at position 4, it is at position 0
cat_it_scrambled = scores_scram[4, 0].item()

print(f"Score between 'cat' and 'it':")
print(f"  Original order:  {cat_it_original:.4f}")
print(f"  Scrambled order: {cat_it_scrambled:.4f}")
print(f"  Identical! Attention doesn't know where words are.")
print(f"\n'The cat sat on it' and 'it sat The on cat' look the same to attention.")
print("We need position information!")

## Rotary Position Encoding (RoPE)

The fix: **rotate** each word's embedding by an angle based on its position. The dot product of two rotated vectors depends on the *difference* in their angles — the *relative distance* between the words.

Two words 3 positions apart always have the same angular difference, whether at positions 1 & 4 or positions 50 & 53.

In [None]:
def apply_rope(x, positions, base=10.0):
    """Apply rotary position encoding to vectors.
    
    x: (seq_len, dim) — the vectors to rotate
    positions: (seq_len,) — position of each vector
    """
    dim = x.shape[-1]
    # Each pair of dimensions gets rotated by a different frequency
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    # Angle = position × frequency
    angles = positions.unsqueeze(-1) * freqs.unsqueeze(0)  # (seq_len, dim//2)
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    # Rotate pairs of dimensions
    x1 = x[..., 0::2]  # even dimensions
    x2 = x[..., 1::2]  # odd dimensions
    rotated = torch.stack([
        x1 * cos - x2 * sin,
        x1 * sin + x2 * cos,
    ], dim=-1).flatten(-2)
    return rotated

# Demonstrate: the dot product depends on RELATIVE position, not absolute
dim = 8
torch.manual_seed(7)
vec_a = torch.randn(1, dim)
vec_b = torch.randn(1, dim)

print("Dot product of the same two vectors at different absolute positions:")
print(f"{'Pos A':>6} {'Pos B':>6} {'Gap':>5} {'Dot Product':>12}")
print("-" * 35)

for pos_a, pos_b in [(1, 4), (10, 13), (50, 53), (100, 103)]:
    ra = apply_rope(vec_a, torch.tensor([float(pos_a)]))
    rb = apply_rope(vec_b, torch.tensor([float(pos_b)]))
    dot = (ra * rb).sum().item()
    print(f"{pos_a:>6} {pos_b:>6} {pos_b-pos_a:>5} {dot:>12.4f}")

print("\nSame gap = same dot product, regardless of absolute position!")

print("\nNow change the gap:")
print(f"{'Pos A':>6} {'Pos B':>6} {'Gap':>5} {'Dot Product':>12}")
print("-" * 35)
for gap in [1, 3, 5, 10, 20]:
    ra = apply_rope(vec_a, torch.tensor([0.0]))
    rb = apply_rope(vec_b, torch.tensor([float(gap)]))
    dot = (ra * rb).sum().item()
    print(f"{0:>6} {gap:>6} {gap:>5} {dot:>12.4f}")

print("\nDifferent gap = different dot product. The model can tell distance!")

In [None]:
# Visualize: rotate a 2D vector by different positions
angles_viz = torch.linspace(0, 2 * np.pi, 20)
vec = torch.tensor([1.0, 0.0])  # unit vector pointing right

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

# Left: vectors at different positions
positions = torch.arange(8).float()
colors = plt.cm.viridis(np.linspace(0, 0.9, len(positions)))
for i, pos in enumerate(positions):
    rotated = apply_rope(vec.unsqueeze(0), pos.unsqueeze(0)).squeeze()
    ax1.arrow(0, 0, rotated[0].item() * 0.9, rotated[1].item() * 0.9,
              head_width=0.05, color=colors[i], linewidth=2)
    ax1.annotate(f"pos {int(pos)}", (rotated[0].item(), rotated[1].item()),
                fontsize=8, ha='center')

ax1.set_xlim(-1.3, 1.3)
ax1.set_ylim(-1.3, 1.3)
ax1.set_aspect('equal')
ax1.set_title("Same vector rotated by position")
ax1.grid(True, alpha=0.3)

# Right: dot product vs distance
gaps = torch.arange(0, 20).float()
dots = []
for gap in gaps:
    ra = apply_rope(vec_a, torch.tensor([0.0]))
    rb = apply_rope(vec_b, torch.tensor([gap]))
    dots.append((ra * rb).sum().item())

ax2.plot(gaps.numpy(), dots, 'o-', markersize=4)
ax2.set_xlabel("Distance between words (gap)")
ax2.set_ylabel("Dot product")
ax2.set_title("Dot product depends on relative distance")
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Putting It All Together: Self-Attention with RoPE

Let's combine everything — Q/K/V projections, rotary positions, scaled dot-product attention, and multiple heads — into a complete self-attention module.

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super().__init__()
        assert embed_dim % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        self.W_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        self.W_out = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x):
        seq_len, dim = x.shape
        # Project to Q, K, V
        qkv = self.W_qkv(x)  # (seq_len, 3 * embed_dim)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape for multiple heads: (n_heads, seq_len, head_dim)
        q = q.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
        k = k.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
        v = v.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)

        # Apply rotary position encoding
        positions = torch.arange(seq_len).float()
        q = apply_rope(q, positions.unsqueeze(0).expand(self.n_heads, -1).reshape(-1)).view_as(q)
        k = apply_rope(k, positions.unsqueeze(0).expand(self.n_heads, -1).reshape(-1)).view_as(k)

        # Scaled dot-product attention per head
        scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)
        out = attn @ v  # (n_heads, seq_len, head_dim)

        # Concatenate heads and project
        out = out.transpose(0, 1).contiguous().view(seq_len, -1)
        return self.W_out(out), attn

# Test it
torch.manual_seed(42)
attn_layer = SelfAttention(embed_dim=8, n_heads=2)
test_input = torch.randn(5, 8)  # 5 words, 8-dim embeddings

with torch.no_grad():
    output, attn_weights = attn_layer(test_input)

print(f"Input shape:  {test_input.shape}  (5 words, 8 dimensions)")
print(f"Output shape: {output.shape}  (same — enriched with context)")
print(f"Attention weights shape: {attn_weights.shape}  (2 heads, 5×5 attention matrix)")
print("\nEach word now carries information from the words it attended to.")

---

*This notebook accompanies [Chapter 7: Attention](https://learnai.robennals.org/07-attention). The interactive widgets in the web version let you step through attention computations, see multiple heads, scramble word positions, and explore rotary encoding visually.*

*New to PyTorch? See the [PyTorch from Scratch](https://learnai.robennals.org/appendix-pytorch) appendix for a beginner-friendly introduction.*