# Self-Attention from Scratch

Implementing scaled dot-product attention and multi-head attention in PyTorch — the core building block behind every modern vision transformer.

## Overview

**Builds:** `scaled_dot_product_attention` · `MultiHeadAttention`
**Concepts:** Q/K/V projections · scaled dot-product · why √d_k · multi-head attention · head specialization
**Requires:** nothing — standalone
**References:** [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762) · [CS231N Lecture 11](https://cs231n.stanford.edu/slides/2024/lecture_11.pdf)
**Output:** `assets/gradcam/attention_single_head.png` · `assets/gradcam/attention_multi_head.png`

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['figure.dpi'] = 120
torch.manual_seed(42)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")

## Part 1 — Scaled Dot-Product Attention

A transformer doesn't process each token in isolation. Every token looks at every other token simultaneously and decides how much to attend to each.

Three learned projections of each input token:
- **Query (Q)** — what this token is looking for
- **Key (K)** — what each token offers
- **Value (V)** — the content to aggregate

The attention score between two tokens is their Q·K dot product, scaled by √d_k. Softmax turns scores into weights. The output is a weighted sum of V vectors.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

**Why scale by √d_k?** With large d_k, dot products grow large and push softmax into flat regions with tiny gradients — the model stops learning. Scaling by √d_k keeps the variance of dot products at ~1 regardless of dimension size. ([Vaswani et al., 2017 — Section 3.2](https://arxiv.org/abs/1706.03762))

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Scaled dot-product attention.

    Args:
        q: Query  — shape (..., seq_len, d_k)
        k: Key    — shape (..., seq_len, d_k)
        v: Value  — shape (..., seq_len, d_v)
        mask: Optional boolean mask — shape (..., seq_len, seq_len)

    Returns:
        output:       shape (..., seq_len, d_v)
        attn_weights: shape (..., seq_len, seq_len)
    """
    d_k = q.size(-1)

    # Step 1: QK^T — how much each token attends to every other token
    scores = torch.matmul(q, k.transpose(-2, -1))       # (..., seq, seq)

    # Step 2: Scale — prevents large dot products from saturating softmax
    scores = scores / math.sqrt(d_k)

    # Step 3: Optional mask (used in decoders to block future tokens)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Step 4: Softmax — convert scores to weights that sum to 1
    attn_weights = F.softmax(scores, dim=-1)             # (..., seq, seq)

    # Step 5: Weighted sum of V
    output = torch.matmul(attn_weights, v)               # (..., seq, d_v)

    return output, attn_weights

In [None]:
# Verify shapes
batch_size = 2
seq_len    = 6
d_k        = 64

q = torch.randn(batch_size, seq_len, d_k)
k = torch.randn(batch_size, seq_len, d_k)
v = torch.randn(batch_size, seq_len, d_k)

output, attn_weights = scaled_dot_product_attention(q, k, v)

print(f"Q shape:              {q.shape}")
print(f"K shape:              {k.shape}")
print(f"V shape:              {v.shape}")
print(f"Output shape:         {output.shape}")          # (2, 6, 64)
print(f"Attn weights shape:   {attn_weights.shape}")    # (2, 6, 6)
print(f"Weights sum to 1:     {attn_weights[0].sum(dim=-1)}")  # all 1.0

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

for i, ax in enumerate(axes):
    im = ax.imshow(attn_weights[i].detach().numpy(), cmap='Blues', vmin=0, vmax=1)
    ax.set_title(f'Attention weights — batch {i}')
    ax.set_xlabel('Key position')
    ax.set_ylabel('Query position')
    plt.colorbar(im, ax=ax)

plt.suptitle('Scaled dot-product attention (random weights, untrained)', y=1.02)
plt.tight_layout()
plt.savefig('../assets/gradcam/attention_single_head.png', bbox_inches='tight', dpi=150)
plt.show()

## Part 2 — Multi-Head Attention

A single attention head learns one way of relating tokens. Multi-head attention runs `h` attention operations in **parallel**, each in a lower-dimensional subspace (`d_k = d_model / h`).

Each head can specialize: one might capture local relationships, another long-range structure. The outputs of all heads are concatenated and projected back to `d_model`.

**Shape flow:**
```
Input:            (batch, seq_len, d_model)
Project Q, K, V:  (batch, seq_len, d_model)  via W_q, W_k, W_v
Split into heads: (batch, num_heads, seq_len, d_k)  where d_k = d_model / num_heads
Attention:        (batch, num_heads, seq_len, d_k)
Merge heads:      (batch, seq_len, d_model)
Output project:   (batch, seq_len, d_model)  via W_o
```

([Vaswani et al., 2017 — Section 3.3](https://arxiv.org/abs/1706.03762) · [CS231N Lecture 11](https://cs231n.stanford.edu/slides/2024/lecture_11.pdf))

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model    = d_model
        self.num_heads  = num_heads
        self.d_k        = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        """(batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)"""
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def forward(self, x, mask=None):
        batch, seq_len, _ = x.shape

        # Project and split
        q = self.split_heads(self.W_q(x))   # (batch, heads, seq_len, d_k)
        k = self.split_heads(self.W_k(x))
        v = self.split_heads(self.W_v(x))

        # Attention across all heads
        attn_out, attn_weights = scaled_dot_product_attention(q, k, v, mask)
        # attn_out: (batch, heads, seq_len, d_k)

        # Merge heads
        attn_out = attn_out.transpose(1, 2).reshape(batch, seq_len, self.d_model)

        # Output projection
        output = self.W_o(attn_out)          # (batch, seq_len, d_model)

        return output, attn_weights

In [None]:
# Verify shapes
d_model    = 512
num_heads  = 8
seq_len    = 10
batch_size = 2

mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
x   = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = mha(x)

print(f"Input shape:          {x.shape}")             # (2, 10, 512)
print(f"Output shape:         {output.shape}")         # (2, 10, 512)
print(f"Attn weights shape:   {attn_weights.shape}")   # (2, 8, 10, 10)
print(f"d_k per head:         {d_model} / {num_heads} = {d_model // num_heads}")

## Shape Summary

| Tensor | Shape | Notes |
|---|---|---|
| Q, K, V (single head) | `(batch, seq_len, d_k)` | d_k = 64 in this notebook |
| Attention scores | `(batch, seq_len, seq_len)` | before softmax |
| Attention weights | `(batch, seq_len, seq_len)` | after softmax — rows sum to 1 |
| Attention output | `(batch, seq_len, d_k)` | weighted sum of V |
| Q, K, V (multi-head) | `(batch, num_heads, seq_len, d_k)` | split along d_model |
| MHA input / output | `(batch, seq_len, d_model)` | shape preserved — residual-friendly |
| Attention weights (MHA) | `(batch, num_heads, seq_len, seq_len)` | one matrix per head |

In [None]:
# Visualize each head's attention pattern
weights = attn_weights[0].detach()   # (num_heads, seq_len, seq_len)

fig, axes = plt.subplots(2, 4, figsize=(14, 6))

for head_idx, ax in enumerate(axes.flat):
    im = ax.imshow(weights[head_idx].numpy(), cmap='Blues', vmin=0, vmax=1)
    ax.set_title(f'Head {head_idx + 1}')
    ax.set_xlabel('Key pos')
    ax.set_ylabel('Query pos')

plt.suptitle('Multi-head attention — weights per head (random weights, untrained)', y=1.02)
plt.tight_layout()
plt.savefig('../assets/gradcam/attention_multi_head.png', bbox_inches='tight', dpi=150)
plt.show()

print("Uniform attention is expected with random weights.")
print("After training on image patches, heads specialize to attend to different spatial regions.")