
# KV Cache Hands-On Notebook
This Colab-ready notebook mirrors the experiments from `kv_cache_blog.md`. Run each section to reproduce the results and compare them with the narrative in the blog post.



## 1. Baseline Attention (No Cache)


In [1]:

import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

torch.manual_seed(0)
np.random.seed(0)

plt.rcParams["figure.figsize"] = (8, 4.5)
plt.rcParams["axes.grid"] = True

class SimpleAttentionWithoutCache(nn.Module):
    """Basic attention without KV cache - inefficient but clear"""
    def __init__(self, dim=512, n_heads=8):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        Q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        K = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        V = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)

        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        mask = torch.triu(torch.ones(seq_len, seq_len, device=scores.device), diagonal=1).bool()
        scores = scores.masked_fill(mask, float("-inf"))

        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
        return self.out_proj(attn_output)

example_input = torch.randn(1, 4, 512)
baseline_model = SimpleAttentionWithoutCache()
with torch.no_grad():
    baseline_output = baseline_model(example_input)

print(f"Torch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print("Without cache output shape:", tuple(baseline_output.shape))
print("First token projection sample:", [round(x, 4) for x in baseline_output[0, 0, :5].tolist()])


Torch version: 2.8.0+cpu
NumPy version: 1.26.4
Without cache output shape: (1, 4, 512)
First token projection sample: [-0.288, -0.1835, -0.0284, 0.6146, 0.0041]



## 2. Attention with KV Cache


In [2]:

class SimpleAttentionWithCache(nn.Module):
    """Attention with KV cache - much more efficient for generation"""
    def __init__(self, dim=512, n_heads=8, max_seq_len=2048):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.max_seq_len = max_seq_len

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

        self.register_buffer("k_cache", torch.zeros(1, n_heads, max_seq_len, self.head_dim))
        self.register_buffer("v_cache", torch.zeros(1, n_heads, max_seq_len, self.head_dim))
        self.cache_position = 0

    def forward(self, x, use_cache=True):
        batch_size, seq_len, _ = x.shape

        if use_cache and seq_len == 1:
            Q = self.q_proj(x).view(batch_size, 1, self.n_heads, self.head_dim).transpose(1, 2)
            K = self.k_proj(x).view(batch_size, 1, self.n_heads, self.head_dim).transpose(1, 2)
            V = self.v_proj(x).view(batch_size, 1, self.n_heads, self.head_dim).transpose(1, 2)

            pos = self.cache_position
            self.k_cache[:, :, pos:pos + 1, :] = K
            self.v_cache[:, :, pos:pos + 1, :] = V

            K_full = self.k_cache[:, :, :pos + 1, :]
            V_full = self.v_cache[:, :, :pos + 1, :]

            scores = torch.matmul(Q, K_full.transpose(-2, -1)) / np.sqrt(self.head_dim)
            attn_weights = torch.softmax(scores, dim=-1)
            attn_output = torch.matmul(attn_weights, V_full)

            self.cache_position += 1
        else:
            Q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
            K = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
            V = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

            mask = torch.triu(torch.ones(seq_len, seq_len, device=Q.device), diagonal=1).bool()
            scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
            scores = scores.masked_fill(mask, float("-inf"))
            attn_weights = torch.softmax(scores, dim=-1)
            attn_output = torch.matmul(attn_weights, V)

            if use_cache:
                self.k_cache[:, :, :seq_len, :] = K
                self.v_cache[:, :, :seq_len, :] = V
                self.cache_position = seq_len

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.dim)
        return self.out_proj(attn_output)

    def clear_cache(self):
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.cache_position = 0

cached_model = SimpleAttentionWithCache(dim=512, n_heads=8, max_seq_len=32)
cached_model.clear_cache()
prompt_tokens = torch.randn(1, 3, 512)
with torch.no_grad():
    prompt_output = cached_model(prompt_tokens, use_cache=True)
print("Prompt output shape:", tuple(prompt_output.shape))
print("Cache position after prompt:", cached_model.cache_position)

next_token = torch.randn(1, 1, 512)
with torch.no_grad():
    next_output = cached_model(next_token, use_cache=True)

print("Next-token output shape:", tuple(next_output.shape))
print("Cache position after one-step generation:", cached_model.cache_position)


Prompt output shape: (1, 3, 512)
Cache position after prompt: 3
Next-token output shape: (1, 1, 512)
Cache position after one-step generation: 4



## 3. Benchmarking Cache Speed


In [3]:

def benchmark_generation(model_class, seq_length=25, dim=256):
    """Compare generation time with and without cache."""
    model = model_class(dim=dim)
    model.eval()

    times = []

    with torch.no_grad():
        for i in range(seq_length):
            start = time.time()

            if hasattr(model, "clear_cache"):
                if i == 0:
                    model.clear_cache()
                x = torch.randn(1, 1, dim)
                _ = model(x, use_cache=True)
            else:
                x = torch.randn(1, i + 1, dim)
                _ = model(x)

            times.append(time.time() - start)

    return times

no_cache_times = benchmark_generation(SimpleAttentionWithoutCache, seq_length=25, dim=256)
with_cache_times = benchmark_generation(SimpleAttentionWithCache, seq_length=25, dim=256)

steps = np.arange(1, len(no_cache_times) + 1)
plt.figure()
plt.plot(steps, np.cumsum(no_cache_times), label="Without KV Cache", linewidth=2)
plt.plot(steps, np.cumsum(with_cache_times), label="With KV Cache", linewidth=2)
plt.xlabel("Token position")
plt.ylabel("Cumulative time (s)")
plt.title("Generation Time: With vs Without KV Cache")
plt.legend()
plt.tight_layout()
plt.show()
plt.close()

print(f"Total time without cache: {sum(no_cache_times):.4f}s")
print(f"Total time with cache: {sum(with_cache_times):.4f}s")
print(f"Speedup: {sum(no_cache_times) / sum(with_cache_times):.2f}x")


Total time without cache: 0.0190s
Total time with cache: 0.0186s
Speedup: 1.02x



## 4. Memory Footprint


In [4]:

def calculate_kv_cache_memory(seq_len, n_layers, n_heads, head_dim, batch_size=1, dtype_bytes=2):
    """Calculate KV cache memory requirements in megabytes."""
    cache_per_layer = 2 * batch_size * n_heads * seq_len * head_dim * dtype_bytes
    total_cache = n_layers * cache_per_layer
    return total_cache / (1024 * 1024)

gpt3_base = {
    "seq_len": 2048,
    "n_layers": 96,
    "n_heads": 40,
    "head_dim": 128
}

print("Estimated KV cache memory for a GPT-3 13B style model:")
for batch_size in [1, 8, 16, 32]:
    memory_mb = calculate_kv_cache_memory(batch_size=batch_size, **gpt3_base)
    print(f"Batch size {batch_size}: {memory_mb:.2f} MB ({memory_mb / 1024:.2f} GB)")


Estimated KV cache memory for a GPT-3 13B style model:
Batch size 1: 3840.00 MB (3.75 GB)
Batch size 8: 30720.00 MB (30.00 GB)
Batch size 16: 61440.00 MB (60.00 GB)
Batch size 32: 122880.00 MB (120.00 GB)



## 5. Minimal Transformer with KV Cache


In [5]:

class SimpleTransformerWithCache(nn.Module):
    """A minimal transformer for demonstration."""
    def __init__(self, vocab_size=50257, dim=512, n_heads=8, n_layers=6, max_seq_len=2048):
        super().__init__()
        self.vocab_size = vocab_size
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.token_embedding = nn.Embedding(vocab_size, dim)
        self.pos_embedding = nn.Embedding(max_seq_len, dim)
        self.layers = nn.ModuleList([
            SimpleAttentionWithCache(dim, n_heads, max_seq_len=max_seq_len) for _ in range(n_layers)
        ])
        self.ln_final = nn.LayerNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
        self.current_position = 0

    def forward(self, input_ids, use_cache=True):
        batch_size, seq_len = input_ids.shape
        if use_cache:
            start_pos = self.current_position
            positions = torch.arange(start_pos, start_pos + seq_len, device=input_ids.device)
            self.current_position += seq_len
        else:
            positions = torch.arange(seq_len, device=input_ids.device)

        x = self.token_embedding(input_ids)
        x = x + self.pos_embedding(positions).unsqueeze(0)

        for layer in self.layers:
            x = x + layer(x, use_cache=use_cache)

        x = self.ln_final(x)
        logits = self.lm_head(x)
        return logits

    def clear_cache(self):
        for layer in self.layers:
            layer.clear_cache()
        self.current_position = 0

    def generate(self, prompt_ids, max_length=50, temperature=0.8):
        self.eval()
        generated = prompt_ids.clone()
        self.clear_cache()

        with torch.no_grad():
            logits = self.forward(prompt_ids, use_cache=True)

            for _ in range(max_length - prompt_ids.shape[1]):
                logits_last = logits[:, -1, :] / temperature
                probs = torch.softmax(logits_last, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                generated = torch.cat([generated, next_token], dim=1)
                logits = self.forward(next_token, use_cache=True)

        return generated

demo_model = SimpleTransformerWithCache(vocab_size=500, dim=128, n_heads=4, n_layers=2, max_seq_len=64)
demo_prompt = torch.randint(0, 500, (1, 5))
start = time.time()
demo_output = demo_model.generate(demo_prompt, max_length=20, temperature=0.9)
elapsed = time.time() - start

print("Prompt token IDs:", demo_prompt.tolist()[0])
print("Generated token IDs:", demo_output.tolist()[0])
print(f"New tokens generated: {demo_output.shape[1] - demo_prompt.shape[1]}")
print(f"Generation time: {elapsed:.3f}s")


Prompt token IDs: [175, 419, 281, 170, 137]
Generated token IDs: [175, 419, 281, 170, 137, 147, 343, 45, 368, 445, 128, 298, 322, 245, 50, 280, 36, 221, 51, 105]
New tokens generated: 15
Generation time: 0.044s



## 6. Multi-Query and Grouped-Query Attention


In [6]:

class MultiQueryAttention(nn.Module):
    """MQA: shared K,V across all heads."""
    def __init__(self, dim=512, n_heads=8):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, self.head_dim)
        self.v_proj = nn.Linear(dim, self.head_dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        Q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        K = self.k_proj(x).view(batch_size, seq_len, 1, self.head_dim)
        V = self.v_proj(x).view(batch_size, seq_len, 1, self.head_dim)

        K = K.expand(-1, -1, self.n_heads, -1).transpose(1, 2)
        V = V.expand(-1, -1, self.n_heads, -1).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
        return self.out_proj(attn_output)

def compare_cache_memory(seq_len=2048, n_layers=32, n_heads=32, head_dim=128, dtype_bytes=2):
    mha_cache = 2 * n_layers * n_heads * seq_len * head_dim * dtype_bytes
    mqa_cache = 2 * n_layers * 1 * seq_len * head_dim * dtype_bytes
    n_groups = 8
    gqa_cache = 2 * n_layers * n_groups * seq_len * head_dim * dtype_bytes

    print(f"Standard MHA KV cache: {mha_cache / (1024**2):.2f} MB")
    print(f"MQA KV cache: {mqa_cache / (1024**2):.2f} MB ({mha_cache / mqa_cache:.1f}x smaller)")
    print(f"GQA KV cache (8 groups): {gqa_cache / (1024**2):.2f} MB ({mha_cache / gqa_cache:.1f}x smaller)")

mqa_demo = MultiQueryAttention(dim=256, n_heads=8)
sample_tokens = torch.randn(2, 5, 256)
with torch.no_grad():
    mqa_output = mqa_demo(sample_tokens)

print("MQA output shape:", tuple(mqa_output.shape))
compare_cache_memory()


MQA output shape: (2, 5, 256)
Standard MHA KV cache: 1024.00 MB
MQA KV cache: 32.00 MB (32.0x smaller)
GQA KV cache (8 groups): 256.00 MB (4.0x smaller)



## 7. Sliding Window Cache


In [7]:

class SlidingWindowCache:
    """Maintain only recent K tokens in cache."""
    def __init__(self, window_size=1024, dim=512):
        self.window_size = window_size
        self.k_cache = torch.zeros(window_size, dim)
        self.v_cache = torch.zeros(window_size, dim)
        self.position = 0

    def update(self, k, v):
        idx = self.position % self.window_size
        self.k_cache[idx] = k
        self.v_cache[idx] = v
        self.position += 1

    def get_cache(self):
        if self.position < self.window_size:
            return self.k_cache[:self.position], self.v_cache[:self.position]
        idx = self.position % self.window_size
        k = torch.cat([self.k_cache[idx:], self.k_cache[:idx]], dim=0)
        v = torch.cat([self.v_cache[idx:], self.v_cache[:idx]], dim=0)
        return k, v

window_cache = SlidingWindowCache(window_size=4, dim=3)
for step in range(6):
    k_vec = torch.full((3,), float(step + 1))
    v_vec = torch.full((3,), float(-(step + 1)))
    window_cache.update(k_vec, v_vec)
    current_k, current_v = window_cache.get_cache()
    print(f"Step {step + 1}: cache length {current_k.shape[0]}")

print("Latest K entries:", [row.tolist() for row in current_k])
print("Latest V entries:", [row.tolist() for row in current_v])


Step 1: cache length 1
Step 2: cache length 2
Step 3: cache length 3
Step 4: cache length 4
Step 5: cache length 4
Step 6: cache length 4
Latest K entries: [[3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0]]
Latest V entries: [[-3.0, -3.0, -3.0], [-4.0, -4.0, -4.0], [-5.0, -5.0, -5.0], [-6.0, -6.0, -6.0]]



## 8. Dynamic Cache Allocation


In [8]:

class DynamicKVCache:
    """Dynamically growing KV cache."""
    def __init__(self, n_heads, head_dim):
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.k_cache = []
        self.v_cache = []

    def update(self, k, v):
        self.k_cache.append(k)
        self.v_cache.append(v)

    def get_cache(self):
        if not self.k_cache:
            return None, None
        k = torch.stack(self.k_cache, dim=1)
        v = torch.stack(self.v_cache, dim=1)
        return k, v

    def clear(self):
        self.k_cache = []
        self.v_cache = []

dynamic_cache = DynamicKVCache(n_heads=4, head_dim=3)
for step in range(5):
    k = torch.full((4, 3), float(step + 1))
    v = torch.full((4, 3), float(-(step + 1)))
    dynamic_cache.update(k, v)

k_stack, v_stack = dynamic_cache.get_cache()
print("Dynamic cache K shape:", tuple(k_stack.shape))
print("Latest timestep keys:", k_stack[:, -1, :].tolist())
dynamic_cache.clear()
print("Cache cleared. Stored entries:", len(dynamic_cache.k_cache))


Dynamic cache K shape: (4, 5, 3)
Latest timestep keys: [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0], [5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
Cache cleared. Stored entries: 0



## 9. Debugging Helpers


In [9]:

def debug_kv_cache(model, input_sequence):
    """Helper function to debug KV cache issues."""
    model.clear_cache()
    initial_cache = model.layers[0].k_cache.clone()
    _ = model(input_sequence[:, :1], use_cache=True)
    assert not torch.equal(initial_cache, model.layers[0].k_cache), "Cache not updating."

    for i in range(input_sequence.shape[1]):
        model.clear_cache()
        _ = model(input_sequence[:, :i + 1], use_cache=True)
        cache_pos = model.layers[0].cache_position
        assert cache_pos == i + 1, f"Cache position mismatch at {i}: {cache_pos} != {i + 1}"

    model.clear_cache()
    with_cache_output = model(input_sequence, use_cache=True)
    model.clear_cache()
    without_cache_output = model(input_sequence, use_cache=False)

    if not torch.allclose(with_cache_output, without_cache_output, rtol=1e-5, atol=1e-5):
        diff = (with_cache_output - without_cache_output).abs().max()
        print("Warning: outputs differ between cached and uncached paths.")
        print(f"Max difference: {diff.item():.3e}")
    else:
        print("Outputs match between cached and uncached paths.")

    print("KV cache validation complete.")

test_model = SimpleTransformerWithCache(vocab_size=300, dim=64, n_heads=4, n_layers=2, max_seq_len=32)
test_sequence = torch.randint(0, 300, (1, 6))
debug_kv_cache(test_model, test_sequence)


Outputs match between cached and uncached paths.
KV cache validation complete.
