<a href="https://colab.research.google.com/github/shubham-bari/Language-Models/blob/main/GroupedQueryAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
def tensor_size_bytes(t):
    return t.nelement() * t.element_size()
#function to measure memory usage

In [13]:
import torch
import torch.nn as nn
from typing import Dict

class GQA(nn.Module):

  def __init__(self, d_in, d_out, q_heads, groups, context_len, dropout):

    super().__init__()
    assert d_out % q_heads == 0, "d_out must be divisible by q_heads"
    assert q_heads % groups == 0, "groups must be divisible by q_heads"

    self.d_in=d_in
    self.d_out=d_out
    self.q_heads=q_heads
    self.groups=groups
    self.context_len=context_len
    self.dropout=dropout

    self.head_dim = self.d_in // self.q_heads  #should be same for q and kv
    self.kv_heads = q_heads // groups
    self.kv_out_dim = self.kv_heads * self.head_dim

    self.Wq = nn.Linear(d_in, d_out, bias = False)
    self.Wk = nn.Linear(d_in, self.kv_out_dim, bias=False)
    self.Wv = nn.Linear(d_in, self.kv_out_dim, bias=False)
    self.dropout = nn.Dropout(dropout)
    self.out_proj = nn.Linear(d_out, d_out, bias=False)

    self.register_buffer('mask', torch.triu(torch.ones(context_len, context_len), diagonal=1))

  def forward(self, x, past_kv:Dict[str, torch.tensor]=None):

    b, context_len, d_in = x.shape

    q = self.Wq(x)
    k = self.Wk(x)
    v = self.Wv(x)

    # heads*head_dim should always result in matrix d_out (for queries as well as keys), hence change d_out for Wk and Wv accordingly
    q = q.view(b, context_len, self.q_heads, self.head_dim)
    k = k.view(b, context_len, self.kv_heads, self.head_dim)
    v = v.view(b, context_len, self.kv_heads, self.head_dim)

    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    if past_kv is not None:

      past_k = past_kv['k']
      past_v = past_kv['v']

      if past_k is not None:
        k = torch.cat((past_k, k), dim=2)
      if past_v is not None:
        v = torch.cat((past_v, v), dim=2)

    context_len_total = k.shape[2]

    k = k.repeat_interleave(self.groups, dim=1)
    v = v.repeat_interleave(self.groups, dim=1)

    attn_scores = q@k.transpose(2, 3)
    attn_scores = attn_scores / (self.head_dim**0.5)

    if past_kv is None:
      causal_mask = self.mask.bool()[:context_len, :context_len_total]
      attn_scores.masked_fill_(causal_mask, -torch.inf)
    else:
      start_idx = context_len_total - context_len
      causal_mask = self.mask.bool()[start_idx:start_idx+context_len, :context_len_total]
      attn_scores.masked_fill_(causal_mask, -torch.inf)

    attn_weights = torch.softmax(attn_scores, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vecs = (attn_weights@v).transpose(1,2)

    context_vecs = context_vecs.contiguous().view(b, self.context_len , self.d_out)  #we roll out back to 3 dims
    context_vecs = self.out_proj(context_vecs)

    present_kv=None

    present_kv = {'k': k.detach(), 'v': v.detach()}
    return context_vecs, present_kv






In [14]:
import torch
import time

torch.manual_seed(123)

# ---------- Dummy input (same as your MHA test) ----------
inputs = torch.tensor([
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
     7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
     13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
     19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
     25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
     31.0, 32.0, 33.0, 34.0, 35.0, 36.0],

    [0.5, 1.5, 2.5, 3.5, 4.5, 5.5,
     6.5, 7.5, 8.5, 9.5, 10.5, 11.5,
     12.5, 13.5, 14.5, 15.5, 16.5, 17.5,
     18.5, 19.5, 20.5, 21.5, 22.5, 23.5,
     24.5, 25.5, 26.5, 27.5, 28.5, 29.5,
     30.5, 31.5, 32.5, 33.5, 34.5, 35.5],

    [9.0, 8.0, 7.0, 6.0, 5.0, 4.0,
     3.0, 2.0, 1.0, 0.0, -1.0, -2.0,
     -3.0, -4.0, -5.0, -6.0, -7.0, -8.0,
     -9.0, -10.0, -11.0, -12.0, -13.0, -14.0,
     -15.0, -16.0, -17.0, -18.0, -19.0, -20.0,
     -21.0, -22.0, -23.0, -24.0, -25.0, -26.0]
], dtype=torch.float32)


batch = torch.stack((inputs, inputs), dim=0)   # shape = (2,3,6)
print("Batch shape:", batch.shape)

# ---------- Model ----------
batch_size, num_tokens, d_in = batch.shape
d_out = d_in

gqa = GQA(
    d_in=d_in,
    d_out=d_out,
    context_len=num_tokens,
    q_heads=6,
    groups=2,
    dropout=0.0
)

# ---------- Run MHA-KV without caching ----------
start = time.time()
context_vecs, present_kv = gqa(batch)
end = time.time()

print("\nContext vectors:\n", context_vecs)
print("\nShape:", context_vecs.shape)
print("\nPresent KV shapes:")
print("K:", present_kv['k'].shape)
print("V:", present_kv['v'].shape)

print("\nTime:", end - start, "seconds")

k_mem = tensor_size_bytes(present_kv['k'])
v_mem = tensor_size_bytes(present_kv['v'])

total = k_mem + v_mem

print("K memory:", k_mem, "bytes")
print("V memory:", v_mem, "bytes")
print("Total KV memory:", total / 1024, "KB")


Batch shape: torch.Size([2, 3, 36])

Context vectors:
 tensor([[[ -6.1980,  -9.9308,   0.5260,   0.6715,   5.2618,   9.0740,  10.8081,
            7.8047,  -6.5301,   4.4666,  -4.1266,   7.8682,  -4.4953,  -1.3451,
           -4.5841,  -1.0412,  -6.2217,   1.9541,   5.6271,   7.7625,   1.4324,
            1.9269,  -7.8929,  -3.2784,  -6.2498, -10.0581,  -5.0259,  11.3834,
            0.2150,  -1.7109,  13.7492,   1.2524,   0.3525,   3.3179,  -6.1689,
            0.0749],
         [ -6.2260,  -9.8186,   0.6143,   0.7120,   5.2712,   8.9916,  10.7293,
            7.7545,  -6.4265,   4.4256,  -4.1391,   7.8505,  -4.4662,  -1.3120,
           -4.6133,  -1.0169,  -6.1926,   2.0974,   5.5877,   7.6623,   1.4087,
            1.9290,  -7.8667,  -3.2940,  -6.2568, -10.0813,  -5.0360,  11.3647,
            0.2924,  -1.6632,  13.6438,   1.3294,   0.3542,   3.2958,  -6.1349,
            0.0417],
         [  3.2579,  -5.7384,  -8.9809, -11.2953,  -2.1024,   0.2462,   4.1953,
            1.1100, -10

We can see that, GQA has memory usage 2x of MQA but speed is increased by almost ~82%, this highly benefits the speed-accuracy tradeoff by calculating much faster than MQA but still storing ~3x lesser memory than MHA
