# Llama

Llama 诞生于 2023.02, 在22年10月 ChatGPT 发布后，并没有与之性能对标的开源模型，Llama 一经发布，在工业/学术界引起了重要的开源热潮。在 Llama-2 后放开了商业 license

1. Llama 是一种 Decoder Transformer，相较 GPT-2/3 有多种改进，如RoPE、RMSNorm、SwiGLU等
2. Llama 加速了微调、RAG、agent、inference、training、pretrained、多模态、预训练数据工程、RL 等多个研究领域的开源发展
3. Llama-1/2/3 其网络结构稳定，是各种工具链标准适配模型
4. Llama 可以视为一种 dense 网络结果（从FFN层定义）， 而近年流行采用 MoE（混合专家系统）来扩展 FFN 参数，与 dense 相对应，这类模型被称为 sparse 型（如：deepseek-v3）

本 notebook 实现：

1. Llama 整体模型
2. KV Cache + GQA + RoPE 的注意力组件
3. 实现 SwiGLU
4. 实现 Llama training & inference

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

<torch._C.Generator at 0x10536cd70>

## config

In [34]:
from dataclasses import dataclass

@dataclass
class LlamaConfig:
    learning_rate: float = 0.001
    vocab_size: int = 200
    max_len: int = 512
    dim: int = 512
    n_heads: int = 8
    n_kv_heads: int = 4 # for GQA
    num_layers: int = 6
    position_encoding_base: float = 10000.0
    pad_token_id: int = 0
    attention_bias: bool = False # without bias
config = LlamaConfig

## Toy Model

1. 输入层没有 位置编码
2. block 里引入 rope; block里包含模块 GQA、RMSNorm、SwiGLU、RoPE
3. 输入层包含 norm 和 lm_heads
4. Linear 层去除 bias

In [35]:
# 定义一个参数模型学习：

class LlamaToyModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embd = nn.Embedding(config.vocab_size, config.dim)

        # 仅用 linear 示例
        self.RMSNorm1 = nn.Linear(config.dim, config.dim, bias=False)
        self.GQA = nn.Linear(config.dim, config.dim, bias=False) # with RoPE
        self.RMSNorm2 = nn.Linear(config.dim, config.dim, bias=False)
        self.SwiGLU = nn.Linear(config.dim, config.dim, bias=False)

        
        self.OutputNorm = nn.Linear(config.dim, config.dim, bias=False)
        self.LM_head = nn.Linear(config.dim, config.vocab_size, bias=False)


    def forward(self, x):
        '''
            - x dim[batch,seq]
        '''
        # input
        X = self.embd(x) # 没有位置编码
        
        # blocks
        X_norm = self.RMSNorm1(X)        
        X = self.GQA(X_norm) + X
        X_norm = self.RMSNorm2(X)        
        h = self.SwiGLU(X_norm) + X

        # lm head
        h = self.OutputNorm(h)
        logits = self.LM_head(h)
        return logits

model = LlamaToyModel(config)

x = torch.randint(config.vocab_size, (2, 32))
y_logits = model(x)

print(x.shape) # # batch_size, seq_len
print(y_logits.shape) # batch_size, seq_len, vocab_size

torch.Size([2, 32])
torch.Size([2, 32, 200])


## RMS Normalization

In [36]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        mean = (x**2).mean(-1, keepdim=True)
        out_mean = x / torch.sqrt(mean + self.eps) # root mean square
        out = self.gamma * out_mean 
        return out

## SwiGLU

In [37]:
a = torch.randn(1,4)
print(F.silu(a))
print(F.sigmoid(a)*a)

tensor([[ 0.5429, -0.2782,  0.0819, -0.2592]])
tensor([[ 0.5429, -0.2782,  0.0819, -0.2592]])


In [38]:
class SwiGLU(nn.Module):
    def __init__(self, dim):
        super().__init__()

        hidden_dim = dim * 8 // 3
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w_act = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        
    def forward(self, x):
        h = F.silu(self.w_act(x)) * self.w1(x)
        return self.w2(h)

## RoPE

In [45]:
def create_rope(max_len=1024, dim=512, base = 10000.0):
    m = torch.arange(0, max_len, 1)
    i = torch.arange(0, dim//2, 1) 
    theta = base ** (-2 * i / dim)
    m_theta = torch.outer(m, theta)
    cos = torch.zeros(max_len, dim)
    sin = torch.zeros(max_len, dim) 
    cos[:, 0::2] = cos[:, 1::2] = torch.cos(m_theta) # cos(theta1), cos
    sin[:, 0::2] = sin[:, 1::2] = torch.sin(m_theta) # sin, sin
    return sin, cos

def apply_rope(X, sin, cos):
    bs, n_heads, seq_len, d = X.shape
    X_shift = torch.zeros_like(X)
    X_shift[..., 0::2] = -X[..., 1::2]
    X_shift[..., 1::2] = X[..., 0::2]
    Y = cos[None, None, :seq_len, :] * X + \
        sin[None, None, :seq_len, :] * X_shift
    return Y

sin, cos = create_rope(max_len=1024, dim=config.dim // config.n_heads)
print(sin.shape)
x = torch.randn(2, 8, 16, config.dim // config.n_heads)
x_rope = apply_rope(x, sin, cos)
print(x_rope.shape)

torch.Size([1024, 64])
torch.Size([2, 8, 16, 64])


## GQA

In [40]:
class GroupQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dim = config.dim
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads 
        self.head_dim = self.dim // self.n_heads
        self.share_heads = self.n_heads // self.n_kv_heads
        self.wq = nn.Linear(self.dim, self.dim, bias=config.attention_bias)
        self.wk = nn.Linear(self.dim, self.head_dim * self.n_kv_heads, bias=config.attention_bias) # grouped share k
        self.wv = nn.Linear(self.dim, self.head_dim * self.n_kv_heads, bias=config.attention_bias) # grouped share v 
        self.wo = nn.Linear(self.dim, self.dim, bias=config.attention_bias)
        
    def forward(self, x, mask = None, sin=None, cos=None):
        bsz, seq_len, dim = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # split
        q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        k = k.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1,2)
        v = v.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1,2)
        k = torch.repeat_interleave(k, self.share_heads, dim=1)
        v = torch.repeat_interleave(v, self.share_heads, dim=1)

        # apply rope
        q = apply_rope(q, sin, cos)
        k = apply_rope(k, sin, cos)

        # KV Cache after apply rope
        
        s = q@k.transpose(3,2) / math.sqrt(self.dim)
        if mask is not None:
            s = s + mask.unsqueeze(0).unsqueeze(0)
        p = F.softmax(s, dim = -1)
        z = p @ v

        # cat
        z = z.transpose(1,2).reshape(bsz, seq_len, self.dim)
        
        return self.wo(z)

## Llama Blocks

In [30]:
class LlamaDecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dim = config.dim
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads 
        self.head_dim = self.dim // self.n_heads

        self.norm1 = RMSNorm(dim = self.dim)
        self.attn = GroupQueryAttention(config)
        self.norm2 = RMSNorm(dim = self.dim)
        self.ffn = SwiGLU(dim = self.dim)
        
    def forward(self, X, mask = None, sin=None, cos=None):
        bsz, seq_len, _ = X.shape

        # blocks
        X_norm = self.norm1(X)        
        X = self.attn(X_norm, mask, sin, cos) + X 
        X_norm = self.norm2(X)        
        h = self.ffn(X_norm) + X

        return h

## Llama Model

In [31]:
class LlamaModel(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        self.config = config
        self.embd = nn.Embedding(self.config.vocab_size, 
                                 self.config.dim)
        self.decoder = nn.ModuleList(
            [LlamaDecoderBlock(self.config) for _ in range(self.config.num_layers)]
        )
        self.ln = RMSNorm(self.config.dim)
        self.lm_head = nn.Linear(self.config.dim, 
                                 self.config.vocab_size,
                                 bias = False) # 不学习预训练数据分布偏置

        self.cache_mask = torch.tril(torch.ones(self.config.max_len, 
                                                self.config.max_len)) 
        self.rope_sin, self.rope_cos = create_rope(self.config.max_len, self.config.dim // self.config.n_heads)

        

    def forward(self, x):
        bs, seq_len = x.shape
        
        add_mask = self.cache_mask[:seq_len, :seq_len]
        
        X = self.embd(x)
        for block in self.decoder:
            X = block(X, mask = add_mask, sin=self.rope_sin, cos=self.rope_cos)
        X = self.ln(X)
        logits = self.lm_head(X)
        
        return logits

In [32]:
model = LlamaModel(config)
input_ids = torch.randint(config.vocab_size, (2, 32))
y = model(input_ids)
print(y.shape)

torch.Size([2, 32, 200])
