<a href="https://colab.research.google.com/github/shubham-bari/Language-Models/blob/main/MultiQueryAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
def tensor_size_bytes(t):
    return t.nelement() * t.element_size()
#function to measure memory usage

In [30]:
import torch
import torch.nn as nn
from typing import Dict
import math

class MultiQueryAttention(nn.Module):

  def __init__(self, d_in, d_out, context_len, q_heads, dropout=0.0):
    super().__init__()
    assert d_out%q_heads==0, "d_out must be divisible by q_heads"

    self.d_in = d_in
    self.d_out = d_out
    self.context_len = context_len
    self.q_heads = q_heads
    self.dropout = dropout
    self.head_dim = d_out // q_heads

    self.Wq = nn.Linear(d_in, d_out, bias=False)
    self.Wk = nn.Linear(d_in, d_out, bias=False)
    self.Wv = nn.Linear(d_in, d_out, bias=False)

    self.dropout = nn.Dropout(dropout)
    self.out_proj = nn.Linear(d_out, d_out)

    self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_len, context_len), diagonal=1)
    )

  def forward(self, x, past_kv: Dict[str, torch.tensor]=None):

    b, context_len, d_in = x.shape

    q = self.Wq(x)
    k = self.Wk(x)
    v = self.Wv(x)

    q = q.view(b, context_len, self.q_heads, self.head_dim)
    # shape is (b, 3, 2, 6)

    k = k.view(b, context_len, self.q_heads, self.head_dim)
    k = k[:, :, :1, :]
    #shape is (b,3,1,6)

    v = v.view(b, context_len, self.q_heads, self.head_dim)
    v = v[:,:, :1, :]


    q = q.transpose(1, 2)  #(b,2,3,6)
    k = k.transpose(1, 2)  #(b,1,3,6)
    v = v.transpose(1, 2)

    # If past_kv provided, concatenate along seq_len dim (dim=2)
    if past_kv is not None:

      # expected shapes for past_kv['k'] and ['v']: (b, num_heads, seq_len_past, head_dim)
      past_k = past_kv['k']
      past_v = past_kv['v']

      if past_k is not None:
        k = torch.cat((past_k, k), dim=2)  # new_seq_len = seq_len_past+ seq_len

      if past_v is not None:
        v = torch.cat((past_v, v), dim=2)

    seq_len_total = k.shape[2]

    attn_scores = q@k.transpose(2,3)  #k=(b, n_heads, head_dim, s)
    attn_scores = attn_scores/math.sqrt(self.head_dim)  #scaling

    if past_kv is None:
      causal_mask = self.mask.bool()[:context_len, :seq_len_total]
      attn_scores.masked_fill_(causal_mask, -torch.inf)
    else:
      start_idx = seq_len_total - context_len
      causal_mask = self.mask.bool()[start_idx:start_idx+context_len, :seq_len_total]
      attn_scores.masked_fill_(causal_mask, -torch.inf)

    attn_weights = torch.softmax(attn_scores, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vecs = (attn_weights@v).transpose(1,2)

    context_vecs = context_vecs.contiguous().view(b, self.context_len , self.d_out)  #we roll out back to 3 dims
    context_vecs = self.out_proj(context_vecs)

    present_kv=None

    present_kv = {'k': k.detach(), 'v': v.detach()}
    return context_vecs, present_kv





In [32]:
import torch
import time

torch.manual_seed(123)

# ---------- Dummy input (same as your MHA test) ----------
inputs = torch.tensor([
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
    [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
    [9.0, 8.0, 7.0, 6.0, 5.0, 4.0]
], dtype=torch.float32)

batch = torch.stack((inputs, inputs), dim=0)   # shape = (2,3,6)
print("Batch shape:", batch.shape)

# ---------- Model ----------
batch_size, num_tokens, d_in = batch.shape
d_out = d_in

mqa = MultiQueryAttention(
    d_in=d_in,
    d_out=d_out,
    context_len=num_tokens,
    q_heads=2,
    dropout=0.0
)

# ---------- Run MHA-KV without caching ----------
start = time.time()
context_vecs, present_kv = mqa(batch)
end = time.time()

print("\nContext vectors:\n", context_vecs)
print("\nShape:", context_vecs.shape)
print("\nPresent KV shapes:")
print("K:", present_kv['k'].shape)
print("V:", present_kv['v'].shape)

print("\nTime:", end - start, "seconds")

k_mem = tensor_size_bytes(present_kv['k'])
v_mem = tensor_size_bytes(present_kv['v'])

total = k_mem + v_mem

print("K memory:", k_mem, "bytes")
print("V memory:", v_mem, "bytes")
print("Total KV memory:", total / 1024, "KB")


Batch shape: torch.Size([2, 3, 6])

Context vectors:
 tensor([[[ 0.8755,  0.1901,  1.4868, -0.5051, -0.0724,  0.6745],
         [ 0.8412,  0.1574,  1.3448, -0.4141, -0.1003,  0.6521],
         [ 2.1408, -0.7935,  3.4938, -2.1454,  1.0740,  0.0329]],

        [[ 0.8755,  0.1901,  1.4868, -0.5051, -0.0724,  0.6745],
         [ 0.8412,  0.1574,  1.3448, -0.4141, -0.1003,  0.6521],
         [ 2.1408, -0.7935,  3.4938, -2.1454,  1.0740,  0.0329]]],
       grad_fn=<ViewBackward0>)

Shape: torch.Size([2, 3, 6])

Present KV shapes:
K: torch.Size([2, 1, 3, 3])
V: torch.Size([2, 1, 3, 3])

Time: 0.0015461444854736328 seconds
K memory: 72 bytes
V memory: 72 bytes
Total KV memory: 0.140625 KB


We can see from above that the memory used by KV cache has almost reduced by 50% or a factor of 1/2 (n_heads=2) than normally

