<a href="https://colab.research.google.com/github/shubham-bari/Language-Models/blob/main/MHA_KV_Cache.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

In [3]:
import math
from typing import Dict

In [15]:
class MultiHeadKV(nn.Module):

  def __init__(self, d_in, d_out, context_len, n_heads, dropout: float=0.0):
    super().__init__()
    assert d_out%n_heads==0, "d_out must be divisible by n_heads"

    self.d_in = d_in
    self.d_out = d_out
    self.n_heads = n_heads
    self.head_dim = d_out//n_heads
    self.context_len = context_len
    self.Wq = nn.Linear(d_in, d_out, bias=False)
    self.Wk = nn.Linear(d_in, d_out, bias=False)
    self.Wv = nn.Linear(d_in, d_out, bias=False)

    self.output_projection = nn.Linear(d_out, d_out, bias = False)
    self.dropout = nn.Dropout(dropout)

    self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_len, context_len), diagonal=1)
    )

  def forward(self, x, past_kv: Dict[str, torch.tensor]=None):

    b, context_size, d_out = x.shape
    batch_size = b

    q = self.Wq(x)
    k = self.Wk(x)
    v = self.Wv(x)

    q = q.view(batch_size, context_size, self.n_heads, self.head_dim)
    k = k.view(batch_size, context_size, self.n_heads, self.head_dim)
    v = v.view(batch_size, context_size, self.n_heads, self.head_dim)

    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    # If past_kv provided, concatenate along seq_len dim (dim=2)
    if past_kv is not None:

      # expected shapes for past_kv['k'] and ['v']: (b, num_heads, seq_len_past, head_dim)
      past_k = past_kv['k']
      past_v = past_kv['v']

      if past_k is not None:
        k = torch.cat((past_k, k), dim=2)  # new_seq_len = seq_len_past+ seq_len

      if past_v is not None:
        v = torch.cat((past_v, v), dim=2)

    seq_len_total = k.shape[2]

    attn_scores = torch.matmul(q, k.transpose(2,3))  #k=(b, n_heads, head_dim, s)
    attn_scores = attn_scores/math.sqrt(self.head_dim)  #scaling

    if past_kv is None:
      causal_mask = self.mask.bool()[:context_size, :seq_len_total]
      attn_scores.masked_fill_(causal_mask, -torch.inf)
    else:
      start_idx = seq_len_total - context_size
      causal_mask = self.mask.bool()[start_idx:start_idx+context_size, :seq_len_total]
      attn_scores.masked_fill_(causal_mask, -torch.inf)

    attn_weights = torch.softmax(attn_scores, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vecs = (attn_weights@v).transpose(1,2)

    context_vecs = context_vecs.contiguous().view(b, self.context_len , self.d_out)  #we roll out back to 3 dims
    context_vecs = self.output_projection(context_vecs)

    present_kv=None

    present_kv = {'k': k.detach(), 'v': v.detach()}
    return context_vecs, present_kv








In [16]:
mha = MultiHeadKV(d_in=6, d_out=6, n_heads=2, context_len = 3)
cache = None
generated = []
# imagine we generate token by token; each step we feed only the new token embedding (seq_len=1)
for t in range(10):
    new_token_embed = torch.randn(2, 3, 6)   # batch=2, three new token
    out, cache = mha(new_token_embed)

    print("\n\nOuputs are:\n", out)
    print("\n\nCache are:\n", cache)
    # out shape: (2, 1, 512)
    # cache['k'].shape -> (2, num_heads, t+1, head_dim)
    # cache['v'].shape -> (2, num_heads, t+1, head_dim)
    # pass out into subsequent MLP / logits head etc.





Ouputs are:
 tensor([[[ 0.0022, -0.0203,  0.2001,  0.3570,  0.0913, -0.1347],
         [-0.3427,  0.1300, -0.0896,  0.4539,  0.1334, -0.4004],
         [-0.1972,  0.0609, -0.0726,  0.2446,  0.1138, -0.2240]],

        [[-0.0747, -0.0401,  0.3982,  0.3793,  0.2116, -0.1631],
         [ 0.1301, -0.1844,  0.3920,  0.2036,  0.2023,  0.0905],
         [ 0.1228, -0.1421,  0.2808,  0.1890,  0.1380,  0.0028]]],
       grad_fn=<UnsafeViewBackward0>)


Cache are:
 {'k': tensor([[[[-1.0803, -0.1004, -0.6312],
          [-0.9167, -0.5473,  0.3107],
          [ 0.2617,  0.2593, -0.0087]],

         [[-0.9150, -0.3980,  0.5276],
          [-1.5294, -0.6495, -0.3476],
          [ 0.3810,  0.0409, -0.3407]]],


        [[[ 0.0320,  0.0768, -0.3775],
          [ 0.0827, -0.3101, -0.2378],
          [ 0.0571, -0.4368, -0.4180]],

         [[ 0.1320,  0.5006, -0.1276],
          [ 0.7761,  0.1697, -0.9751],
          [-0.6248,  0.4890,  0.2609]]]]), 'v': tensor([[[[-0.2846, -0.4014, -0.3111],
         