# Attention from Scratch

**Learning Objectives:**
1. Implement scaled dot-product attention in NumPy
2. Build intuition for query, key, value through visualizations
3. Visualize attention weights as heatmaps
4. Implement causal (masked) attention for autoregressive models
5. Compare attention to RNN hidden state bottleneck

**Prerequisites:** [attention](../transformers/attention.md), [self-attention](../transformers/self-attention.md)

**Key Insight:** Attention replaces sequential processing with content-based routing. Every position can directly access every other position, weighted by learned relevance.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
np.random.seed(42)

## 1. The Attention Equation

The core attention formula:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Breaking this down:
- **Q (Query):** What we're looking for
- **K (Key):** What's available to match against
- **V (Value):** What we actually retrieve
- **Scaling by sqrt(d_k):** Prevents dot products from getting too large

In [None]:
def softmax(x, axis=-1):
    """Numerically stable softmax."""
    exp_x = np.exp(x - x.max(axis=axis, keepdims=True))
    return exp_x / exp_x.sum(axis=axis, keepdims=True)

def attention(Q, K, V, mask=None):
    """
    Scaled dot-product attention.
    
    Args:
        Q: Queries (seq_len_q, d_k)
        K: Keys (seq_len_k, d_k)
        V: Values (seq_len_k, d_v)
        mask: Optional mask (seq_len_q, seq_len_k), True = attend
    
    Returns:
        output: (seq_len_q, d_v)
        weights: (seq_len_q, seq_len_k)
    """
    d_k = K.shape[-1]
    
    # Step 1: Compute similarity scores
    scores = Q @ K.T / np.sqrt(d_k)  # (seq_len_q, seq_len_k)
    
    # Step 2: Apply mask if provided
    if mask is not None:
        scores = np.where(mask, scores, -np.inf)
    
    # Step 3: Convert to probabilities
    weights = softmax(scores, axis=-1)
    
    # Step 4: Weighted sum of values
    output = weights @ V
    
    return output, weights

## 2. Why Scale by sqrt(d_k)?

Without scaling, as dimension increases, dot products grow in variance. This pushes softmax into near-one-hot outputs where gradients vanish.

In [None]:
# Demonstrate the scaling problem
dimensions = [8, 32, 128, 512, 2048]

fig, axes = plt.subplots(2, len(dimensions), figsize=(15, 5))
fig.suptitle('Effect of Dimension on Dot Product Distribution', fontsize=12)

for i, d in enumerate(dimensions):
    # Random unit vectors
    q = np.random.randn(1000, d)
    k = np.random.randn(1000, d)
    
    # Unscaled dot products
    unscaled = np.sum(q * k, axis=1)
    
    # Scaled dot products
    scaled = unscaled / np.sqrt(d)
    
    # Plot unscaled
    axes[0, i].hist(unscaled, bins=30, density=True, alpha=0.7, color='red')
    axes[0, i].set_title(f'd={d}\nVar={unscaled.var():.1f}')
    axes[0, i].set_xlim(-60, 60)
    if i == 0:
        axes[0, i].set_ylabel('Unscaled')
    
    # Plot scaled
    axes[1, i].hist(scaled, bins=30, density=True, alpha=0.7, color='blue')
    axes[1, i].set_title(f'Var={scaled.var():.2f}')
    axes[1, i].set_xlim(-4, 4)
    if i == 0:
        axes[1, i].set_ylabel('Scaled by sqrt(d)')

plt.tight_layout()
plt.show()

print("\nWithout scaling: variance grows with dimension")
print("With scaling: variance stays around 1 regardless of dimension")

In [None]:
# Show effect on softmax
d = 512
seq_len = 10

Q = np.random.randn(seq_len, d)
K = np.random.randn(seq_len, d)

# Unscaled scores
scores_unscaled = Q @ K.T
weights_unscaled = softmax(scores_unscaled, axis=-1)

# Scaled scores
scores_scaled = scores_unscaled / np.sqrt(d)
weights_scaled = softmax(scores_scaled, axis=-1)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Unscaled
im1 = axes[0].imshow(weights_unscaled, cmap='Blues', vmin=0, vmax=1)
axes[0].set_title(f'Unscaled: Near one-hot\nMax weight: {weights_unscaled.max():.4f}')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')
plt.colorbar(im1, ax=axes[0])

# Scaled
im2 = axes[1].imshow(weights_scaled, cmap='Blues', vmin=0, vmax=1)
axes[1].set_title(f'Scaled: Smoother distribution\nMax weight: {weights_scaled.max():.4f}')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

## 3. Query-Key-Value Intuition

Think of attention like a **database lookup**:
- **Query:** Your search term
- **Keys:** Index entries
- **Values:** The actual data stored

Unlike a database, attention returns a *weighted combination* of all values, not just the best match.

In [None]:
# Simulate a "memory retrieval" system
# Memories: stored facts with associated vectors

memories = [
    ("Paris is the capital of France", np.array([0.9, 0.1, 0.0, 0.0])),
    ("Berlin is the capital of Germany", np.array([0.8, 0.2, 0.0, 0.0])),
    ("The Eiffel Tower is in Paris", np.array([0.7, 0.0, 0.3, 0.0])),
    ("French cuisine includes croissants", np.array([0.6, 0.0, 0.0, 0.4])),
    ("German cars are well-engineered", np.array([0.0, 0.9, 0.1, 0.0])),
]

# Keys: semantic vectors for matching
# Values: the actual facts (represented as one-hot for retrieval)
K = np.array([m[1] for m in memories])
V = np.eye(len(memories))  # One-hot encoding of each memory

# Query: "Tell me about France"
Q_france = np.array([[0.85, 0.0, 0.1, 0.05]])

# Query: "Tell me about German engineering"
Q_german = np.array([[0.0, 0.8, 0.2, 0.0]])

def show_retrieval(query, query_name):
    output, weights = attention(query, K, V)
    
    print(f"\n{query_name}:")
    print("-" * 50)
    for i, (fact, _) in enumerate(memories):
        bar = "#" * int(weights[0, i] * 40)
        print(f"{weights[0, i]:.3f} {bar}")
        print(f"      {fact}")

show_retrieval(Q_france, "Query: 'Tell me about France'")
show_retrieval(Q_german, "Query: 'Tell me about German engineering'")

## 4. Self-Attention: Sequence Talks to Itself

In self-attention, Q, K, and V all come from the same sequence. Each position asks: "Which other positions are relevant to me?"

In [None]:
class SelfAttention:
    """Self-attention with learnable projections."""
    
    def __init__(self, d_model, d_k=None, d_v=None):
        d_k = d_k or d_model
        d_v = d_v or d_model
        
        # Learnable projection matrices
        self.W_Q = np.random.randn(d_model, d_k) / np.sqrt(d_model)
        self.W_K = np.random.randn(d_model, d_k) / np.sqrt(d_model)
        self.W_V = np.random.randn(d_model, d_v) / np.sqrt(d_model)
        self.d_k = d_k
    
    def forward(self, X, mask=None):
        """
        Args:
            X: Input sequence (seq_len, d_model)
            mask: Optional mask (seq_len, seq_len)
        """
        Q = X @ self.W_Q
        K = X @ self.W_K
        V = X @ self.W_V
        
        return attention(Q, K, V, mask)

# Test self-attention
d_model = 32
seq_len = 6

sa = SelfAttention(d_model)
X = np.random.randn(seq_len, d_model)

output, weights = sa.forward(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"\nAttention weights (each row sums to 1):")
print(weights.round(3))
print(f"\nRow sums: {weights.sum(axis=1).round(3)}")

## 5. Visualizing Attention on Real Text

Let's create a simple example where we can interpret what the attention is doing.

In [None]:
def plot_attention_heatmap(tokens, weights, title="Attention Weights"):
    """
    Visualize attention weights as a heatmap.
    
    Args:
        tokens: List of token strings
        weights: Attention matrix (n, n)
        title: Plot title
    """
    fig, ax = plt.subplots(figsize=(8, 6))
    
    im = ax.imshow(weights, cmap='Blues', aspect='auto')
    
    # Set ticks
    ax.set_xticks(np.arange(len(tokens)))
    ax.set_yticks(np.arange(len(tokens)))
    ax.set_xticklabels(tokens, fontsize=10)
    ax.set_yticklabels(tokens, fontsize=10)
    
    # Rotate x labels
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    
    # Add text annotations
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            text = ax.text(j, i, f"{weights[i, j]:.2f}",
                          ha="center", va="center", 
                          color="white" if weights[i, j] > 0.5 else "black",
                          fontsize=8)
    
    ax.set_xlabel("Keys (attending to)")
    ax.set_ylabel("Queries (from)")
    ax.set_title(title)
    
    plt.colorbar(im, ax=ax, label="Attention Weight")
    plt.tight_layout()
    return fig, ax

In [None]:
# Simulate learned attention patterns
# Create embeddings where similar tokens have similar vectors

tokens = ["The", "cat", "sat", "on", "the", "mat"]

# Create embeddings with semantic structure
d_model = 16
embeddings = {
    "The": np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    "the": np.array([0.9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    "cat": np.array([0, 1, 0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    "sat": np.array([0, 0, 0, 1, 0, 0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    "on":  np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    "mat": np.array([0, 0.5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
}

# Build input matrix
X = np.array([embeddings.get(t, embeddings.get(t.lower(), np.random.randn(d_model))) for t in tokens])
X = X + np.random.randn(*X.shape) * 0.1  # Add noise

# Apply self-attention
sa = SelfAttention(d_model, d_k=8)
output, weights = sa.forward(X)

plot_attention_heatmap(tokens, weights, "Self-Attention: 'The cat sat on the mat'")
plt.show()

print("\nNotice: Each row shows which tokens that query position attends to.")
print("The pattern depends on the learned W_Q, W_K, W_V matrices.")

## 6. Causal (Masked) Attention

For autoregressive models (like GPT), position i can only attend to positions <= i. We enforce this with a causal mask.

In [None]:
def create_causal_mask(seq_len):
    """Create lower-triangular causal mask."""
    return np.tril(np.ones((seq_len, seq_len), dtype=bool))

# Visualize the mask
seq_len = 6
mask = create_causal_mask(seq_len)

fig, ax = plt.subplots(figsize=(6, 5))
ax.imshow(mask, cmap='Greens', aspect='auto')
ax.set_title('Causal Mask\n(True = can attend, False = blocked)')
ax.set_xlabel('Key Position (can see)')
ax.set_ylabel('Query Position')

for i in range(seq_len):
    for j in range(seq_len):
        symbol = "Y" if mask[i, j] else "N"
        color = "white" if mask[i, j] else "red"
        ax.text(j, i, symbol, ha="center", va="center", fontsize=14, color=color)

plt.tight_layout()
plt.show()

print("Position 0 can only see itself")
print("Position 3 can see positions 0, 1, 2, 3")
print("Position 5 can see all positions 0-5")

In [None]:
# Apply causal masking
tokens = ["The", "cat", "sat", "on", "the", "mat"]
mask = create_causal_mask(len(tokens))

output, weights_causal = sa.forward(X, mask=mask)

plot_attention_heatmap(tokens, weights_causal, "Causal Self-Attention (can't see future)")
plt.show()

print("\nNotice: The upper triangle is all zeros.")
print("Each position can only attend to itself and previous positions.")

In [None]:
# Compare causal vs bidirectional
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bidirectional (no mask)
_, weights_bi = sa.forward(X)
im1 = axes[0].imshow(weights_bi, cmap='Blues', aspect='auto')
axes[0].set_title('Bidirectional Attention\n(like BERT)')
axes[0].set_xticks(range(len(tokens)))
axes[0].set_yticks(range(len(tokens)))
axes[0].set_xticklabels(tokens)
axes[0].set_yticklabels(tokens)
axes[0].set_xlabel('Keys')
axes[0].set_ylabel('Queries')
plt.colorbar(im1, ax=axes[0])

# Causal (with mask)
im2 = axes[1].imshow(weights_causal, cmap='Blues', aspect='auto')
axes[1].set_title('Causal Attention\n(like GPT)')
axes[1].set_xticks(range(len(tokens)))
axes[1].set_yticks(range(len(tokens)))
axes[1].set_xticklabels(tokens)
axes[1].set_yticklabels(tokens)
axes[1].set_xlabel('Keys')
axes[1].set_ylabel('Queries')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

print("BERT uses bidirectional attention (sees entire context)")
print("GPT uses causal attention (can only see past)")

## 7. Attention vs RNN: The Bottleneck Problem

RNNs compress all information into a fixed-size hidden state. Attention provides direct access to all positions.

In [None]:
def simulate_rnn_information_flow(seq_len, decay=0.7):
    """
    Simulate how information from position i reaches position j in an RNN.
    With each step, information decays (vanishing gradient).
    """
    flow = np.zeros((seq_len, seq_len))
    
    for i in range(seq_len):
        for j in range(i, seq_len):
            # Information from i to j decays with distance
            distance = j - i
            flow[j, i] = decay ** distance
    
    return flow

def simulate_attention_information_flow(seq_len):
    """
    In attention, any position can directly access any other.
    Distance doesn't matter.
    """
    # Uniform attention as baseline (each position attends equally)
    flow = np.ones((seq_len, seq_len)) / seq_len
    return flow

In [None]:
seq_len = 10

rnn_flow = simulate_rnn_information_flow(seq_len, decay=0.7)
attn_flow = simulate_attention_information_flow(seq_len)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# RNN
im1 = axes[0].imshow(rnn_flow, cmap='YlOrRd', aspect='auto')
axes[0].set_title('RNN Information Flow\n(exponential decay with distance)')
axes[0].set_xlabel('Source Position')
axes[0].set_ylabel('Target Position')
plt.colorbar(im1, ax=axes[0], label='Information preserved')

# Attention
im2 = axes[1].imshow(attn_flow, cmap='YlOrRd', aspect='auto')
axes[1].set_title('Attention Information Flow\n(direct access, no decay)')
axes[1].set_xlabel('Source Position')
axes[1].set_ylabel('Target Position')
plt.colorbar(im2, ax=axes[1], label='Information preserved')

plt.tight_layout()
plt.show()

In [None]:
# Information preservation at different distances
distances = np.arange(0, 20)
decay_rates = [0.9, 0.7, 0.5]

plt.figure(figsize=(10, 5))

for decay in decay_rates:
    rnn_preservation = decay ** distances
    plt.plot(distances, rnn_preservation, 'o-', label=f'RNN (decay={decay})', alpha=0.7)

# Attention baseline (constant)
plt.axhline(y=1.0, color='green', linestyle='--', linewidth=2, label='Attention (direct)')

plt.xlabel('Distance between positions')
plt.ylabel('Information Preserved')
plt.title('RNN vs Attention: Long-Range Dependencies')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0, 1.1)
plt.show()

print("\nKey insight: RNN information decays exponentially with distance.")
print("Attention maintains constant O(1) path length between any two positions.")

## 8. Attention Patterns

Different attention patterns emerge for different relationships:

In [None]:
def create_attention_pattern(pattern_type, seq_len):
    """Create example attention patterns."""
    
    if pattern_type == "diagonal":
        # Self-attention (attend to self)
        weights = np.eye(seq_len)
        
    elif pattern_type == "previous":
        # Attend to previous token
        weights = np.zeros((seq_len, seq_len))
        for i in range(seq_len):
            weights[i, max(0, i-1)] = 1.0
            
    elif pattern_type == "first":
        # Attend to first token (like [CLS])
        weights = np.zeros((seq_len, seq_len))
        weights[:, 0] = 1.0
        
    elif pattern_type == "uniform":
        # Attend uniformly
        weights = np.ones((seq_len, seq_len)) / seq_len
        
    elif pattern_type == "local":
        # Attend to nearby tokens
        weights = np.zeros((seq_len, seq_len))
        for i in range(seq_len):
            for j in range(max(0, i-2), min(seq_len, i+3)):
                weights[i, j] = 1.0
        weights = weights / weights.sum(axis=1, keepdims=True)
        
    return weights

# Visualize different patterns
patterns = ["diagonal", "previous", "first", "uniform", "local"]
titles = [
    "Diagonal\n(attend to self)",
    "Previous\n(attend to prior token)",
    "First Token\n(like [CLS] aggregation)",
    "Uniform\n(average all)",
    "Local\n(nearby window)"
]

fig, axes = plt.subplots(1, 5, figsize=(18, 3.5))

for i, (pattern, title) in enumerate(zip(patterns, titles)):
    weights = create_attention_pattern(pattern, 8)
    im = axes[i].imshow(weights, cmap='Blues', aspect='auto', vmin=0, vmax=1)
    axes[i].set_title(title, fontsize=10)
    axes[i].set_xlabel('Key')
    if i == 0:
        axes[i].set_ylabel('Query')

plt.tight_layout()
plt.show()

print("\nThese patterns emerge naturally during training.")
print("Different attention heads often specialize in different patterns.")

## 9. Multi-Head Attention

Multiple attention heads learn different patterns in parallel:

In [None]:
class MultiHeadAttention:
    """Multi-head self-attention."""
    
    def __init__(self, d_model, n_heads):
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.d_model = d_model
        
        # Combined projections for efficiency
        self.W_Q = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_K = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_V = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_O = np.random.randn(d_model, d_model) / np.sqrt(d_model)
    
    def forward(self, X, mask=None):
        """
        Args:
            X: Input (seq_len, d_model)
            mask: Optional mask (seq_len, seq_len)
        Returns:
            output: (seq_len, d_model)
            weights: (n_heads, seq_len, seq_len)
        """
        seq_len = X.shape[0]
        
        # Project and reshape for heads
        Q = (X @ self.W_Q).reshape(seq_len, self.n_heads, self.d_k).transpose(1, 0, 2)
        K = (X @ self.W_K).reshape(seq_len, self.n_heads, self.d_k).transpose(1, 0, 2)
        V = (X @ self.W_V).reshape(seq_len, self.n_heads, self.d_k).transpose(1, 0, 2)
        
        # Attention for each head: (n_heads, seq_len, seq_len)
        scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(self.d_k)
        
        if mask is not None:
            scores = np.where(mask[None, :, :], scores, -np.inf)
        
        weights = softmax(scores, axis=-1)
        
        # Weighted values: (n_heads, seq_len, d_k)
        head_outputs = np.matmul(weights, V)
        
        # Concatenate and project: (seq_len, d_model)
        concat = head_outputs.transpose(1, 0, 2).reshape(seq_len, self.d_model)
        output = concat @ self.W_O
        
        return output, weights

# Test multi-head attention
d_model = 64
n_heads = 8
seq_len = 6

mha = MultiHeadAttention(d_model, n_heads)
X = np.random.randn(seq_len, d_model)

output, all_weights = mha.forward(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {all_weights.shape}")
print(f"  ({n_heads} heads, each with {seq_len}x{seq_len} attention matrix)")

In [None]:
# Visualize attention patterns across heads
tokens = ["The", "cat", "sat", "on", "the", "mat"]

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for head_idx in range(min(8, n_heads)):
    im = axes[head_idx].imshow(all_weights[head_idx], cmap='Blues', aspect='auto')
    axes[head_idx].set_title(f'Head {head_idx}', fontsize=10)
    axes[head_idx].set_xticks(range(len(tokens)))
    axes[head_idx].set_yticks(range(len(tokens)))
    axes[head_idx].set_xticklabels(tokens, fontsize=8, rotation=45)
    axes[head_idx].set_yticklabels(tokens, fontsize=8)

plt.suptitle('Multi-Head Attention: Different Heads Learn Different Patterns', fontsize=12)
plt.tight_layout()
plt.show()

print("\nEach head can specialize in different relationships:")
print("- Syntactic patterns (subject-verb)")
print("- Positional patterns (previous token, next token)")
print("- Semantic patterns (related concepts)")

## 10. Positional Encoding

Attention is **permutation equivariant** — it doesn't know token order. We add positional information explicitly.

In [None]:
def sinusoidal_position_encoding(seq_len, d_model):
    """
    Sinusoidal positional encoding from "Attention Is All You Need".
    
    PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    """
    position = np.arange(seq_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    
    pe = np.zeros((seq_len, d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    
    return pe

# Visualize positional encodings
seq_len = 50
d_model = 64

pe = sinusoidal_position_encoding(seq_len, d_model)

plt.figure(figsize=(12, 6))
plt.imshow(pe, cmap='RdBu', aspect='auto')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Sinusoidal Positional Encoding')
plt.colorbar(label='Value')
plt.show()

print("Each position has a unique encoding.")
print("Low dimensions change quickly (high frequency).")
print("High dimensions change slowly (low frequency).")

In [None]:
# Show that nearby positions have similar encodings
def positional_similarity(pe):
    """Compute cosine similarity between position encodings."""
    norm = np.linalg.norm(pe, axis=1, keepdims=True)
    pe_normalized = pe / norm
    return pe_normalized @ pe_normalized.T

pe = sinusoidal_position_encoding(50, 64)
sim = positional_similarity(pe)

plt.figure(figsize=(8, 6))
plt.imshow(sim, cmap='RdYlGn', aspect='auto')
plt.xlabel('Position j')
plt.ylabel('Position i')
plt.title('Positional Encoding Similarity\n(nearby positions are more similar)')
plt.colorbar(label='Cosine Similarity')
plt.show()

print("\nKey property: Similarity decreases smoothly with distance.")
print("This gives the model information about relative positions.")

## 11. Computational Complexity

Attention is O(n^2) in sequence length — both a strength and a limitation.

In [None]:
import time

def benchmark_attention(seq_lengths, d_model=64):
    """Measure attention computation time vs sequence length."""
    times = []
    
    for seq_len in seq_lengths:
        X = np.random.randn(seq_len, d_model)
        sa = SelfAttention(d_model, d_k=32)
        
        # Warm up
        _ = sa.forward(X)
        
        # Time multiple runs
        n_runs = 10
        start = time.time()
        for _ in range(n_runs):
            _ = sa.forward(X)
        elapsed = (time.time() - start) / n_runs
        times.append(elapsed)
    
    return times

seq_lengths = [16, 32, 64, 128, 256, 512]
times = benchmark_attention(seq_lengths)

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.plot(seq_lengths, times, 'bo-', markersize=8)
plt.xlabel('Sequence Length')
plt.ylabel('Time (seconds)')
plt.title('Attention Computation Time')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.loglog(seq_lengths, times, 'bo-', markersize=8, label='Measured')
# Theoretical O(n^2) line
theoretical = [(s/seq_lengths[0])**2 * times[0] for s in seq_lengths]
plt.loglog(seq_lengths, theoretical, 'r--', label='O(n^2) theoretical')
plt.xlabel('Sequence Length (log)')
plt.ylabel('Time (log)')
plt.title('Log-Log Scale (slope ~ 2)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nImplication: Doubling sequence length quadruples compute time.")
print("This is why long-context models use efficient attention variants.")

## Summary

| Concept | Formula | Purpose |
|---------|---------|---------||
| Attention | softmax(QK^T/sqrt(d_k))V | Content-based routing |
| Scaling | / sqrt(d_k) | Stabilize softmax gradients |
| Causal Mask | Lower triangular | Prevent seeing future |
| Multi-Head | Parallel attention | Learn different patterns |
| Position Encoding | X + PE | Inject order information |

**Key Insights:**
1. Attention provides **O(1) path length** between any positions (vs O(n) for RNN)
2. Self-attention is **permutation equivariant** — needs positional encoding
3. **Multi-head attention** learns diverse patterns in parallel
4. **Causal masking** enables autoregressive generation
5. **O(n^2) complexity** limits sequence length; efficient variants exist

**Next:** [08-minimal-transformer.ipynb](08-minimal-transformer.ipynb) builds a complete transformer block with attention, feedforward, and residual connections.

## Exercises

1. **Temperature in Attention:** Modify the attention function to accept a temperature parameter T. What happens with T < 1 (sharper)? T > 1 (softer)?

2. **Relative Position:** Implement relative positional encoding where attention scores include a bias based on position difference (i - j).

3. **Sparse Attention:** Implement a local attention pattern where each position only attends to a window of k neighbors. How does this reduce complexity?

4. **Cross-Attention:** Modify the code to support cross-attention where Q comes from one sequence and K, V from another (like encoder-decoder attention).