# Grouped Query Attention

标准的多头注意力机制提供了强大的序列建模能力，其中“多头”能够捕捉多个角度的注意力关系。

1. 对于当前尺寸模型，头数有64头/128头级别，多头注意力关系是否有冗余？
2. KV-Cache 在 inference 阶段占用较多存储。 `bsz, seq_len, n_heads, head_dim, bit`, 如果在 `n_heads` 减少维度是否能保留原有精度？

对于标准的注意力机制是多头 q，k，v 计算注意力，考虑从头数优化。

## Multi Heads Attention

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

class MultiHeadsAttention(nn.Module):
    def __init__(self, dim = 512, n_heads = 8):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = self.dim // self.n_heads
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, dim)
        self.wv = nn.Linear(dim, dim)
        self.wo = nn.Linear(dim, dim)
        
    def forward(self, x, mask = None, verbose = False):
        bsz, seq_len, dim = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # split
        q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        k = k.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        v = v.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        
        s = q@k.transpose(3,2) / math.sqrt(self.dim)
        if mask is not None:
            s = s + mask.unsqueeze(0).unsqueeze(0)
        p = F.softmax(s, dim = -1)
        z = p @ v

        # cat
        z = z.transpose(1,2).reshape(bsz, seq_len, self.dim)
        
        return self.wo(z)

bsz = 2
seq_len = 8
dim = 512
n_heads = 8
MHA = MultiHeadsAttention(dim = dim, n_heads = n_heads)
x = torch.randn(bsz, seq_len, dim)
x_mha = MHA(x)
x_mha.shape

torch.Size([2, 8, 512])

## Multi Query Attention

1. 多query, 单kv
2. 单kv share 成多 kv

In [2]:
class MultiQueryAttention(nn.Module):
    def __init__(self, dim = 512, n_heads = 8):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = self.dim // self.n_heads
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, self.head_dim) # 单头
        self.wv = nn.Linear(dim, self.head_dim) # 单头
        self.wo = nn.Linear(dim, dim)
        
    def forward(self, x, mask = None, verbose = False):
        bsz, seq_len, dim = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # split
        q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        # k = k.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        # v = v.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        k = k[:, None, :, :]
        v = v[:, None, :, :]
        if verbose:
            print('kshape', k.shape)
        
        s = q@k.transpose(3,2) / math.sqrt(self.dim)
        if mask is not None:
            s = s + mask.unsqueeze(0).unsqueeze(0)
        p = F.softmax(s, dim = -1)
        z = p @ v

        # cat
        z = z.transpose(1,2).reshape(bsz, seq_len, self.dim)
        
        return self.wo(z)

MQA = MultiQueryAttention(dim = dim, n_heads = n_heads)
# x = torch.randn(bsz, seq_len, dim)
x_mqa = MQA(x, verbose=True)
x_mqa.shape

kshape torch.Size([2, 1, 8, 64])


torch.Size([2, 8, 512])

## Group Query Attention

分组 share kv

In [3]:
a = torch.tensor([1,2,3,4])
a_share = torch.repeat_interleave(a, 2, dim=0)
print(a_share)

tensor([1, 1, 2, 2, 3, 3, 4, 4])


In [4]:
class GroupQueryAttention(nn.Module):
    def __init__(self, dim = 512, n_heads = 8, n_kv_heads = 2):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads 
        self.head_dim = self.dim // self.n_heads
        self.share_heads = self.n_heads // self.n_kv_heads
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, self.head_dim * self.n_kv_heads) # grouped share k
        self.wv = nn.Linear(dim, self.head_dim * self.n_kv_heads) # grouped share v 
        self.wo = nn.Linear(dim, dim)
        
    def forward(self, x, mask = None, verbose = False):
        bsz, seq_len, dim = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # split
        q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        k = k.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1,2)
        v = v.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1,2)
        k = torch.repeat_interleave(k, self.share_heads, dim=1)
        v = torch.repeat_interleave(v, self.share_heads, dim=1)
        
        if verbose:
            print('kshape', k.shape)
            print(k[0, 0, 0, :5])
            print(k[0, 1, 0, :5])
            print(k[0, 2, 0, :5])
            print(k[0, 3, 0, :5])
        
        s = q@k.transpose(3,2) / math.sqrt(self.dim)
        if mask is not None:
            s = s + mask.unsqueeze(0).unsqueeze(0)
        p = F.softmax(s, dim = -1)
        z = p @ v

        # cat
        z = z.transpose(1,2).reshape(bsz, seq_len, self.dim)
        
        return self.wo(z)

n_kv_heads = 4
GQA = GroupQueryAttention(dim = dim, n_heads = n_heads, n_kv_heads = n_kv_heads)
# x = torch.randn(bsz, seq_len, dim)
x_gqa = GQA(x, verbose=True)
x_gqa.shape

kshape torch.Size([2, 8, 8, 64])
tensor([-0.9566, -0.9000,  0.2131,  0.0966, -0.8631], grad_fn=<SliceBackward0>)
tensor([-0.9566, -0.9000,  0.2131,  0.0966, -0.8631], grad_fn=<SliceBackward0>)
tensor([-0.4187,  0.6545, -0.2934, -0.4444, -0.0613], grad_fn=<SliceBackward0>)
tensor([-0.4187,  0.6545, -0.2934, -0.4444, -0.0613], grad_fn=<SliceBackward0>)


torch.Size([2, 8, 512])

## 讨论

1. 为什么 MQA/GQA 能 work？ 高维特征冗余或注意力关系冗余
2. MQA/GQA 是否减少了计算量？ 1. 从投影角度减少 2. GQA 增加repeat计算（kernel可优化） 3.在注意力分数上，与标准的 MHA 计算量等同
3. MQA 和 GQA 减少了多少的 KV-Cache？ q头数/kv头数
4. 如果将 KV 变换视为一种压缩手段，压缩的维度极限在哪里？究竟应该如何压缩？ （ MLA优化）
5. 从 SRAM 视角分析 MQA/GQA 加速？（从HBM 加载单头KV 到 SRAM，并在 SRAM 中将单头KV share到多个线程中，避免copy）
6. 考虑分布式注意力计算情况， 如何部署参数和计算策略？*
7. 考虑多层 attention 计算过程，虽然 GQA/MQA 会拓展头数（KV Cache），其 block 内随扩展随消除，对整体 KV-cache 量影响不大