动手实现LLaMA2大模型

In [22]:
import torch
import torch.nn as nn
from typing import Tuple
import math
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast

In [2]:
from transformers import PretrainedConfig#通过继承这个类来方便使用transformers库中的一些功能，也方便后续导入Hugging Face功能

class ModelConfig(PretrainedConfig):
    model_type="Tiny-K"
    def __init__(
        self,
        dim:int=768,#模型维度
        n_layers:int=12,#Tranformer的层数
        n_heads:int=16,#注意力机制的头数
        n_kv_heads:int=8,#键值头的数量
        vocab_size:int=6144,#词汇表的大小
        hidden_dim:int=None,#隐藏层的维度
        multiple_of:int=64,#模型维度必须为64的倍数
        norm_eps:float=1e-5,#归一化层的epsilon参数
        max_seq_len:int=512,#最大序列长度
        dropout:float=0.0,#dropout的概率
        flash_attn:bool=True,#是否使用Flash Attention
        **kwargs,):
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.multiple_of = multiple_of
        self.norm_eps = norm_eps
        self.max_seq_len = max_seq_len
        self.dropout = dropout
        self.flash_attn = flash_attn
        super().__init__(**kwargs)
#后面会根据这些超参数来构建我们的模型

构建RMSNorm

In [24]:
class RMSNorm(nn.Module):
    def __init__(self,dim:int,eps:float):
        super().__init__()
        self.eps=eps
        self.weight=nn.Parameter(torch.ones(dim))

    def forward(self,x):
        return x*self.weight/(torch.rsqrt(x.pow(2).mean(dim=-1,keepdim=True)+self.eps))

In [25]:
args=ModelConfig()
norm=RMSNorm(args.dim,args.norm_eps)

构建LLaMA2Attention，GQA来构建，分组查询注意力机制

首先，在LLaMA2模型中，我们需要将键和值的维度扩展到和查询的维度一样，这样才能进行注意力机制，也就是要实现repeat_kv

In [26]:
def repeat_kv(x:torch.Tensor,n_rep:int)->torch.Tensor:
    #获得输入张量的形状：批次大小、序列长度、键/值对头的数量、每个头的维度
    bs,slen,n_kv_heads,head_dim=x.shape
    if n_rep==1:
        return x
    return (x[:,:,:,None,:].expand(bs,slen,n_kv_heads,n_rep,head_dim).reshape(bs,slen,n_kv_heads*n_rep,head_dim))

旋转嵌入，可以为注意力机制提供更强的上下文信息

In [27]:
def precompute_freqs_cis(dim:int,end:int,theta:float=10000.0):
    #torch.arange(0,dim,2)[:(dim//2)].float()生成了一个从0开始，步长为2的序列，长度为dim的一半
    freqs=1.0/(theta**(torch.arange(0,dim,2)[:(dim//2)].float()/dim))
    #生成一个从0到end的序列，长度为end
    t=torch.arange(end,device=freqs.device)
    #计算外积，得到一个二维矩阵，每一行是t的元素乘以freps的元素
    freqs=torch.outer(t,freqs).float()
    #计算频率的余弦值，得到实部
    freqs_cos=torch.cos(freqs)
    #计算频率的正弦值，得到虚部
    freqs_sin=torch.sin(freqs)
    return freqs_cos,freqs_sin

调整张量的形状,使其在进行广播操作的时候与x的维度对其，从而能进行正确的张量计算

In [32]:
def reshape_for_broadcast(freqs_cis:torch.Tensor,x:torch.Tensor):
    #获取x的维度
    ndim=x.ndim
    #断言，确保1在x的维度范围内
    assert 0<=1<ndim
    #断言，确保freqs_cis的形状与x的第二维和最后一维相同
    assert freqs_cis.shape==(x.shape[1],x.shape[-1]//2), \
        f"Shape mismatch: {freqs_cis.shape} vs ({x.shape[1]}, {x.shape[-1]//2})"
    shape=list(x.shape)
    shape[0]=1
    shape[2]=1
    return freqs_cis.view(shape[0],shape[1],shape[2],shape[3]//2)

In [33]:
def apply_rotary_emb(
        xq:torch.Tensor,
        xk:torch.Tensor,
        freqs_cos:torch.Tensor,
        freqs_sin:torch.Tensor,
)->Tuple[torch.Tensor,torch.Tensor]:
    #将查询的和键的张量转换为浮点数，并重塑形状以分离实部和虚部
    xq_r,xq_i=xq.float().reshape(xq.shape[:-1]+(-1,2)).unbind(-1)
    xk_r,xk_i=xk.float().reshape(xk.shape[:-1]+(-1,2)).unbind(-1)
    #计算查询和键的旋转嵌入
    freqs_cos,freqs_sin=reshape_for_broadcast(freqs_cos,xq),reshape_for_broadcast(freqs_sin,xq)
    #应用旋转，分别计算旋转后的实部和虚部
    xq_out_r=xq_r*freqs_cos-xq_i*freqs_sin
    xq_out_i=xq_r*freqs_sin+xq_i*freqs_cos
    xk_out_r=xk_r*freqs_cos-xk_i*freqs_sin
    xk_out_i=xk_r*freqs_sin+xk_i*freqs_cos
    #将实部和虚部重新组合成复数张量，并重塑为原始形状
    xq_out=torch.stack([xq_out_r,xq_out_i],dim=-1).flatten(3)
    xk_out=torch.stack([xk_out_r,xk_out_i],dim=-1).flatten(3)

    return xq_out.type_as(xq),xk_out.type_as(xk)

上面完成了旋转嵌入的实现，现在可以来构建Attention模块了

In [30]:
class Attention(nn.Module):
    def __init__(self,args:ModelConfig):
        super().__init__()
        #根据是否指定n_kv_heads，确定用于键和值的头的数量
        self.n_kv_heads=args.n_kv_head if args.n_kv_heads is None else args.n_kv_heads
        #确保总头数可以被键值头数整除
        assert args.n_heads%self.n_kv_heads==0
        #模型并行处理大小，默认为1
        model_parallel_size=1
        #本地计算头数
        self.n_local_heads=args.n_heads//model_parallel_size
        #本地键值头数
        self.n_local_kv_heads=self.n_kv_heads//model_parallel_size
        #重复次数
        self.n_rep=self.n_local_heads//self.n_local_kv_heads
        #每个头的维度
        self.head_dim=args.dim//args.n_heads
        #定义权重矩阵
        self.wq=nn.Linear(args.dim,args.n_heads*self.head_dim,bias=False)
        self.wk=nn.Linear(args.dim,self.n_kv_heads*self.head_dim,bias=False)
        self.wv=nn.Linear(args.dim,self.n_kv_heads*self.head_dim,bias=False)
        self.wo=nn.Linear(args.n_heads*self.head_dim,args.dim,bias=False)
        #定义dropout层
        self.attn_dropout=nn.Dropout(args.dropout)
        self.resid_dropout=nn.Dropout(args.dropout)
        self.dropout=args.dropout

        self.flash=hasattr(torch.nn.functional,'scaled_dot_product_attention')
        if not self.flash:
            print("Warning: using the slow attention implementation. Install PyTorch nightly to get the fast one.")

            mask=torch.full((1,1,args.max_seq_len,args.max_seq_len),float('-inf'))
            mask=torch.triu(mask,diagonal=1)
            self.register_buffer('mask',mask)

    def forward(self,x:torch.Tensor,freqs_cos:torch.Tensor,freqs_sin:torch.Tensor):
        baz,seqlen,_=x.shape
        xq,xk,xv=self.wq(x),self.wk(x),self.wv(x)
        xq=xq.view(baz,seqlen,self.n_local_heads,self.head_dim)
        xk=xk.view(baz,seqlen,self.n_local_kv_heads,self.head_dim)
        xv=xv.view(baz,seqlen,self.n_local_kv_heads,self.head_dim)
        xq,xk=apply_rotary_emb(xq,xk,freqs_cos,freqs_sin)
        xk=repeat_kv(xk,self.n_rep)
        xv=repeat_kv(xv,self.n_rep)

        xq=xq.transpose(1,2)
        xk=xk.transpose(1,2)
        xv=xv.transpose(1,2)

        if self.flash:
            output=torch.nn.functional.scaled_dot_product_attention(xq,xk,xv,attn_mask=None,dropout_p=self.dropout if self.training else 0.0,is_causal=True)
        else:
            scores=torch.matmul(xq,xk.transpose(2,3))/math.sqrt(self.head_dim)
            assert hasattr(self,"mask")
            scores=scores+self.mask[:,:,:seq_len,:seq_len]
            scores=F.softmax(scores.float(),dim=-1).type_as(xq)
            scores=self.attn_dropout(scores)
            output=torch.matmul(scores,xv)
        output=output.transpose(1,2).contiguous().view(baz,seqlen,-1)
        output=self.wo(output)
        output=self.resid_dropout(output)
        return output

In [34]:
attention_model=Attention(ModelConfig())
batch_size=1
seq_len=50
dim=args.dim
x=torch.rand(batch_size,seq_len,dim)
freqs_cos,freqs_sin=precompute_freqs_cis(dim//args.n_heads,seq_len)
output=attention_model(x,freqs_cos,freqs_sin)
print(output.shape)

torch.Size([1, 50, 768])


In [11]:
class MLP(nn.Module):
    def __init__(self,dim:int,hidden_dim:int,multilpe_of:int,dropout:float):
        super().__init__()
        if hidden_dim is None:
            hidden_dim=dim*4
            hidden_dim=int(2*hidden_dim/3)
            hidden_dim=multilpe_of*((hidden_dim+multilpe_of-1)//multilpe_of)

        self.w1=nn.Linear(dim,hidden_dim,bias=False)
        self.w2=nn.Linear(hidden_dim,dim,bias=False)
        self.w3=nn.Linear(dim,hidden_dim,bias=False)
        self.dropout=nn.Dropout(dropout)

    def forward(self,x):
        return self.dropout(self.w2(F.gelu(self.w1(x))+self.w3(x)))

In [12]:
mlp=MLP(args.dim,args.hidden_dim,args.multiple_of,args.dropout)
x=torch.randn(1,50,args.dim)
output=mlp(x)
print(output.shape)

torch.Size([1, 50, 768])


构建LlaMA2的Decoder Layer,把我们的Attention模块和MLP模块组合在一起，实现一个完整的transformer模块


In [39]:
class DecoderLayer(nn.Module):
    def __init__(self,layer_id:int,args:ModelConfig):
        super().__init__()
        self.n_heads=args.n_heads
        self.dim=args.dim
        self.head_dim=self.dim//self.n_heads
        self.attention=Attention(args)
        self.feed_forward=MLP(
            dim=args.dim,
            hidden_dim=args.hidden_dim,
            multilpe_of=args.multiple_of,
            dropout=args.dropout
        )
        self.layer_id=layer_id
        self.attention_norm=RMSNorm(args.dim,eps=args.norm_eps)
        self.ffn_norm=RMSNorm(args.dim,eps=args.norm_eps)
    def forward(self,x,freqs_cos,freqs_sin):
        h=x+self.attention.forward(self.attention_norm(x),freqs_cos,freqs_sin)
        out=h+self.feed_forward.forward(self.ffn_norm(h))
        return out

In [40]:
decoder_layer=DecoderLayer(0,args)
dim=args.dim
seq_len=50
x=torch.randn(1,seq_len,dim)
freqs_cos,freqs_sin=precompute_freqs_cis(dim//args.n_heads,seq_len)
out=decoder_layer(x,freqs_cos,freqs_sin)
print(out.shape)

torch.Size([1, 50, 768])


In [49]:
from typing import Optional, List

class Transformer(nn.Module):
    config_class=ModelConfig
    last_loss:Optional[torch.Tensor]

    def __init__(self,args: ModelConfig=None):
        super().__init__()
        self.args=args
        self.dim=args.dim
        self.vocab_size=args.vocab_size
        self.n_layers=args.n_layers
        self.tok_embedding=nn.Embedding(self.vocab_size, self.dim)
        self.dropout=nn.Dropout(args.dropout)
        self.layers=torch.nn.ModuleList()
        for layer_id in range(self.n_layers):
            self.layers.append(DecoderLayer(layer_id,args))
        self.norm=RMSNorm(self.dim, eps=args.norm_eps)
        self.output=nn.Linear(self.dim, self.vocab_size, bias=False)
        self.tok_embedding.weight=self.output.weight
        freqs_cos,freqs_sin=precompute_freqs_cis(
            self.dim//args.n_heads,
            args.max_seq_len
        )
        self.register_buffer("freqs_cos",freqs_cos,persistent=False)
        self.register_buffer("freqs_sin",freqs_sin,persistent=False)
        self.apply(self._init_weights)
        for pn,p in self.named_parameters():
            if pn.endswith("w3.weight") or pn.endswith("wo.weight"):
                torch.nn.init.normal_(
                    p,mean=0.0,std=0.02/math.sqrt(2*args.n_layers)
                )
        self.last_loss=None
        self.OUT=CausalLMOutputWithPast()
        self._no_split_modules=[name for name, _ in self.named_modules()]

    def _init_weights(self,module):
        """初始化权重"""
        if isinstance(module,nn.Linear):
            torch.nn.init.normal_(module.weight,mean=0.0,std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module,nn.Embedding):
            torch.nn.init.normal_(module.weight,mean=0.0,std=0.02)

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None,  **keyargs):
        output={}
        if "input_ids" in keyargs:
            tokens=keyargs["input_ids"]
        if "attention_mask" in keyargs:
            targets=keyargs["attention_mask"]
        _bsz,_seq_len=tokens.shape
        h=self.tok_embedding(tokens)
        h=self.dropout(h)
        freqs_cos=self.freqs_cos[:seq_len]
        freqs_sin=self.freqs_sin[:seq_len]
        for layer in self.layers:
            h=layer(h,freqs_cos,freqs_sin)
        h=self.norm(h)
        if targets is not None:
            logits=self
        if "input_ids" in keyargs:
            tokens=keyargs["input_ids"]
        if "attention_mask" in keyargs:
            targets=keyargs["attention_mask"]
        _bsz,_seq_len=tokens.shape
        h=self.tok_embedding(tokens)
        h=self.dropout(h)
        freqs_cos=self.freqs_cos[:seq_len]
        freqs_sin=self.freqs_sin[:seq_len]
        for layer in self.layers:
            h=layer(h,freqs_cos,freqs_sin)
        h=self.norm(h)
        if targets is not None:
            logits=self.output(h)
            self.last_loss=F.cross_entropy(logits.view(-1,self.vocab_size),targets.view(-1),ignore_index=0,reduction=None)
        else:
            logits=self.output(h)
            self.last_loss=None
        self.OUT.__setitem__("logits",logits)
        self.OUT.__setitem__("loss",self.last_loss)
        return self.OUT

    @torch.inference_mode()
    def generate(self, idx, stop_id=None, max_new_tokens=256, temperature=1.0,top_k=None):
        index=idx.shape[1]
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, self.args.max_seq_len:]  #
            logits=self(idx_cond).logits
            logits=logits[:, -1, :]
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            probs=F.softmax(logits/temperature,dim=-1)
            idx_next=torch.multinomial(probs,num_samples=1)
            if idx_next==stop_id:
                break
            idx=torch.cat((idx,idx_next),dim=-1)
        return idx[:,index:]