# Llama3 GQA

![](./image/grouped-query-attention.png)

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Tuple
import math

参数设置分析

```
嵌入词向量维度为：18
有6个Q头，【单个头】的维度为18/6 = 3
对于WQ(18x18)
Q=(WQ(X))->(Q1,Q2,Q3,Q4,Q5,Q6)

有2个KV头
有2个KV头，单头维度为3，3*2 = 6
对于WK(18x6)
对于WV(18x6)
K=(WK(X）)->(K1, K2)
V=(WV(X）)->(V1, V2)


对于输出
6个头*单头维度 = 6*3 =18
输出维度为:
Wo(18x18)
```

猜想

GQA中的KV维度为6, 单个头维度为3，那么就是K1 V1复制3份， K2V2 复制3份， 再进行多头注意力计算

In [29]:
# @dataclass
class ModelArgs:
    dim: int = 18
    n_layers: int = 1
    n_heads: int = 6
    n_kv_heads: int =  2
    vocab_size: int = -1
    multiple_of: int = 10  # make SwiGLU hidden layer size multiple of large power of 2
    norm_eps: float = 1e-5
    rope_theta: float = 500000
    max_batch_size: int = 2
    max_seq_len: int = 17
    model_parallel_size = 1

In [30]:
config = ModelArgs()

In [32]:
# 特别注意这里的n_rep指的是，一个kv复制多少次
# 如果Q 6个头
# KV 2个头
# 那么n_rep为 6/2=3
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim) # 
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


k = torch.randn(1, 7, 2, 3) # batch size为1,  seqlen为7, n_kv_heads = 2, dim = 3
repeat_k = repeat_kv(k, 3)  # batch size为1,  seqlen为7, n_kv_heads = 2 * 3, dim = 3
print(repeat_k.shape)       # batch size为1,  seqlen为7, n_kv_heads = 6, dim = 3

# 我们仅需要确认，n_kv_heads 6个头里的数据变化
# 发现打印，前3个头的数值相同
repeat_k[0,0,:,0] # 

tensor([0.7358, 0.7358, 0.7358, 2.1032, 2.1032, 2.1032])

In [33]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # model_parallel_size = 2
        model_parallel_size = args.model_parallel_size # 特别的我们观察这个参数
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads # 18/6 = 3

        self.wq = nn.Linear(in_features=args.dim, out_features=args.n_heads * self.head_dim,bias=False,)
        self.wk = nn.Linear(in_features=args.dim, out_features=args.n_kv_heads * self.head_dim,bias=False,)
        self.wv = nn.Linear(in_features=args.dim, out_features=args.n_kv_heads * self.head_dim,bias=False,)
        self.wo = nn.Linear(in_features=args.n_heads * self.head_dim, out_features=args.dim,bias=False,)

        print(f'wq_shape: {self.wq.weight.shape}')
        print(f'wk_shape: {self.wk.weight.shape}')
        print(f'wv_shape: {self.wv.weight.shape}')
        print(f'wo_shape: {self.wo.weight.shape}')

        # 原Llama里面，只要创建模型就会申请KV-Cache
        # 训练里实际上不关系KV-Cache
        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        )
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        )

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # ignore RoPE

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        # 在已提前创建好的KV-Cache里，填入新的KV
        # 这里实际上应该加入training/Inference 判断
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk 
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        
        print('repeat group 前的维度')
        
        print(f'q: {xq.shape}')
        print(f'k: {keys.shape}')
        print(f'v: {values.shape}')
        

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)

        # 再check一步
        print('真正计算Attention前各维度')
        
        print(f'q: {xq.shape}')
        print(f'k: {keys.shape}')
        print(f'v: {values.shape}')
        
        
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)

        print(f'o: {output.shape}')
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        print(f'concat o: {output.shape}')
        
        return self.wo(output)
attn = Attention(config)
attn
# 符合我们一开始的假设

Attention(
  (wq): Linear(in_features=18, out_features=18, bias=False)
  (wk): Linear(in_features=18, out_features=6, bias=False)
  (wv): Linear(in_features=18, out_features=6, bias=False)
  (wo): Linear(in_features=18, out_features=18, bias=False)
)

In [34]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        # attn = self.attention(x, start_pos, freqs_cis, mask)
        h = x + self.attention(x, start_pos, freqs_cis, mask)
        
        # attn = self.attention(x, start_pos, freqs_cis, mask)
        return h

In [35]:
llama_block = TransformerBlock(1, config)

In [36]:
bsz = 1
len = 7
dim = 18
x_src = torch.randn(bsz, len, dim)
x_src.shape

torch.Size([1, 7, 18])

In [37]:
y = llama_block(x_src, start_pos = 0, freqs_cis=None, mask=None )

In [38]:
print(x_src.shape)
print(y.shape)

# 模型并行测试


In [39]:
config = ModelArgs()
config.model_parallel_size = 2 # 这里假设有2个gpu
attn_parallel = Attention(config)
print('each gpu has q heads: ', attn_parallel.n_local_heads)
print('each gpu has KV heads: ', attn_parallel.n_local_kv_heads)

此时, 进行并行分配
- GPU0: Q1,Q2,Q3,  K1 V1
- GPU1: Q4,Q5,Q6,  K2 V2

  
计算时: 
- GPU0: Q1,Q2,Q3,  K1 K1 K1 V1 V1 V1 -> o1, o2, o3
- GPU1: Q4,Q5,Q6,  K2 K2 K2 V2 V2 V2 -> o4, o5, o6


输出时:
- GPU1 -> o4, o5, o6 -> GPU0
- GPU0 o = (o1,o2,o3,o4,o5,o6)  Wo(o) -> attn_output

# 结论

1. 符合猜想，满头计算。与长文档里的原版Grouped Query Attention 实现有所不同， 此版本GQA实现更加简单，意味着不需要过多的切分和拼接。

2. 特别关注，GQA在模型并行时候的分配策略。

reference: 

https://github.com/meta-llama/llama3/blob/main/llama/model.py

https://arxiv.org/pdf/2305.13245