In [1]:
import torch
import torch.nn as nn

class GQA(nn.Module):
    def __init__(self, d_model, n_heads, d_k, d_v, num_groups):
        super(GQA, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.num_groups = num_groups
        self.W_q = nn.Linear(d_model, n_heads * d_k)
        self.W_k = nn.Linear(d_model, n_heads * d_k)
        self.W_v = nn.Linear(d_model, n_heads * d_v)
        self.fc = nn.Linear(n_heads * d_v, d_model)

    def forward(self, Q, K, V):
        batch_size = Q.size(0)
        print("Original Q shape:", Q.shape)

        Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k)
        K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k)
        V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_v)
        print("Q shape after linear:", Q.shape)

        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        print("Q shape after transpose:", Q.shape)

        Q_groups = torch.chunk(Q, self.num_groups, dim=2)
        K_groups = torch.chunk(K, self.num_groups, dim=2)
        V_groups = torch.chunk(V, self.num_groups, dim=2)
        print("Number of Q groups:", len(Q_groups))

        context = []
        for i, (Q_group, K_group, V_group) in enumerate(zip(Q_groups, K_groups, V_groups)):
            print(f"Group {i} - Q_group shape:", Q_group.shape)
            scores = torch.matmul(Q_group, K_group.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
            attn = torch.softmax(scores, dim=-1)
            context_group = torch.matmul(attn, V_group)
            context.append(context_group)
            print(f"Group {i} - context_group shape:", context_group.shape)

        context = torch.cat(context, dim=2)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
        output = self.fc(context)
        print("Final output shape:", output.shape)
        return output

d_model = 64
n_heads = 4
d_k = d_v = 16
num_groups = 2
batch_size = 2
seq_len = 10

Q = torch.rand(batch_size, seq_len, d_model)
K = torch.rand(batch_size, seq_len, d_model)
V = torch.rand(batch_size, seq_len, d_model)

model = GQA(d_model, n_heads, d_k, d_v, num_groups)
output = model(Q, K, V)

Original Q shape: torch.Size([2, 10, 64])
Q shape after linear: torch.Size([2, 10, 4, 16])
Q shape after transpose: torch.Size([2, 4, 10, 16])
Number of Q groups: 2
Group 0 - Q_group shape: torch.Size([2, 4, 5, 16])
Group 0 - context_group shape: torch.Size([2, 4, 5, 16])
Group 1 - Q_group shape: torch.Size([2, 4, 5, 16])
Group 1 - context_group shape: torch.Size([2, 4, 5, 16])
Final output shape: torch.Size([2, 10, 64])
