# 12: KV Cache Demo

## Learning Objectives

1. Understand why naive autoregressive generation is slow
2. Implement KV (Key-Value) caching from scratch
3. Benchmark the speedup from caching
4. Visualize what gets cached and why
5. Explore memory vs speed tradeoffs

**Prerequisites:** [attention](../transformers/attention.md), [GPT](../transformers/gpt.md), [inference](../modern-llms/inference.md)

**Framework:** PyTorch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
from dataclasses import dataclass
from typing import Optional, Tuple, List

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

## Part 1: The Problem - Redundant Computation

In autoregressive generation, each new token requires attending to ALL previous tokens. Without caching, we recompute K and V for every previous token at every step.

In [None]:
def visualize_redundant_computation():
    """Visualize the redundant work in naive generation."""
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Left: Naive approach
    ax = axes[0]
    seq_len = 6
    
    # Create a matrix showing what's computed at each step
    computation = np.zeros((seq_len, seq_len))
    for step in range(seq_len):
        for j in range(step + 1):
            computation[step, j] = step + 1  # Color by step
    
    im = ax.imshow(computation, cmap='Blues', aspect='auto')
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Generation Step')
    ax.set_title('Naive: Recompute K,V for ALL tokens each step\n(Darker = computed at that step)')
    ax.set_xticks(range(seq_len))
    ax.set_yticks(range(seq_len))
    
    # Add text annotations
    for i in range(seq_len):
        for j in range(seq_len):
            if computation[i, j] > 0:
                ax.text(j, i, 'K,V', ha='center', va='center', fontsize=8, color='white' if computation[i,j] > 3 else 'black')
    
    # Right: With KV cache
    ax = axes[1]
    
    # Only new tokens are computed
    cached = np.zeros((seq_len, seq_len))
    for step in range(seq_len):
        cached[step, step] = step + 1  # Only diagonal
        # Show cached values as lighter
        for j in range(step):
            cached[step, j] = 0.3  # Cached
    
    im = ax.imshow(cached, cmap='Greens', aspect='auto', vmin=0, vmax=seq_len)
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Generation Step')
    ax.set_title('KV Cache: Only compute NEW token, reuse cached\n(Dark = new, Light = cached)')
    ax.set_xticks(range(seq_len))
    ax.set_yticks(range(seq_len))
    
    # Add text annotations
    for i in range(seq_len):
        for j in range(seq_len):
            if j == i:
                ax.text(j, i, 'NEW', ha='center', va='center', fontsize=8, color='white')
            elif j < i:
                ax.text(j, i, 'cache', ha='center', va='center', fontsize=7, color='darkgreen')
    
    plt.tight_layout()
    plt.savefig('kv_cache_concept.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Compute complexity comparison
    n = 1000  # sequence length
    d = 768   # hidden dimension
    
    naive_ops = sum((i+1) for i in range(n))  # 1 + 2 + 3 + ... + n = n(n+1)/2
    cached_ops = n  # 1 + 1 + 1 + ... = n
    
    print(f"\nComplexity comparison (n={n} tokens):")
    print(f"  Naive: O(n^2) = {naive_ops:,} K,V computations")
    print(f"  KV Cache: O(n) = {cached_ops:,} K,V computations")
    print(f"  Speedup: {naive_ops / cached_ops:.1f}x fewer computations")

visualize_redundant_computation()

## Part 2: Implementing a GPT with KV Cache

We'll build a small GPT model with both naive and cached generation modes.

In [None]:
@dataclass
class GPTConfig:
    vocab_size: int = 1000
    d_model: int = 256
    n_heads: int = 4
    n_layers: int = 4
    max_seq_len: int = 512
    dropout: float = 0.1


class KVCache:
    """Key-Value cache for efficient autoregressive generation."""
    
    def __init__(self, batch_size: int, max_seq_len: int, n_layers: int, 
                 n_heads: int, head_dim: int, device: torch.device):
        self.max_seq_len = max_seq_len
        self.n_layers = n_layers
        self.current_len = 0
        
        # Pre-allocate cache tensors for all layers
        # Shape: [batch, n_heads, max_seq_len, head_dim]
        self.cache_k = torch.zeros(
            n_layers, batch_size, n_heads, max_seq_len, head_dim,
            device=device
        )
        self.cache_v = torch.zeros(
            n_layers, batch_size, n_heads, max_seq_len, head_dim,
            device=device
        )
    
    def update(self, layer_idx: int, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Update cache with new K, V and return full cached K, V.
        
        Args:
            k, v: [batch, n_heads, seq_len, head_dim] - new K, V to add
        Returns:
            Full K, V including new values: [batch, n_heads, total_len, head_dim]
        """
        seq_len = k.shape[2]
        start = self.current_len
        end = start + seq_len
        
        # Store new K, V in cache
        self.cache_k[layer_idx, :, :, start:end, :] = k
        self.cache_v[layer_idx, :, :, start:end, :] = v
        
        # Return full cached K, V up to current position
        return self.cache_k[layer_idx, :, :, :end, :], self.cache_v[layer_idx, :, :, :end, :]
    
    def increment(self, n: int = 1):
        """Increment the sequence length counter."""
        self.current_len += n
    
    def reset(self):
        """Reset the cache for a new sequence."""
        self.current_len = 0
        self.cache_k.zero_()
        self.cache_v.zero_()
    
    def memory_usage(self) -> float:
        """Return memory usage in MB."""
        k_mem = self.cache_k.numel() * self.cache_k.element_size()
        v_mem = self.cache_v.numel() * self.cache_v.element_size()
        return (k_mem + v_mem) / (1024 ** 2)

In [None]:
class CausalSelfAttention(nn.Module):
    """Multi-head self-attention with optional KV caching."""
    
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.d_model % config.n_heads == 0
        
        self.n_heads = config.n_heads
        self.head_dim = config.d_model // config.n_heads
        self.d_model = config.d_model
        
        # Combined Q, K, V projection
        self.qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=False)
        self.proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout)
        
        # Causal mask
        self.register_buffer(
            'mask',
            torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
            .view(1, 1, config.max_seq_len, config.max_seq_len)
        )
    
    def forward(self, x: torch.Tensor, cache: Optional[KVCache] = None, 
                layer_idx: int = 0) -> torch.Tensor:
        """
        Forward pass with optional KV caching.
        
        Args:
            x: [batch, seq_len, d_model]
            cache: Optional KV cache for generation
            layer_idx: Layer index for cache lookup
        """
        B, T, C = x.shape
        
        # Compute Q, K, V
        qkv = self.qkv(x)  # [B, T, 3*C]
        q, k, v = qkv.split(self.d_model, dim=-1)
        
        # Reshape for multi-head attention
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # [B, nh, T, hd]
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Use KV cache if provided
        if cache is not None:
            k, v = cache.update(layer_idx, k, v)
            # k, v now have shape [B, nh, cache_len + T, hd]
        
        # Compute attention
        # Q has T positions, K has (cache_len + T) positions
        scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Apply causal mask (only for new tokens attending to past)
        kv_len = k.shape[2]  # Total length including cache
        if cache is not None:
            # During generation, Q only has 1 token but can attend to all cached K
            # The causal constraint is automatically satisfied since we only 
            # generate one token at a time
            pass  # No explicit masking needed for single-token generation
        else:
            # During training/prefill, apply causal mask
            scores = scores.masked_fill(
                self.mask[:, :, :T, :T] == 0,
                float('-inf')
            )
        
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        
        # Apply attention to values
        out = weights @ v  # [B, nh, T, hd]
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        
        return self.proj(out)


class FeedForward(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, 4 * config.d_model)
        self.fc2 = nn.Linear(4 * config.d_model, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x):
        return self.dropout(self.fc2(F.gelu(self.fc1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.ffn = FeedForward(config)
    
    def forward(self, x, cache=None, layer_idx=0):
        x = x + self.attn(self.ln1(x), cache, layer_idx)
        x = x + self.ffn(self.ln2(x))
        return x

In [None]:
class GPT(nn.Module):
    """GPT model with support for KV caching."""
    
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        
        self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])
        
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        
        # Weight tying
        self.tok_emb.weight = self.head.weight
    
    def forward(self, idx: torch.Tensor, cache: Optional[KVCache] = None,
                start_pos: int = 0) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            idx: [batch, seq_len] token indices
            cache: Optional KV cache
            start_pos: Starting position for positional embeddings (for cached generation)
        """
        B, T = idx.shape
        
        # Embeddings
        pos = torch.arange(start_pos, start_pos + T, device=idx.device)
        x = self.dropout(self.tok_emb(idx) + self.pos_emb(pos))
        
        # Transformer blocks
        for layer_idx, block in enumerate(self.blocks):
            x = block(x, cache, layer_idx)
        
        x = self.ln_f(x)
        logits = self.head(x)
        
        return logits
    
    @torch.no_grad()
    def generate_naive(self, idx: torch.Tensor, max_new_tokens: int,
                       temperature: float = 1.0) -> torch.Tensor:
        """
        Naive generation: recompute everything at each step.
        """
        self.train(False)
        
        for _ in range(max_new_tokens):
            # Truncate if needed
            idx_cond = idx[:, -self.config.max_seq_len:]
            
            # Forward pass over ENTIRE sequence
            logits = self(idx_cond)
            
            # Sample next token
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            idx = torch.cat([idx, next_token], dim=1)
        
        return idx
    
    @torch.no_grad()
    def generate_cached(self, idx: torch.Tensor, max_new_tokens: int,
                        temperature: float = 1.0) -> torch.Tensor:
        """
        Cached generation: only compute new token at each step.
        """
        self.train(False)
        B, T = idx.shape
        
        # Initialize KV cache
        cache = KVCache(
            batch_size=B,
            max_seq_len=self.config.max_seq_len,
            n_layers=self.config.n_layers,
            n_heads=self.config.n_heads,
            head_dim=self.config.d_model // self.config.n_heads,
            device=idx.device
        )
        
        # Prefill: process entire prompt
        logits = self(idx, cache, start_pos=0)
        cache.increment(T)
        
        # Sample first new token
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_token], dim=1)
        
        # Decode: one token at a time using cache
        for i in range(max_new_tokens - 1):
            # Only process the NEW token
            logits = self(next_token, cache, start_pos=T + i)
            cache.increment(1)
            
            # Sample next token
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            idx = torch.cat([idx, next_token], dim=1)
        
        return idx


# Create model
config = GPTConfig()
model = GPT(config).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Part 3: Benchmarking Naive vs Cached Generation

In [None]:
def benchmark_generation(model, prompt_len, gen_len, n_runs=5):
    """
    Benchmark naive vs cached generation.
    
    Returns:
        naive_time, cached_time (both in seconds)
    """
    # Create prompt
    prompt = torch.randint(0, config.vocab_size, (1, prompt_len), device=device)
    
    # Warmup
    _ = model.generate_naive(prompt.clone(), 5)
    _ = model.generate_cached(prompt.clone(), 5)
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark naive
    naive_times = []
    for _ in range(n_runs):
        start = time.perf_counter()
        _ = model.generate_naive(prompt.clone(), gen_len)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        naive_times.append(time.perf_counter() - start)
    
    # Benchmark cached
    cached_times = []
    for _ in range(n_runs):
        start = time.perf_counter()
        _ = model.generate_cached(prompt.clone(), gen_len)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        cached_times.append(time.perf_counter() - start)
    
    return np.mean(naive_times), np.mean(cached_times)


# Benchmark for different generation lengths
print("Benchmarking generation speed...")
print("=" * 60)

prompt_len = 32
gen_lengths = [8, 16, 32, 64, 128]
results = []

for gen_len in gen_lengths:
    naive_t, cached_t = benchmark_generation(model, prompt_len, gen_len)
    speedup = naive_t / cached_t
    results.append({
        'gen_len': gen_len,
        'naive': naive_t,
        'cached': cached_t,
        'speedup': speedup
    })
    print(f"Gen {gen_len:3d} tokens: Naive={naive_t:.3f}s, Cached={cached_t:.3f}s, Speedup={speedup:.2f}x")

In [None]:
# Visualize benchmark results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

gen_lens = [r['gen_len'] for r in results]
naive_times = [r['naive'] for r in results]
cached_times = [r['cached'] for r in results]
speedups = [r['speedup'] for r in results]

# Time comparison
axes[0].plot(gen_lens, naive_times, 'r-o', label='Naive', linewidth=2, markersize=8)
axes[0].plot(gen_lens, cached_times, 'g-o', label='KV Cache', linewidth=2, markersize=8)
axes[0].set_xlabel('Tokens Generated')
axes[0].set_ylabel('Time (seconds)')
axes[0].set_title('Generation Time Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Speedup
bars = axes[1].bar(gen_lens, speedups, color='steelblue', width=10)
axes[1].axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
axes[1].set_xlabel('Tokens Generated')
axes[1].set_ylabel('Speedup (x)')
axes[1].set_title('KV Cache Speedup Factor')
axes[1].grid(True, alpha=0.3, axis='y')

# Add value labels
for bar, speedup in zip(bars, speedups):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height(),
                 f'{speedup:.1f}x', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig('kv_cache_speedup.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nAverage speedup: {np.mean(speedups):.2f}x")

## Part 4: Scaling Analysis - How Speedup Grows with Sequence Length

In [None]:
def theoretical_speedup(prompt_len, gen_len):
    """
    Compute theoretical speedup from KV caching.
    
    Naive: Each new token requires attention over all previous tokens
    - Token 1: attend to 1+prompt_len tokens
    - Token 2: attend to 2+prompt_len tokens
    - ...
    - Token n: attend to n+prompt_len tokens
    Total: sum of (prompt_len + i) for i in 1..gen_len
           = gen_len * prompt_len + gen_len * (gen_len + 1) / 2
    
    Cached: 
    - Prefill: prompt_len^2 (full attention over prompt)
    - Each decode: attend to (prompt_len + i) tokens, but only compute Q for 1 token
    - Total attention computations: prompt_len^2 + sum of (prompt_len + i) for i in 1..gen_len
    - BUT K,V computations are only gen_len (one per new token)
    
    The key savings is in K,V projection, not attention.
    """
    # K,V computation counts
    naive_kv = sum(prompt_len + i for i in range(1, gen_len + 1))
    cached_kv = prompt_len + gen_len  # Prefill once, then one per token
    
    return naive_kv / cached_kv


# Visualize theoretical speedup
prompt_lens = [16, 32, 64, 128]
gen_lens = np.arange(1, 257, 4)

plt.figure(figsize=(10, 5))

for prompt_len in prompt_lens:
    speedups = [theoretical_speedup(prompt_len, g) for g in gen_lens]
    plt.plot(gen_lens, speedups, label=f'Prompt={prompt_len}', linewidth=2)

plt.xlabel('Tokens Generated')
plt.ylabel('Theoretical Speedup (K,V computations)')
plt.title('KV Cache Theoretical Speedup\n(Grows approximately linearly with generation length)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('kv_cache_theoretical_speedup.png', dpi=150, bbox_inches='tight')
plt.show()

print("Key insight: Speedup grows roughly linearly with generation length.")
print("For long generations, KV caching provides massive speedups.")

## Part 5: Memory Usage Analysis

In [None]:
def compute_kv_cache_memory(batch_size, seq_len, n_layers, d_model, n_heads, dtype_bytes=4):
    """
    Compute KV cache memory in bytes.
    
    For each layer:
        K: [batch, n_heads, seq_len, head_dim]
        V: [batch, n_heads, seq_len, head_dim]
    """
    head_dim = d_model // n_heads
    per_layer = 2 * batch_size * n_heads * seq_len * head_dim * dtype_bytes
    total = n_layers * per_layer
    return total


def visualize_memory_tradeoff():
    """Show memory vs speed tradeoff."""
    
    # Simulate a larger model
    batch_sizes = [1, 2, 4, 8, 16, 32]
    seq_lens = [512, 1024, 2048, 4096, 8192]
    
    # Model config (simulating a ~7B model)
    d_model = 4096
    n_heads = 32
    n_layers = 32
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Memory vs sequence length (batch=1)
    ax = axes[0]
    memory_gb = [compute_kv_cache_memory(1, s, n_layers, d_model, n_heads) / 1e9 for s in seq_lens]
    
    ax.bar([str(s) for s in seq_lens], memory_gb, color='steelblue')
    ax.set_xlabel('Sequence Length')
    ax.set_ylabel('KV Cache Memory (GB)')
    ax.set_title('KV Cache Memory vs Sequence Length\n(7B model, batch=1, fp32)')
    ax.grid(True, alpha=0.3, axis='y')
    
    for i, (s, m) in enumerate(zip(seq_lens, memory_gb)):
        ax.text(i, m, f'{m:.1f} GB', ha='center', va='bottom', fontsize=9)
    
    # Memory vs batch size (seq=2048)
    ax = axes[1]
    memory_gb = [compute_kv_cache_memory(b, 2048, n_layers, d_model, n_heads) / 1e9 for b in batch_sizes]
    
    ax.bar([str(b) for b in batch_sizes], memory_gb, color='green')
    ax.set_xlabel('Batch Size')
    ax.set_ylabel('KV Cache Memory (GB)')
    ax.set_title('KV Cache Memory vs Batch Size\n(7B model, seq=2048, fp32)')
    ax.grid(True, alpha=0.3, axis='y')
    
    for i, (b, m) in enumerate(zip(batch_sizes, memory_gb)):
        ax.text(i, m, f'{m:.1f} GB', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.savefig('kv_cache_memory.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nMemory analysis for 7B model (32 layers, d=4096):")
    print("-" * 50)
    print(f"Model weights (fp16): ~14 GB")
    print(f"KV cache (seq=4096, batch=1, fp16): ~{compute_kv_cache_memory(1, 4096, 32, 4096, 32, 2)/1e9:.1f} GB")
    print(f"KV cache (seq=4096, batch=8, fp16): ~{compute_kv_cache_memory(8, 4096, 32, 4096, 32, 2)/1e9:.1f} GB")
    print("\nKey insight: KV cache memory scales with batch_size * seq_len")
    print("For long sequences or large batches, cache can exceed model size!")

visualize_memory_tradeoff()

## Part 6: What's Actually Being Cached?

In [None]:
def visualize_cache_contents(model, prompt_len=16, gen_len=8):
    """Visualize what gets stored in the KV cache."""
    
    # Generate with cache and capture cache state
    prompt = torch.randint(0, config.vocab_size, (1, prompt_len), device=device)
    
    # Initialize cache
    cache = KVCache(
        batch_size=1,
        max_seq_len=config.max_seq_len,
        n_layers=config.n_layers,
        n_heads=config.n_heads,
        head_dim=config.d_model // config.n_heads,
        device=device
    )
    
    model.train(False)
    
    # Prefill
    with torch.no_grad():
        _ = model(prompt, cache, start_pos=0)
    cache.increment(prompt_len)
    
    # Capture cache state after prefill
    prefill_k = cache.cache_k[0, 0, 0, :prompt_len, :8].cpu().numpy()  # First layer, first head, first 8 dims
    prefill_v = cache.cache_v[0, 0, 0, :prompt_len, :8].cpu().numpy()
    
    # Continue generation
    next_token = torch.randint(0, config.vocab_size, (1, 1), device=device)
    for i in range(gen_len):
        with torch.no_grad():
            _ = model(next_token, cache, start_pos=prompt_len + i)
        cache.increment(1)
        next_token = torch.randint(0, config.vocab_size, (1, 1), device=device)
    
    # Capture full cache state
    total_len = prompt_len + gen_len
    full_k = cache.cache_k[0, 0, 0, :total_len, :8].cpu().numpy()
    full_v = cache.cache_v[0, 0, 0, :total_len, :8].cpu().numpy()
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(14, 8))
    
    # K cache after prefill
    im1 = axes[0, 0].imshow(prefill_k.T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2)
    axes[0, 0].set_xlabel('Token Position')
    axes[0, 0].set_ylabel('Key Dimension')
    axes[0, 0].set_title(f'Key Cache After Prefill (Layer 0, Head 0)\n({prompt_len} prompt tokens)')
    axes[0, 0].axvline(x=prompt_len-0.5, color='green', linestyle='--', linewidth=2)
    plt.colorbar(im1, ax=axes[0, 0])
    
    # K cache after generation
    im2 = axes[0, 1].imshow(full_k.T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2)
    axes[0, 1].set_xlabel('Token Position')
    axes[0, 1].set_ylabel('Key Dimension')
    axes[0, 1].set_title(f'Key Cache After Generation\n({prompt_len} prompt + {gen_len} generated)')
    axes[0, 1].axvline(x=prompt_len-0.5, color='green', linestyle='--', linewidth=2, label='Prompt/Gen boundary')
    axes[0, 1].legend(loc='upper right')
    plt.colorbar(im2, ax=axes[0, 1])
    
    # V cache after prefill
    im3 = axes[1, 0].imshow(prefill_v.T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2)
    axes[1, 0].set_xlabel('Token Position')
    axes[1, 0].set_ylabel('Value Dimension')
    axes[1, 0].set_title(f'Value Cache After Prefill (Layer 0, Head 0)')
    axes[1, 0].axvline(x=prompt_len-0.5, color='green', linestyle='--', linewidth=2)
    plt.colorbar(im3, ax=axes[1, 0])
    
    # V cache after generation
    im4 = axes[1, 1].imshow(full_v.T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2)
    axes[1, 1].set_xlabel('Token Position')
    axes[1, 1].set_ylabel('Value Dimension')
    axes[1, 1].set_title(f'Value Cache After Generation')
    axes[1, 1].axvline(x=prompt_len-0.5, color='green', linestyle='--', linewidth=2)
    plt.colorbar(im4, ax=axes[1, 1])
    
    plt.tight_layout()
    plt.savefig('kv_cache_contents.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nCache statistics:")
    print(f"  Total K entries: {total_len}")
    print(f"  Total V entries: {total_len}")
    print(f"  Memory used: {cache.memory_usage():.2f} MB")

visualize_cache_contents(model)

## Part 7: Summary and Best Practices

In [None]:
print("""
KV CACHING: SUMMARY
===================

What is KV Caching?
-------------------
During autoregressive generation, cache the Key and Value tensors
from previous tokens. When generating the next token, only compute
K and V for the new token, reusing cached values for attention.

Why It Works:
-------------
- Attention: Q_new @ K_all.T @ V_all
- Without cache: Recompute K_all and V_all every step
- With cache: K_all and V_all already stored, just append new K,V

Complexity:
-----------
- Naive generation of n tokens: O(n^2) K,V computations
- Cached generation of n tokens: O(n) K,V computations
- Speedup grows roughly linearly with generation length

Memory Cost:
------------
- Per layer: 2 * batch * seq_len * d_model * dtype_size
- Total: n_layers * (above)
- For 7B model (32 layers, d=4096), seq=4096, batch=1:
  - fp16: ~4 GB
  - fp32: ~8 GB

Best Practices:
---------------
1. Pre-allocate cache to max_seq_len to avoid dynamic allocation
2. Use fp16 for cache to reduce memory by 2x
3. For very long sequences, consider chunked/paged attention
4. Batch generation requests when possible to amortize prefill

Advanced Techniques:
--------------------
- PagedAttention (vLLM): Virtual memory for KV cache
- Grouped-Query Attention: Share K,V across query heads
- Multi-Query Attention: Single K,V for all heads
- Sliding Window: Only cache recent N tokens
""")

## Exercises

1. **Batch generation**: Modify the cached generation to handle multiple prompts simultaneously. How does batch size affect speedup?

2. **Memory optimization**: Implement a version that uses fp16 for the cache. Measure memory savings.

3. **Sliding window**: Implement a sliding window cache that only stores the last N tokens. Compare quality vs memory.

4. **Grouped-query attention**: Implement GQA where K,V are shared across groups of query heads. How much memory is saved?

5. **Prefill chunking**: For very long prompts, implement chunked prefill to avoid memory spikes.

## Summary

| Concept | Key Point |
|---------|----------|
| Problem | Naive generation recomputes K,V for all tokens each step |
| Solution | Cache K,V tensors, only compute for new tokens |
| Speedup | Roughly linear with generation length (n^2 -> n) |
| Memory | Grows with batch_size * seq_len * n_layers * d_model |
| Prefill | Process entire prompt once, populate cache |
| Decode | One token at a time, append to cache |

**Key insight:** KV caching is essential for practical LLM inference. Without it, generating 1000 tokens would require ~500x more computation than necessary. The memory cost is the tradeoffâ€”for long sequences or large batches, the cache can become substantial.