# 动手实现LLaMA2

## 1. 超参数定义
自定义一个ModelConfig类，继承自transformers库中的PretrainedConfig类，用来存储和记录超参数

In [5]:
from transformers import PretrainedConfig

class ModelConfig(PretrainedConfig):
    model_type="Tiny-K"
    def __init__(
        self,
        dim: int =768, #模型维度
        n_layers: int = 12, #transformer 层数
        n_heads: int = 16,  #注意力头数
        n_kv_heads: int = 8,
        vocab_size: int =6144, #词汇表大小
        hidden_dim: int = None,
        multiple_of: int = 64,
        norms_eps: float = 1e-5,
        max_seq_len: int=512, #输入的最大序列长度
        dropout: float=0.0,
        flash_atten: bool =True,
        **kwargs,
    ):
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.multiple_of = multiple_of
        self.norms_eps = norms_eps
        self.dropout = dropout
        self.flash_atten = flash_atten
        super().__init__(**kwargs)


args = ModelConfig()

## RMSNorm


$\operatorname{RMSNorm}(x)=\frac{x}{\sqrt{\frac{1}{n} \sum_{i=1}^{n} x_{i}^{2}+\epsilon}} \cdot \gamma $

In [6]:
import torch.nn as nn
import torch

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()

        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self,x):
        return x * torch.rsqrt(x.pow(2).mean(-1,keepdim=True) + self.eps)
    
    def forward(self,x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

norm  =RMSNorm(args.dim, args.norms_eps)
x = torch.randn(1,50,args.dim)
output =norm(x)
print(output.shape)

torch.Size([1, 50, 768])


## LLaMA2 Attention
在LLaMA模型中，虽然只有LLaMA-70B模型使用了分组注意力机制（GQA),但我们此处依然使用GQA来构建LLaMA Attention 模块，以节省显存占用，提升模型效率

### 1.repeat_kv
将键和值扩展到和查询的维度一样

In [7]:
def repeat_kv(x:torch.Tensor, n_rep: int) -> torch.Tensor:
    #获取输入张量的形状：批量大小，序列长度、键、值对头的数量
    bs,slen,n_kv_heads, head_dim = x.shape

    #如果重复次数为1，则不需要重复，直接返回原始张量
    if n_rep ==1:
        return x
    
    #对张量进行扩展和重塑操作以重复键值对
    return (
        x[:,:,:,None,:]
        .expands(bs,slen,n_kv_heads,n_rep,head_dim)
        .reshape(bs,slen, n_kv_heads * n_rep, head_dim)
    )


### 2. 旋转嵌入
旋转嵌入可以为注意力机制提供更强的上下文

In [8]:
def precompute_freqs_cis(dim: int, end:int, theta: float=10000.0):

    #生成频率序列
    freqs = 1.0 / (theta ** (torch.arange(0,dim, 2)[:(dim //2)].float() /dim))
    #生成时间序列
    t = torch.arange(end, device=freqs.device)
    #计算频率的外积
    freqs = torch.outer(t,freqs).float()
    #计算频率的余弦值，得到实部
    freqs_cos = torch.cos(freqs)
    #计算频率的正弦值，得到虚部
    freqs_sin =torch.sin(freqs)

    return freqs_cos,freqs_sin


In [17]:
def reshape_for_broadcast(freqs_cis:torch.Tensor, x:torch.Tensor):
    '''
    调整张量freqs_cls的形状，使其在进行广播操作时与x的维度对齐
    '''
    #获取x的维度
    ndim = x.ndim

    assert 0<=1<ndim

    #确保freqs_cis的形状与x的第二维和最后一维形同
    assert freqs_cis.shape == (x.shape[1],x.shape[-1])

    #构造一个新的形状，除了第二维和最后一维，其他维度都为1，这样做是为了能够将freqs_cis与x进行广播对齐
    shape = [d if i==1 or i ==ndim -1 else 1 for i,d in enumerate(x.shape)]
    
    #将freqs_cis调整为新的形状，并返回
    return freqs_cis.view(shape)


In [15]:
from typing import Tuple

def apply_rotary_emb(
        xq: torch.Tensor,
        xk: torch.Tensor,
        freqs_cos: torch.Tensor,
        freqs_sin: torch.Tensor
)->Tuple[torch.Tensor,torch.Tensor]:
    '''实现旋转嵌入'''

    #将查询和键张量转换为浮点数，并重塑形状以分离实部和虚部
    xq_r,xq_i = xq.float().reshape(xq.shape[:-1] + (-1,2)).unbind(-1)
    xk_r,xk_i = xk.float().reshape(xk.shape[:-1] + (-1,2)).unbind(-1)

    #重新塑性频率张量已进行广播
    freqs_cos = reshape_for_broadcast(freqs_cos,xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin,xq_r)

    #应用旋转，分别计算旋转后的实部和虚部
    xq_out_r  =xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    #将最后两个维度合并，并还原为原始张量的形状
    xq_out = torch.stack([xq_out_r,xq_out_i],dim = -1).flatten(3)
    xk_out = torch.stack([xk_out_r,xk_out_i],dim = -1).flatten(3)

    return xq_out.type_as(xq),xk_out.type_as(xk)

In [25]:
xq = torch.randn(1,50,6,48) # bs, seq_len, dim//n_head, n_head_dim
xk = torch.randn(1,50,6,48)

cos,sin = precompute_freqs_cis(288//6,50)
print(cos.shape, sin.shape)
xq_out,xk_out  =apply_rotary_emb(xq,xk,cos,sin)

xq_out.shape,xk_out.shape


torch.Size([50, 24]) torch.Size([50, 24])


(torch.Size([1, 50, 6, 48]), torch.Size([1, 50, 6, 48]))

### 组装LLaMA2 Attention

In [None]:
import math
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self,args:ModelConfig):
        super().__init__()
        #根据是否指定n_kv_heads,确定用于键和值的头的数量
        self.n_kv_heads  =args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        #确保总头数可以被键值头数整除
        assert args.n_heads % self.n_kv_heads == 0

        #模型并行处理大小，默认为1
        model_parallel_size = 1
        #本地计算头数，等于总头数除以模型并行处理大小
        self.n_local_heads  =args.n_heads //model_parallel_size
        #本地键值头数，等于键值头数除以模型并行处理大小
        self.n_local_kv_heads = self.n_kv_heads //model_parallel_size
        #重复次数，用于扩展键和值的尺寸
        self.n_rep  =self.n_local_heads //self.n_local_kv_heads
        #每个头的维度，等于模型维度除以头的总数
        self.head_dim  =args.dim // args.n_heads

        # 定义权重矩阵
        self.wq  =nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim,bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim,bias=False)

        #输出权重矩阵
        self.wo = nn.Linear(args.n_heads * self.head_dim,args.dim,bias=False)

        #定义dropout
        self.attn_dropout  =nn.Dropout(args.dropout)
        self.resid_droput = nn.Dropout(args.dropout)

        #保存dropout概率
        self.dropout = args.dropout

        #检查是否使用Flash Attenion(需要Pytorch >=2.0)
        self.flash  =hasattr(torch.nn.functional,'scaled_dot_product_attention')

        if not self.flash:
            #若不支持Flash Attention,则使用手动实现的注意力机制，并设置mask
            print("Warning: using slow attention. Flash Attention requires PyTorch >= 2.0")
            #创建一个上三角矩阵，用于遮蔽未来信息
            mask = torch.full((1,1,args.max_seq_len,args.max_seq_len),float("-inf"))
            mask= torch.triu(mask,diagonal=1)

            # 注册为模型的缓冲区
            self.register_buffer("mask",mask)

    def forward(self,x:torch.Tensor, freqs_cos:torch.Tensor,freqs_sin:torch.Tensor):
        # 获取批次大小和序列长度，【batch_size,seq_len,dim]
        bsz,seqlen,_ = x.shape

        #计算Q\K\V
        xq,xk,xv = self.wq(x),self.wk(x),self.wv(x)

        #调整形状以适应头的维度
        xq = xq.view(bsz,seqlen,self.n_local_heads,self.head_dim)
        xk = xk.view(bsz,seqlen,self.n_local_kv_heads,self.head_dim)
        xv = xv.view(bsz,seqlen,self.n_local_kv_heads,self.head_dim)

        #应用旋转位置编码
        xq,xk =apply_rotary_emb(xq,xk,freqs_cos,freqs_sin)

        #对键和值进行扩展以适应重复次数
        xk  =repeat_kv(xk,self.n_rep)
        xv  =repeat_kv(xv,self.n_rep)

        xq = xq.transpose(1,2)
        xk = xk.transpose(1,2)
        xv = xv.transpose(1,2)

        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(
                xq,
                xk,
                xv,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0.0,
                is_causal=True
            )
        else:
            #使用手动实现的注意力机制
            scores = torch.matmul(xq,xk.transpose(2,3)) / math.sqrt(self.head_dim)
            assert hasattr(self,"mask")
            scores =scores + self.mask[:,:,:seqlen,:seqlen]
            scores = F.softmax(scores.float(), dim = -1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores,xv)

        output  =output.transpose(1,2).contiguous().view(bsz,seqlen,-1)

        output = self.wo(output)
        output  =self.resid_droput(output)
        return output


attention_model = Attention(args)
batch_size =1
seq_len = 50
dim  =args.dim
x = torch.rand(batch_size,seq_len,dim) #随机生成输入张量

freqs_cos,freqs_sin = precompute_freqs_cis(dim // args.n_heads, seq_len)
output  =attention_model(x, freqs_cos,freqs_sin)

print("Output shape:",output.shape)


AttributeError: 'Tensor' object has no attribute 'expands'