<a href="https://colab.research.google.com/github/shubham-bari/Language-Models/blob/main/MHA_KV_Cache.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

In [None]:
import math
from typing import Dict

In [17]:
def tensor_size_bytes(t):
    return t.nelement() * t.element_size()
#function to measure memory usage

In [None]:
class MultiHeadKV(nn.Module):

  def __init__(self, d_in, d_out, context_len, n_heads, dropout: float=0.0):
    super().__init__()
    assert d_out%n_heads==0, "d_out must be divisible by n_heads"

    self.d_in = d_in
    self.d_out = d_out
    self.n_heads = n_heads
    self.head_dim = d_out//n_heads
    self.context_len = context_len
    self.Wq = nn.Linear(d_in, d_out, bias=False)
    self.Wk = nn.Linear(d_in, d_out, bias=False)
    self.Wv = nn.Linear(d_in, d_out, bias=False)

    self.output_projection = nn.Linear(d_out, d_out, bias = False)
    self.dropout = nn.Dropout(dropout)

    self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_len, context_len), diagonal=1)
    )

  def forward(self, x, past_kv: Dict[str, torch.tensor]=None):

    b, context_size, d_out = x.shape
    batch_size = b

    q = self.Wq(x)
    k = self.Wk(x)
    v = self.Wv(x)

    q = q.view(batch_size, context_size, self.n_heads, self.head_dim)
    k = k.view(batch_size, context_size, self.n_heads, self.head_dim)
    v = v.view(batch_size, context_size, self.n_heads, self.head_dim)

    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    # If past_kv provided, concatenate along seq_len dim (dim=2)
    if past_kv is not None:

      # expected shapes for past_kv['k'] and ['v']: (b, num_heads, seq_len_past, head_dim)
      past_k = past_kv['k']
      past_v = past_kv['v']

      if past_k is not None:
        k = torch.cat((past_k, k), dim=2)  # new_seq_len = seq_len_past+ seq_len

      if past_v is not None:
        v = torch.cat((past_v, v), dim=2)

    seq_len_total = k.shape[2]

    attn_scores = torch.matmul(q, k.transpose(2,3))  #k=(b, n_heads, head_dim, s)
    attn_scores = attn_scores/math.sqrt(self.head_dim)  #scaling

    if past_kv is None:
      causal_mask = self.mask.bool()[:context_size, :seq_len_total]
      attn_scores.masked_fill_(causal_mask, -torch.inf)
    else:
      start_idx = seq_len_total - context_size
      causal_mask = self.mask.bool()[start_idx:start_idx+context_size, :seq_len_total]
      attn_scores.masked_fill_(causal_mask, -torch.inf)

    attn_weights = torch.softmax(attn_scores, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vecs = (attn_weights@v).transpose(1,2)

    context_vecs = context_vecs.contiguous().view(b, self.context_len , self.d_out)  #we roll out back to 3 dims
    context_vecs = self.output_projection(context_vecs)

    present_kv=None

    present_kv = {'k': k.detach(), 'v': v.detach()}
    return context_vecs, present_kv








In [19]:
import torch
import time

torch.manual_seed(123)

# ---------- Dummy input (same as your MHA test) ----------
inputs = torch.tensor([
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
    [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
    [9.0, 8.0, 7.0, 6.0, 5.0, 4.0]
], dtype=torch.float32)

batch = torch.stack((inputs, inputs), dim=0)   # shape = (2,3,6)
print("Batch shape:", batch.shape)

# ---------- Model ----------
batch_size, num_tokens, d_in = batch.shape
d_out = d_in

mha_kv = MultiHeadKV(
    d_in=d_in,
    d_out=d_out,
    context_len=num_tokens,
    n_heads=2,
    dropout=0.0
)

# ---------- Run MHA-KV without caching ----------
start = time.time()
context_vecs, present_kv = mha_kv(batch)
end = time.time()

print("\nContext vectors:\n", context_vecs)
print("\nShape:", context_vecs.shape)
print("\nPresent KV shapes:")
print("K:", present_kv['k'].shape)
print("V:", present_kv['v'].shape)

print("\nTime:", end - start, "seconds")

k_mem = tensor_size_bytes(present_kv['k'])
v_mem = tensor_size_bytes(present_kv['v'])

total = k_mem + v_mem

print("K memory:", k_mem, "bytes")
print("V memory:", v_mem, "bytes")
print("Total KV memory:", total / 1024, "KB")



Batch shape: torch.Size([2, 3, 6])

Context vectors:
 tensor([[[ 0.0844,  0.2387,  1.2366,  0.5299, -0.7774,  0.2029],
         [ 0.0813,  0.1798,  1.0878,  0.5513, -0.7497,  0.2093],
         [ 0.1015,  0.1774,  1.0688,  0.5545, -0.7234,  0.2244]],

        [[ 0.0844,  0.2387,  1.2366,  0.5299, -0.7774,  0.2029],
         [ 0.0813,  0.1798,  1.0878,  0.5513, -0.7497,  0.2093],
         [ 0.1015,  0.1774,  1.0688,  0.5545, -0.7234,  0.2244]]],
       grad_fn=<UnsafeViewBackward0>)

Shape: torch.Size([2, 3, 6])

Present KV shapes:
K: torch.Size([2, 2, 3, 3])
V: torch.Size([2, 2, 3, 3])

Time: 0.0009682178497314453 seconds
K memory: 144 bytes
V memory: 144 bytes
Total KV memory: 0.28125 KB


We can see that the KV cache takes 144 bytes of memory which is directly proportional to the number of attention heads used
