# KV cache skeleton (exercise)

Reference: [../KV_cache.ipynb](../KV_cache.ipynb).

**Goals:**
- Implement **KVCache.update(layer, new_keys, new_values)**: write the new K/V into the preallocated cache at `current_length`; shapes are [B, nh, token_count, hd]. Cache is per-layer; keys/values have shape [B, nh, max_len, hd]. Advance `current_length` after append (caller does this in _prefill/generate).
- In **Attention forward**: when `kv_cache` is not None and layer_idx is set, call `kv_cache.update(layer_idx, K, V)` and read K, V from the cache (slice by `kv_cache.current_length`) instead of using the full K, V from this forward.

In [None]:
# Setup: self-contained transformer (same as reference). Focus on the TODO cells below.
import torch
from typing import List

ACT2FN = {
    'relu': torch.nn.functional.relu,
    'gelu': torch.nn.functional.gelu,
    'silu': torch.nn.functional.silu,
    'swish': torch.nn.functional.silu,
}

class Attention(torch.nn.Module):
    def __init__(self, D=768, layer_idx=None, head_dim=64, causal=True, device="cuda", gqa=False):
        super().__init__()
        self.D = D
        self.head_dim = head_dim
        self.gqa = gqa
        assert D % head_dim == 0
        self.nheads = D // head_dim
        self.Wq = torch.nn.Linear(D, D)
        self.Wk = torch.nn.Linear(D, D)
        self.Wv = torch.nn.Linear(D, D)
        self.causal = causal
        self.Wo = torch.nn.Linear(D, D)
        self.device = device
        self.layer_idx = layer_idx

    def forward(self, x: torch.Tensor, kv_cache=None):
        B, S, D = x.shape
        Q, K, V = self.Wq(x), self.Wk(x), self.Wv(x)
        Q = Q.view(B, S, self.nheads, self.head_dim).transpose(1, 2)
        K = K.view(B, S, self.nheads, self.head_dim).transpose(1, 2)
        V = V.view(B, S, self.nheads, self.head_dim).transpose(1, 2)

        # TODO: If kv_cache is not None and self.layer_idx is not None:
        #   Call kv_cache.update(self.layer_idx, K, V).
        #   Then set K = kv_cache.keys[layer_idx][:, :, :kv_cache.current_length, :]
        #   and V = kv_cache.values[layer_idx][:, :, :kv_cache.current_length, :]
        if kv_cache is not None and self.layer_idx is not None:
            raise NotImplementedError("TODO: update cache and read K, V from cache")

        scale = torch.sqrt(torch.tensor(self.head_dim, dtype=Q.dtype, device=self.device))
        logits = (Q @ K.transpose(-2, -1)) / scale
        if self.causal:
            mask = torch.triu(torch.ones_like(logits), diagonal=1).bool()
            logits = logits.masked_fill(mask, float('-inf'))
        A = torch.nn.functional.softmax(logits, dim=-1)
        preout = torch.einsum('bnxy,bnyd->bnxd', A, V)
        preout = preout.transpose(1, 2).reshape(B, S, -1)
        return self.Wo(preout)

class MLP(torch.nn.Module):
    def __init__(self, D, hidden_multiplier=4, act='swish', device=None):
        super().__init__()
        self.D = D
        self.up_proj = torch.nn.Linear(D, D * hidden_multiplier)
        self.down_proj = torch.nn.Linear(D * hidden_multiplier, D)
        self.act = ACT2FN[act]
    def forward(self, x):
        return self.down_proj(self.act(self.up_proj(x)))

class LN(torch.nn.Module):
    def __init__(self, D, eps=1e-9, device=None):
        super().__init__()
        self.mean_scale = torch.nn.Parameter(torch.zeros(D))
        self.std_scale = torch.nn.Parameter(torch.ones(D))
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = (x.var(dim=-1, keepdim=True) + 1e-9)**0.5
        return (x - mean) / std * self.std_scale + self.mean_scale

class TransformerLayer(torch.nn.Module):
    def __init__(self, D, gqa=False, device=None):
        super().__init__()
        self.attn = Attention(D, gqa=gqa, device=device or torch.device("cuda"))
        self.mlp = MLP(D, device=device)
        self.ln1 = LN(D, device=device)
        self.ln2 = LN(D, device=device)
    def forward(self, x, kv_cache=None):
        x = x + self.attn(self.ln1(x), kv_cache=kv_cache)
        return x + self.mlp(self.ln2(x))

class PositionalEmbedding(torch.nn.Module):
    def __init__(self, max_seq_len, D, device=None):
        super().__init__()
        self.pos_embedding = torch.nn.Parameter(torch.randn(max_seq_len, D))
    def forward(self, x):
        B, S, D = x.shape
        return x + self.pos_embedding[:S]

class EmbeddingLayer(torch.nn.Module):
    def __init__(self, vocab_size, D, device=None):
        super().__init__()
        self.embedding = torch.nn.Parameter(torch.randn(vocab_size, D))
    def forward(self, x):
        return self.embedding[x]

class UnembeddingLayer(torch.nn.Module):
    def __init__(self, vocab_size, D, device=None):
        super().__init__()
        self.unembedding = torch.nn.Linear(D, vocab_size)
    def forward(self, x):
        return self.unembedding(x)

class Transformer(torch.nn.Module):
    def __init__(self, depth, hidden_dim, vocab_size, max_seq_len=16384, device=None, gqa=False):
        super().__init__()
        self.depth = depth
        self.hidden_dim = hidden_dim
        self.emb = EmbeddingLayer(vocab_size, hidden_dim, device=device)
        self.pos_emb = PositionalEmbedding(max_seq_len, hidden_dim, device=device)
        self.unemb = UnembeddingLayer(vocab_size, hidden_dim, device=device)
        self.layers = torch.nn.ModuleList([TransformerLayer(hidden_dim, gqa, device=device) for _ in range(depth)])
        for i, layer in enumerate(self.layers):
            layer.attn.layer_idx = i
        self.device = device

    def forward(self, x, kv_cache=None):
        x = self.emb(x)
        if kv_cache is not None:
            pos_offset = kv_cache.current_length
            pos_emb = self.pos_emb.pos_embedding[pos_offset: pos_offset + x.size(1)].unsqueeze(0)
            x = x + pos_emb
        else:
            x = self.pos_emb(x)
        for layer in self.layers:
            x = layer(x, kv_cache=kv_cache)
        return self.unemb(x)

In [None]:
'''
KV CACHE: preallocated tensor we fill during autoregressive decoding.
Prefill = process full prompt and populate cache. Decode = one token at a time, update cache with new K,V.
Cache shape per layer: keys/values [B, nh, max_seq_len, head_dim]. Advance current_length after append.
'''

In [None]:
class KVCache:
    """
    Preallocated K/V cache. keys and values: list of tensors [B, num_heads, max_seq_len, head_dim] per layer.
    """
    def __init__(self, num_layers: int, batch_size: int, num_heads: int, head_dim: int, max_seq_len: int, device='cuda'):
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.device = device
        self.current_length = 0
        # Preallocate: one tensor per layer for keys and values
        self.keys = [torch.empty(batch_size, num_heads, max_seq_len, head_dim, device=device) for _ in range(num_layers)]
        self.values = [torch.empty(batch_size, num_heads, max_seq_len, head_dim, device=device) for _ in range(num_layers)]

    def update(self, layer: int, new_keys: torch.Tensor, new_values: torch.Tensor):
        # TODO: Write new_keys and new_values into self.keys[layer] and self.values[layer]
        # at positions [:, :, seq_offset : seq_offset + token_count, :]. seq_offset = self.current_length.
        raise NotImplementedError("TODO: implement KVCache.update")

class TransformerGenerator:
    def __init__(self, model: Transformer, max_seq_len: int = 4096):
        self.model = model
        self.device = next(model.parameters()).device
        self.max_seq_len = max_seq_len
        self.kv_cache = None

    def _initialize_cache(self, batch_size: int):
        attn = self.model.layers[0].attn
        num_heads = getattr(attn, 'num_kv_heads', None) or attn.nheads
        self.kv_cache = KVCache(
            self.model.depth, batch_size, num_heads, attn.head_dim,
            self.max_seq_len, self.device
        )
        self.model.kv_cache = self.kv_cache

    def _prefill(self, prompt: List[int]):
        prompt_tensor = torch.tensor(prompt, device=self.device).unsqueeze(0)
        self._initialize_cache(prompt_tensor.size(0))
        _ = self.model(prompt_tensor, kv_cache=self.kv_cache)
        self.kv_cache.current_length = prompt_tensor.size(1)
        return self.kv_cache

    def generate(self, prompt: List[int], max_new_tokens: int):
        kv_cache = self._prefill(prompt)
        generated = list(prompt)
        for _ in range(max_new_tokens):
            input_tensor = torch.tensor([[generated[-1]]], device=self.device)
            logits = self.model(input_tensor, kv_cache=kv_cache)
            next_token = int(torch.argmax(logits[:, -1, :], dim=-1).item())
            generated.append(next_token)
            kv_cache.current_length += 1
        return generated

In [None]:
# After implementing KVCache.update and the Attention kv_cache block, run a quick sanity check:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(depth=2, hidden_dim=64, vocab_size=256, device=device).to(device)
model.eval()
gen = TransformerGenerator(model, max_seq_len=128)
with torch.no_grad():
    out = gen.generate([1, 2, 3], max_new_tokens=5)
print("Generated token ids:", out)