#### Next Token Prediction Task (**KV Cache**)

![Inference](media/Inference.jpg)

In transformer models, self-attention computes attention scores over the entire input sequence for every token generated or processed.

![QKV](media/QKV.jpg)

- For each new token generated (e.g., in autoregressive decoding), you compute new Query (Q), Key (K), and Value (V) vectors for the entire sequence so far.
- This means recomputing K and V for all past tokens repeatedly, which is very inefficient because past tokens don't change during generation.

**Problem:**

- Computational inefficiency: Computing K and V repeatedly for all previous tokens leads to quadratic time complexity with respect to sequence length during generation.
- Latency: Slows down autoregressive generation in tasks like language modeling or translation.

**KV Cache**

- Cache means to store and reuse.
- Instead of recomputing `K` and `V` for all previous tokens every step, we store (cache) the `K` and `V` vectors computed at each step.
- When generating the next token, we only compute the `Q` vector for the new token, and reuse the cached `K` and `V` vectors from all previous tokens.
- This drastically reduces computation, since K and V don't have to be recomputed for the entire history.


*Illustration:*
| Step | Tokens processed          | Compute Q for | Compute K, V for | Use cached K, V for     |
| ---- | ------------------------- | ------------- | ---------------- | ----------------------- |
| 1    | \[token1]                 | token1        | token1           | token1                  |
| 2    | \[token1, token2]         | token2        | token2           | token1 (cached)         |
| 3    | \[token1, token2, token3] | token3        | token3           | token1, token2 (cached) |

Without cache, at step 3, we recompute K and V for all 3 tokens. With KV cache, at step 3, we reuse K1,K2,V1,V2 from cache, only compute K3,V3.


![Kv Cache](media/kvcache.jpg)



**Simple Multi Head Attention with KV Cache**

In [23]:
import torch 
import torch.nn as nn
from torch.functional import F

torch.manual_seed(42)

<torch._C.Generator at 0x7c83fe7ac0d0>

In [None]:
class SimpleMultiHeadAttentionWithKV(nn.Module):
    def __init__(self, d_model, num_heads, context_len, qkv_bias=False, dropout=0.0):
        super().__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num heads" 

        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = int(d_model / num_heads)

        self.wq = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.wk = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.wv = nn.Linear(d_model, d_model, bias=qkv_bias)

        # This is not required for inference, Yet will be auto disabled during inference. 
        self.dropout = nn.Dropout(dropout)
        
        self.wo = nn.Linear(d_model, d_model)

        # cache KV
        self.cached_key = None
        self.cached_value = None

        # mask 
        #This not efficient, better to do it on demand. Move to forward section
        self.register_buffer("mask", torch.triu(torch.ones(context_len, context_len), diagonal=1).bool())

    def forward(self, x):
        b, seq, d_model = x.size()

        Q = self.wq(x) # (b, seq, d_model)
        K = self.wk(x) # (b, seq, d_model)
        V = self.wv(x) # (b, seq, d_model)

        # split heads (b, seq, d_mdoel) -> (b, num_heads, seq, head_dim)
        Q = Q.view(b, seq, self.num_heads, self.head_dim).transpose(1,2)
        K = K.view(b, seq, self.num_heads, self.head_dim).transpose(1,2)
        V = V.view(b, seq, self.num_heads, self.head_dim).transpose(1,2)

        # check if have cached KV
        if self.cached_key is None or self.cached_value is None:
            self.cached_key = K
            self.cached_value = V
        else:
            # update cache original cached (b, num_heads, seq, head_dim) concatented with seq dim = 2
            self.cached_key = torch.cat((self.cached_key, K), dim = 2)
            self.cached_value = torch.cat((self.cached_value, V), dim = 2)
        
        # compute score (b, num_heads, seq, head_dim) @ (b, num_heads, head_dim, seq) -> (b, num_heads, seq, seq)
        attn_score = Q @ self.cached_key.transpose(2,3)
        mask = self.mask.unsqueeze(0).unsqueeze(0).expand(b, self.num_heads, -1, -1)
        # pluckking out the seq elements only form context length
        mask = mask[:,:,:seq,:seq]
        attn_score.masked_fill_(mask, -torch.inf)

        # compute scaled attention weight 
        attn_weight = F.softmax(attn_score / (self.head_dim **0.5), dim=-1)
        # dropuout 
        attn_weight = self.dropout(attn_weight)
        
        context_vector = attn_weight @ self.cached_value
        # reshape (b, num_heads, seq, seq) -> (b, seq, d_model)
        context_vector = context_vector.transpose(1,2).view(b, seq, self.d_model)

        output = self.wo(context_vector)
        return output, attn_weight

In [25]:
token_one = torch.randn(2, 1, 4)
token_two = torch.randn(2, 1, 4)
token_three = torch.randn(2, 1, 4)

b, seq, d_model = [2, 1, 4]
num_heads = 2
context_len = 4
num_pass = 4

kv_model = SimpleMultiHeadAttentionWithKV(d_model, num_heads, context_len )

for i in range(num_pass):
    data = torch.randn(b, seq, d_model)
    out, attn_weight = kv_model(data)
    print(f"Pass: {i+1}, K:{kv_model.cached_key.shape} V:{kv_model.cached_value.shape}")
    print(f"Out:{out.shape}, attn_eight:{attn_weight.shape}")
    print("--"*30)

Pass: 1, K:torch.Size([2, 2, 1, 2]) V:torch.Size([2, 2, 1, 2])
Out:torch.Size([2, 1, 4]), attn_eight:torch.Size([2, 2, 1, 1])
------------------------------------------------------------
Pass: 2, K:torch.Size([2, 2, 2, 2]) V:torch.Size([2, 2, 2, 2])
Out:torch.Size([2, 1, 4]), attn_eight:torch.Size([2, 2, 1, 2])
------------------------------------------------------------
Pass: 3, K:torch.Size([2, 2, 3, 2]) V:torch.Size([2, 2, 3, 2])
Out:torch.Size([2, 1, 4]), attn_eight:torch.Size([2, 2, 1, 3])
------------------------------------------------------------
Pass: 4, K:torch.Size([2, 2, 4, 2]) V:torch.Size([2, 2, 4, 2])
Out:torch.Size([2, 1, 4]), attn_eight:torch.Size([2, 2, 1, 4])
------------------------------------------------------------


This demonstrates the working of a multi-head attention module with a KV cache mechanism implemented in PyTorch. The key highlights of this example include:

* The model processes inputs of shape `[batch_size=2, seq_len=1, d_model=4]` with `num_heads=2` and a context length of 4.
* In a loop running for 4 passes, new random input data simulating tokens is fed sequentially into the model.
* The KV cache (`cached_key` and `cached_value`) accumulates over the passes, expanding the sequence dimension.
* After each pass, the shapes of the cached keys and values, as well as the output and attention weights, are printed.
* The cache shape grows along the sequence dimension with each new token processed, illustrating how the cache accumulates previous tokens' K and V projections.
* The output shape remains `[2, 1, 4]` indicating a single output per token processed.
* Attention weights shape grows from `[2, 2, 1, 1]` to `[2, 2, 1, 4]`, showing the expanding attention window over the cached keys.

---



This example effectively illustrates how KV caching enables efficient incremental processing in multi-head attention. By caching keys and values across passes, the model avoids recomputing these projections for previous tokens, improving computational efficiency during autoregressive generation or streaming scenarios.

* The progressive increase in cached K and V shapes corresponds to the growing context as more tokens are processed.
* The stable output shape per token reflects correct incremental output production.
* This approach is crucial for scaling transformer-based models to long sequences without quadratic cost in computation.
* Managing and verifying the KV cache shape is important to ensure proper attention computation over the accumulated context.

Overall, KV caching is a key optimization for transformer inference, enabling faster generation while maintaining correctness, as clearly demonstrated in this practical code example.
