## Group Query Attention (GQA)

In [14]:
from torch.nn.functional import scaled_dot_product_attention
import torch
from einops import rearrange, einsum

In [4]:
# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 8, 64)
value = torch.randn(1, 256, 8, 64)

output = scaled_dot_product_attention(query, key, value)
print(output.shape) # torch.Size([1, 256, 8, 64])

torch.Size([1, 256, 8, 64])


In [7]:
# !pip install einops

In [10]:

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64)
key = torch.randn(1, 256, 2, 64)
value = torch.randn(1, 256, 2, 64)

num_head_groups = query.shape[2] // key.shape[2]
print(num_head_groups) # each group is of size 4 since there are 2 kv_heads

4


In [11]:
query = rearrange(query, "b n h d -> b h n d")
key = rearrange(key, "b s h d -> b h s d")
value = rearrange(value, "b s h d -> b h s d")

In [12]:
query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
print(query.shape) # torch.Size([1, 4, 2, 256, 64])

torch.Size([1, 4, 2, 256, 64])


In [15]:
scores = einsum(query, key, "b g h n d, b h s d -> b h n s")
print(scores.shape) # torch.Size([1, 2, 256, 256])

torch.Size([1, 2, 256, 256])


In [17]:
import torch.nn.functional as F

scale = query.size(-1) ** 0.5
attention = F.softmax(scores / scale, dim=-1)

# here we do just a standard matrix multiplication
out = einsum(attention, value, "b h n s, b h s d -> b h n d")

# finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim)
out = rearrange(out, "b h n d -> b n h d")
print(out.shape) # torch.Size([1, 256, 2, 64])

torch.Size([1, 256, 2, 64])
