In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LlamaGQA(nn.Module):
    def __init__(self, dim, num_query_heads, num_kv_heads, head_dim=None):
        super().__init__()
        self.dim = dim
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        ## 注意这里是除以query_heads，一个query对应多个kv
        self.head_dim = head_dim if head_dim is not None else dim // num_query_heads
        # 每个查询头映射到特定KV头
        self.kv_groups = num_query_heads // num_kv_heads
        assert num_query_heads % num_kv_heads == 0, "num_query_heads必须被num_kv_heads整除"
        # 投影矩阵
        self.q_proj = nn.Linear(dim, self.num_query_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_query_heads * self.head_dim, dim, bias=False)
        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self, x, attention_mask=None, cache=None):
        batch_size, seq_len, _ = x.shape
        # 计算查询、键、值
        q = self.q_proj(x).view(batch_size, seq_len, self.num_query_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        # 变换维度以便计算注意力
        q = q.transpose(1, 2)  # [batch_size, num_query_heads, seq_len, head_dim]
        k = k.transpose(1, 2)  # [batch_size, num_kv_heads, seq_len, head_dim]
        v = v.transpose(1, 2)  # [batch_size, num_kv_heads, seq_len, head_dim]
        # 使用KV缓存
        if cache is not None:
            past_k, past_v = cache
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
            cache = (k, v)
        # 实现分组注意力：重复k和v以匹配查询头数
        if self.num_query_heads > self.num_kv_heads:
            k = k.repeat_interleave(self.kv_groups, dim=1)
            v = v.repeat_interleave(self.kv_groups, dim=1)
        # 计算注意力分数
        attn_scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        # 应用注意力掩码
        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask
        # 应用softmax获取注意力权重
        attn_weights = F.softmax(attn_scores, dim=-1)
        # 计算输出
        output = torch.matmul(attn_weights, v)  # [batch_size, num_query_heads, seq_len, head_dim]
        # 重塑输出并进行最终投影
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        output = self.o_proj(output)
        if cache is not None:
            return output, cache
        return output