In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention (GQA) implementation.

    GQA is an optimization technique that reduces computational complexity
    by using fewer query heads than key/value heads while maintaining performance.
    This allows for more efficient inference in large language models.
    """

    def __init__(self, d_model, num_query_heads, num_kv_heads, dropout=0.1):
        """
        Initialize the Grouped Query Attention module.

        Args:
            d_model (int): The dimension of the model (embedding dimension)
            num_query_heads (int): Number of query attention heads
            num_kv_heads (int): Number of key/value attention heads
            dropout (float): Dropout probability
        """
        super(GroupedQueryAttention, self).__init__()

        # Ensure d_model is divisible by both the number of query and kv heads
        assert d_model % num_query_heads == 0, "d_model must be divisible by num_query_heads"
        assert d_model % num_kv_heads == 0, "d_model must be divisible by num_kv_heads"
        # Ensure num_query_heads is a multiple of num_kv_heads
        assert num_query_heads % num_kv_heads == 0, "num_query_heads must be a multiple of num_kv_heads"

        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.kv_groups = num_query_heads // num_kv_heads  # How many query heads share a single kv head

        self.d_qk = d_model // num_query_heads  # Dimension per query head
        self.d_v = d_model // num_kv_heads      # Dimension per value head

        # Linear projections for Q, K, V
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, (d_model // num_query_heads) * num_kv_heads)
        self.value_proj = nn.Linear(d_model, (d_model // num_query_heads) * num_kv_heads)

        self.output_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        Forward pass for Grouped Query Attention

        Args:
            query: Query tensor (batch_size, seq_len_q, d_model)
            key: Key tensor (batch_size, seq_len_k, d_model)
            value: Value tensor (batch_size, seq_len_v, d_model)
            mask: Optional mask tensor for masked attention

        Returns:
            output: Attention output (batch_size, seq_len_q, d_model)
            attention_weights: Attention weights
        """
        batch_size = query.size(0)
        seq_len_q = query.size(1)
        seq_len_k = key.size(1)

        # Linear projections
        # Q: (batch_size, seq_len_q, d_model)
        q = self.query_proj(query)
        # K, V: (batch_size, seq_len_k, (d_model // num_query_heads) * num_kv_heads)
        k = self.key_proj(key)
        v = self.value_proj(value)

        # Reshape for attention computation
        # q: (batch_size, num_query_heads, seq_len_q, d_qk)
        q = q.view(batch_size, seq_len_q, self.num_query_heads, self.d_qk).transpose(1, 2)
        # k, v: (batch_size, num_kv_heads, seq_len_k, d_qk)
        k = k.view(batch_size, seq_len_k, self.num_kv_heads, self.d_qk).transpose(1, 2)
        v = v.view(batch_size, seq_len_k, self.num_kv_heads, self.d_qk).transpose(1, 2)

        # Expand K, V to match the number of query heads
        # This implements the "grouping" where multiple query heads share the same key-value head
        # k, v: (batch_size, num_query_heads, seq_len_k, d_qk)
        k = torch.repeat_interleave(k, self.kv_groups, dim=1)
        v = torch.repeat_interleave(v, self.kv_groups, dim=1)

        # Compute attention scores
        # (batch_size, num_query_heads, seq_len_q, seq_len_k)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_qk)

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention weights to values
        # (batch_size, num_query_heads, seq_len_q, d_qk)
        context = torch.matmul(attention_weights, v)

        # Reshape back to original dimensions
        # (batch_size, seq_len_q, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)

        # Final linear projection
        output = self.output_proj(context)

        return output, attention_weights

In [3]:
# Example usage and testing
def test_grouped_query_attention():
    batch_size = 2
    seq_len = 10
    d_model = 512
    num_query_heads = 8
    num_kv_heads = 2  # Each KV head is shared across 4 query heads

    # Create random input tensors
    query = torch.randn(batch_size, seq_len, d_model)
    key = torch.randn(batch_size, seq_len, d_model)
    value = torch.randn(batch_size, seq_len, d_model)

    # Create mask (for causal/self-attention)
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1).unsqueeze(0).unsqueeze(0)
    mask = (1.0 - mask).bool()  # Convert to boolean mask where 1 means keep, 0 means mask

    # Initialize GQA layer
    gqa = GroupedQueryAttention(d_model, num_query_heads, num_kv_heads)

    # Forward pass
    output, attention = gqa(query, key, value, mask)

    print(f"Input shape: {query.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {attention.shape}")
    print(f"Number of query heads: {num_query_heads}")
    print(f"Number of key/value heads: {num_kv_heads}")
    print(f"Grouping factor: {num_query_heads // num_kv_heads}")

    # Check if output maintains input dimensions
    assert output.shape == query.shape, "Output shape should match input shape"

    return output, attention


if __name__ == "__main__":
    test_grouped_query_attention()

Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Attention weights shape: torch.Size([2, 8, 10, 10])
Number of query heads: 8
Number of key/value heads: 2
Grouping factor: 4
