In [None]:
'''
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
(https://arxiv.org/pdf/2305.13245)

Grouped-Query Attention is a memory efficient variant of Multi-Head Attention.
Idea is to save memory by not creating a unique Key and Value for every single Query head. 
Instead, we just create a few K/V heads and have groups of Query heads share them. This is
especially useful for longer sequences.

The key difference from MHA is that the number of key/value heads is a hyperparameter, 
which must be a divisor of the number of query heads.
'''

In [1]:
import torch
import torch.nn as nn

torch.manual_seed(42)

sentence = 'If something is humanly possible, it is attainable by you too'

dict = {s:i for i,s in enumerate(sorted(list(set(sentence.replace(',', '').split()))))}

sentence_ids = torch.tensor([dict[s] for s in sentence.replace(',', '').split()])

print(f"Vocabulary: {dict}")
print(f"Sentence IDs: {sentence_ids}")

Vocabulary: {'If': 0, 'attainable': 1, 'by': 2, 'humanly': 3, 'is': 4, 'it': 5, 'possible': 6, 'something': 7, 'too': 8, 'you': 9}
Sentence IDs: tensor([0, 7, 4, 3, 6, 5, 4, 1, 2, 9, 8])


In [2]:
vocab_size = len(dict)
embed_dim = 32     
num_q_heads = 8    
num_kv_heads = 2  

embedding_layer = nn.Embedding(vocab_size, embed_dim)
embedded_input = embedding_layer(sentence_ids).unsqueeze(0) # Shape: [1, sequence length, embedding dim]

print(f"Input shape: {embedded_input.shape}")

Input shape: torch.Size([1, 11, 32])


In [3]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_q_heads, num_kv_heads):
        super().__init__()
        
        # number of query heads should be a multiple of the KV heads
        assert num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads"
        
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_q_heads
        self.num_groups = num_q_heads // num_kv_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim)
        
        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        # Reshape Q, K, V to separate heads
        # Q,K,V shapes: [batch_size, num_q_heads, seq_len, head_dim]
        q_heads = q.reshape(batch_size, seq_len, self.num_q_heads, self.head_dim).transpose(1, 2)
        k_heads = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v_heads = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # Repeat K and V heads to match Q heads
        k_heads = k_heads.repeat_interleave(self.num_groups, dim=1)
        v_heads = v_heads.repeat_interleave(self.num_groups, dim=1)
        
        # Perform scaled dot-product attention
        attn_scores = torch.matmul(q_heads, k_heads.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        context = torch.matmul(attn_weights, v_heads)
        
        # Concatenate heads and apply final projection
        # Reshape context to [batch_size, seq_len, d_model]
        context = context.transpose(1, 2).reshape(batch_size, seq_len, d_model)
        output = self.o_proj(context)
        
        return output

In [4]:
gqa = GroupedQueryAttention(
    embed_dim=embed_dim, 
    num_q_heads=num_q_heads, 
    num_kv_heads=num_kv_heads
)

output = gqa(embedded_input)

print(f"Output shape: {output.shape}")

Output shape: torch.Size([1, 11, 32])
