In [None]:
import torch 

# mhsa baseline to compare gqa  to 
class mhsa(torch.nn.Module): 
    def __init__(self, D=512, head_dim=64, causal=True): 
        super().__init__()
        self.D = D 
        self.head_dim = head_dim 
        assert self.D % self.head_dim == 0 
        self.nheads = self.D // self.head_dim 
        self.causal = causal 

        self.wq = torch.nn.Linear(D, D)
        self.wk = torch.nn.Linear(D, D)
        self.wv = torch.nn.Linear(D, D)
        self.wo = torch.nn.Linear(D, D)


    def forward(self, x): # BSD -> BSD
        B, S, D = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x) # x is BSD wq is DD

        q = q.reshape(B, self.nheads, S, self.head_dim)
        k = k.reshape(B, self.nheads, S, self.head_dim)
        v = v.reshape(B, self.nheads, S, self.head_dim) 

        normalize = torch.sqrt(torch.tensor(self.head_dim))
        A = torch.einsum('bnij,bnkj->bnik', q, k) # [B, N, S, D] @ [B, N, S, D] -> [B, N, S, S]
        A = torch.nn.functional.softmax(A/normalize, dim=-1)

        # check if causal mask 
        if self.causal: # add -inf in A[j>i]
            mask = torch.triu(torch.ones_like(A), diagonal=1).bool()
            A = A.masked_fill(mask, float('-inf'))

        preout = torch.einsum('bnij,bnjd->bnid', A, v) # BNSS @ BNSD -> BNSD
        preout = preout.reshape(B, S, -1) # this concats under the hood 
        return self.wo(preout)

B, S, D = 8, 512, 768
attn = mhsa(D=D) 
x = torch.randn(B, S, D)
attn(x).shape # BSD -> BSD 


In [None]:
import numpy as np 

class gqa(torch.nn.Module): 
    def __init__(self, D=512, head_dim=64, causal=True, group_size=4): 
        super().__init__()
        self.D = D 
        self.head_dim = head_dim 
        assert self.D % self.head_dim == 0 
        self.num_query_heads = self.D // self.head_dim 
        self.group_size = group_size
        # In GQA, keys and values use fewer heads: 
        # each group of query heads will share a common key/value, so:
        self.num_kv_heads = self.num_query_heads // self.group_size 
        self.causal = causal

        self.D_kv = self.num_kv_heads * self.head_dim

        self.wq = torch.nn.Linear(D, D)
        self.wk = torch.nn.Linear(D, self.D_kv)
        self.wv = torch.nn.Linear(D, self.D_kv)
        self.wo = torch.nn.Linear(D, D)


    def forward(self, x):  # Input x: [B, S, D] -> Output: [B, S, D]
        B, S, D = x.shape
        q = self.wq(x)  # [B, S, D]
        k = self.wk(x)  # [B, S, D_kv]
        v = self.wv(x)  # [B, S, D_kv]

        # Reshape queries to have all query heads.
        q = q.reshape(B, self.num_query_heads, S, self.head_dim)
        k = k.reshape(B, self.num_kv_heads, S, self.head_dim)
        v = v.reshape(B, self.num_kv_heads, S, self.head_dim)

        # reshape kv to match q by interleaving
        k = torch.repeat_interleave(k, self.group_size, dim=1)
        v = torch.repeat_interleave(v, self.group_size, dim=1)

        # Compute scaled dot-product attention.
        normalize = torch.sqrt(torch.tensor(self.head_dim, dtype=q.dtype, device=q.device))
        logits = torch.einsum('bnij,bnkj->bnik', q, k) / normalize  # [B, num_query_heads, S, S]

        # Apply causal mask if needed.
        if self.causal:
            mask = torch.triu(torch.ones_like(logits), diagonal=1).bool()
            logits = logits.masked_fill(mask, float('-inf'))

        A = torch.nn.functional.softmax(logits, dim=-1)
        preout = torch.einsum('bnij,bnjd->bnid', A, v)  # [B, num_query_heads, S, head_dim]
        preout = preout.reshape(B, S, -1)  # Concatenate heads: [B, S, D]
        return self.wo(preout)

import matplotlib.pyplot as plt
import time

# Test different sequence lengths
seq_lengths = [128, 256, 512, 1024, 2048, 4096, 8192]
B, D = 1, 1024  # Keep batch size and dimension fixed

gqa_times = []
mhsa_times = []

for S in seq_lengths:
    # Create inputs
    x = torch.randn(B, S, D, device='cuda')
    
    # Test GQA
    attn_gqa = gqa(D=D).cuda()
    
    # Warmup
    for _ in range(5):
        _ = attn_gqa(x)
    
    # Measure GQA time
    start = time.perf_counter()
    for _ in range(10):
        _ = attn_gqa(x)
    torch.cuda.synchronize()
    gqa_times.append((time.perf_counter() - start) / 10)

    # Test MHSA 
    attn_mhsa = mhsa(D=D).cuda()
    
    # Warmup
    for _ in range(5):
        _ = attn_mhsa(x)
    
    # Measure MHSA time
    start = time.perf_counter()
    for _ in range(10):
        _ = attn_mhsa(x)
    torch.cuda.synchronize()
    mhsa_times.append((time.perf_counter() - start) / 10)

# Plot results
x = np.arange(len(seq_lengths))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(x - width/2, gqa_times, width, label='GQA')
ax.bar(x + width/2, mhsa_times, width, label='MHSA')

ax.set_ylabel('Latency (seconds)')
ax.set_xlabel('Sequence Length')
ax.set_title('GQA vs MHSA Latency Comparison')
ax.set_xticks(x)
ax.set_xticklabels(seq_lengths)
ax.legend()

# plt.yscale('log')
plt.show()
