# KV-Cache

**Optimizing inference from O(n²) to O(n)**

## The Problem: Slow Autoregressive Generation

When generating text, transformers produce one token at a time. After generating each token, we feed the entire sequence back through the model to predict the next token. This means we repeatedly recompute the same values!

### Example without cache:

```python
# Generate "The cat sat"

# Step 1: Generate token 3
Input: [The, cat]
Compute: K[The], V[The], K[cat], V[cat]
Output: "sat" ✓

# Step 2: Generate token 4
Input: [The, cat, sat]
Compute: K[The], V[The], K[cat], V[cat], K[sat], V[sat]  ← Redundant!
Output: "on"

# Step 3: Generate token 5
Input: [The, cat, sat, on]
Compute: K[The], V[The], K[cat], V[cat], K[sat], V[sat], K[on], V[on]  ← Redundant!
Output: "the"
```

For generating n tokens, we process 1 + 2 + 3 + ... + n = **O(n²)** tokens total. Very slow!

## The Solution: KV-Cache

**Key Insight:** In attention, K (Key) and V (Value) for past tokens never change! Only the new token's query matters. We can cache K and V from previous steps and reuse them.

## How It Works

**Two Modes:**

- **PREFILL:** Process initial prompt, compute and cache K, V for all tokens
- **DECODE:** For each new token, compute only its K, V, concatenate with cached values

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time

class AttentionWithCache(nn.Module):
    """Multi-head attention with KV-cache support."""
    
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, kv_cache=None):
        """
        Args:
            x: (batch, seq_len, d_model) - new tokens to process
            kv_cache: tuple of (K_cached, V_cached) or None for first pass
        
        Returns:
            output: (batch, seq_len, d_model)
            new_kv_cache: tuple of updated (K, V) tensors
        """
        batch_size, seq_len, _ = x.shape
        
        # Compute Q, K, V for new tokens
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K_new = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V_new = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Concatenate with cache if available
        if kv_cache is not None:
            K_cached, V_cached = kv_cache
            K = torch.cat([K_cached, K_new], dim=2)  # Append new K
            V = torch.cat([V_cached, V_new], dim=2)  # Append new V
        else:
            K = K_new
            V = V_new
        
        # Standard attention (Q attends to ALL K, V including cached)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, V)
        
        # Reshape and project
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(context)
        
        # Return output AND updated cache
        return output, (K, V)

In [2]:
# Demonstrate KV-cache in action
d_model = 64
num_heads = 4
attn = AttentionWithCache(d_model, num_heads)

print("=== PREFILL: Process initial prompt ===")
prompt = torch.randn(1, 5, d_model)  # 5-token prompt
output, kv_cache = attn(prompt, kv_cache=None)
print(f"Input: {prompt.shape}")
print(f"Output: {output.shape}")
print(f"Cache K shape: {kv_cache[0].shape}")
print(f"Cache V shape: {kv_cache[1].shape}")

=== PREFILL: Process initial prompt ===
Input: torch.Size([1, 5, 64])
Output: torch.Size([1, 5, 64])
Cache K shape: torch.Size([1, 4, 5, 16])
Cache V shape: torch.Size([1, 4, 5, 16])


In [3]:
print("\n=== DECODE: Generate tokens one at a time ===")
for i in range(3):
    # Only process the NEW token
    new_token = torch.randn(1, 1, d_model)
    output, kv_cache = attn(new_token, kv_cache=kv_cache)
    
    print(f"\nStep {i + 1}:")
    print(f"  New token input: {new_token.shape}")
    print(f"  Output: {output.shape}")
    print(f"  Cache K shape: {kv_cache[0].shape}")
    print(f"  (Cache grows by 1 position each step)")


=== DECODE: Generate tokens one at a time ===

Step 1:
  New token input: torch.Size([1, 1, 64])
  Output: torch.Size([1, 1, 64])
  Cache K shape: torch.Size([1, 4, 6, 16])
  (Cache grows by 1 position each step)

Step 2:
  New token input: torch.Size([1, 1, 64])
  Output: torch.Size([1, 1, 64])
  Cache K shape: torch.Size([1, 4, 7, 16])
  (Cache grows by 1 position each step)

Step 3:
  New token input: torch.Size([1, 1, 64])
  Output: torch.Size([1, 1, 64])
  Cache K shape: torch.Size([1, 4, 8, 16])
  (Cache grows by 1 position each step)


## Memory vs Speed Tradeoff

**Memory Cost:** For each layer, we cache K and V tensors with shape `(batch, num_heads, seq_len, d_k)`. For a 6-layer model with d_model=256, 4 heads, and 200-token sequence, this is only ~3 MB per example. Very affordable!

**Speed Benefit:** Reduces time complexity from O(n²) to O(n) for generating n tokens.

Typical speedups:
- Short sequences (10-20 tokens): 2-5x faster
- Medium sequences (50-100 tokens): 10-20x faster
- Long sequences (200+ tokens): 20-50x faster

**Why ALL production LLMs use KV-cache:** The memory cost is tiny compared to the model weights, but the speed improvement is massive. Every production system (GPT, Claude, etc.) uses KV-cache for generation!

In [4]:
# Benchmark: with vs without cache
import time

class SimpleAttention(nn.Module):
    """Attention WITHOUT cache (for comparison)."""
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(context)

# Setup
d_model = 256
num_heads = 8
num_tokens_to_generate = 50

attn_no_cache = SimpleAttention(d_model, num_heads)
attn_with_cache = AttentionWithCache(d_model, num_heads)

# Share weights for fair comparison
attn_with_cache.W_q = attn_no_cache.W_q
attn_with_cache.W_k = attn_no_cache.W_k
attn_with_cache.W_v = attn_no_cache.W_v
attn_with_cache.W_o = attn_no_cache.W_o

In [5]:
# Benchmark WITHOUT cache
prompt = torch.randn(1, 10, d_model)
sequence = prompt.clone()

start = time.time()
for _ in range(num_tokens_to_generate):
    # Recompute attention over ENTIRE sequence each time
    output = attn_no_cache(sequence)
    new_token = output[:, -1:, :]  # Take last position
    sequence = torch.cat([sequence, new_token], dim=1)
time_no_cache = time.time() - start

print(f"WITHOUT cache:")
print(f"  Generated {num_tokens_to_generate} tokens in {time_no_cache:.3f}s")
print(f"  Final sequence length: {sequence.shape[1]}")

WITHOUT cache:
  Generated 50 tokens in 0.010s
  Final sequence length: 60


In [6]:
# Benchmark WITH cache
prompt = torch.randn(1, 10, d_model)

start = time.time()
# Prefill
output, kv_cache = attn_with_cache(prompt, kv_cache=None)
last_token = output[:, -1:, :]

# Decode
for _ in range(num_tokens_to_generate):
    # Only process the NEW token
    output, kv_cache = attn_with_cache(last_token, kv_cache=kv_cache)
    last_token = output
time_with_cache = time.time() - start

print(f"WITH cache:")
print(f"  Generated {num_tokens_to_generate} tokens in {time_with_cache:.3f}s")
print(f"  Final cache length: {kv_cache[0].shape[2]}")
print(f"\nSpeedup: {time_no_cache / time_with_cache:.1f}x faster!")

WITH cache:
  Generated 50 tokens in 0.007s
  Final cache length: 60

Speedup: 1.6x faster!


## Important Implementation Detail

The cache must correctly handle positional encodings! When processing token at position N, it must receive position embedding for N, not 0. Implementations track the cache length and adjust positions automatically.

## Next: Interpretability

Now that we can efficiently generate text, let's explore what our model has actually learned using interpretability techniques.