In [71]:
import torch
import torch.nn as nn
from model import LlaMA,LlaMAConfig
import torch.nn.functional as F

In [35]:
block_size = 1024

In [36]:
# 加载模型配置
config = LlaMAConfig.from_name("baby_llama")
config.block_size = block_size
config.vocab_size = 32000
config.padded_vocab_size = 32000
print("baby_llama config",config)

baby_llama config LlaMAConfig(block_size=1024, vocab_size=32000, padded_vocab_size=32000, n_layer=2, n_head=8, n_embd=128)


In [37]:
# 根据模型配置加载模型
model = LlaMA(config)
print(model)

LlaMA(
  (transformer): ModuleDict(
    (wte): Embedding(32000, 128)
    (h): ModuleList(
      (0-1): 2 x Block(
        (rms_1): RMSNorm()
        (attn): CasualSelfAttention(
          (c_attn): Linear(in_features=128, out_features=384, bias=True)
          (c_proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (rms_2): RMSNorm()
        (mlp): MLP(
          (fc_1): Linear(in_features=128, out_features=512, bias=False)
          (fc_2): Linear(in_features=128, out_features=512, bias=False)
          (c_proj): Linear(in_features=512, out_features=128, bias=True)
        )
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=128, out_features=32000, bias=False)
)


In [38]:
# 加载数据
input = torch.load("./input.pt")
target = torch.load("./target.pt")
print("input.shape:",input.shape)
print("target.shape:",target.shape)
print(input[0,:10])
print(target[0,:10])

input.shape: torch.Size([16, 1024])
target.shape: torch.Size([16, 1024])
tensor([41, 18,  3, 39,  3, 22, 38, 11, 63, 14])
tensor([18,  3, 39,  3, 22, 38, 11, 63, 14, 33])


In [39]:
# 推理
logits = model(input)
print("logits.shape:",logits.shape)
print("vocab_size:",config.vocab_size)
loss = torch.nn.functional.cross_entropy(logits.view(-1,config.vocab_size),target.view(-1),ignore_index=-1)
print("loss:",loss)

logits.shape: torch.Size([16, 1024, 32000])
vocab_size: 32000
loss: tensor(10.5448, grad_fn=<NllLossBackward0>)


In [49]:
# model forward stepbystep
idx = input
B, T = idx.size()
print("batch:{}, length{} ".format(B,T))
print('---------------0. create RoPE, Mask----------------')
# 创建RoPE、mask矩阵
rope = model.rope_cache[:T,:T]
mask = model.mask_cache[:,:,:T,:T]
print("rope.shape:",rope.shape)
print("mask.shape:",mask.shape)
max_seq_length = config.block_size
print("max_seq_length:",max_seq_length)


print('---------------1.embding----------------')
x = model.transformer.wte(idx)
x_embd = x
print("n_embd: ", config.n_embd)
print("before embeding: ", idx.shape)
print("after embeding: ", x.shape)

print('---------------2.llama block attention ----------------')
print("block_size:",len(model.transformer.h))
print("n_layers:",config.n_layer)
for block in model.transformer.h:
    x,_ = block(x,rope,mask,max_seq_length)
    print("Llama Block:",x.shape)

print('---------------3.llama output ----------------')
x = model.transformer.ln_f(x)
print("rms_norm_out:",x.shape)
logits = model.lm_head(x)
print("output logits:",logits.shape)
print("vocab_size:",config.vocab_size)

batch:16, length1024 
---------------0. create RoPE, Mask----------------
rope.shape: torch.Size([1024, 8, 2])
mask.shape: torch.Size([1, 1, 1024, 1024])
max_seq_length: 1024
---------------1.embding----------------
n_embd:  128
before embeding:  torch.Size([16, 1024])
after embeding:  torch.Size([16, 1024, 128])
---------------2.llama block attention ----------------
block_size: 2
n_layers: 2
Llama Block: torch.Size([16, 1024, 128])
Llama Block: torch.Size([16, 1024, 128])
---------------3.llama output ----------------
rms_norm_out: torch.Size([16, 1024, 128])
output logits: torch.Size([16, 1024, 32000])
vocab_size: 32000


In [50]:
# debug block 结构
block = model.transformer.h[0]
print(block)
x,_ = block(x,rope,mask,max_seq_length)
print("rms_1 -> attention-> rms_2-> MLP")

Block(
  (rms_1): RMSNorm()
  (attn): CasualSelfAttention(
    (c_attn): Linear(in_features=128, out_features=384, bias=True)
    (c_proj): Linear(in_features=128, out_features=128, bias=True)
  )
  (rms_2): RMSNorm()
  (mlp): MLP(
    (fc_1): Linear(in_features=128, out_features=512, bias=False)
    (fc_2): Linear(in_features=128, out_features=512, bias=False)
    (c_proj): Linear(in_features=512, out_features=128, bias=True)
  )
)
rms_1 -> attention-> rms_2-> MLP


In [51]:
# block forward
block = model.transformer.h[0]
x = x_embd
x_rms_1 = block.rms_1(x)
x_attn, _ = block.attn(x_rms_1, rope, mask, max_seq_length, None, None)
x = x_embd + x_attn
print('block attention result:', x.shape)

x_rms_2 = block.rms_2(x)
x_block_out = x + block.mlp(x_rms_2)
print('x+mlp(x) result:', x_block_out.shape)

block attention result: torch.Size([16, 1024, 128])
x+mlp(x) result: torch.Size([16, 1024, 128])


In [52]:
# rms_norm 实现
rms_norm = model.transformer.h[0].rms_1
print(rms_norm)
x = x_embd
print("rms_norm.scale", rms_norm.scale.shape)
print("config.n_embd", config.n_embd)
print("rms_norm.eps", rms_norm.eps)
print("rms_norm.dim", rms_norm.dim)
norm_x = torch.mean(x * x, dim=rms_norm.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + rms_norm.eps)
x_rms = rms_norm.scale * x_normed
print("归一化前", x_embd.shape)
print("归一化后", x_rms.shape)

RMSNorm()
rms_norm.scale torch.Size([128])
config.n_embd 128
rms_norm.eps 1e-05
rms_norm.dim -1
归一化前 torch.Size([16, 1024, 128])
归一化后 torch.Size([16, 1024, 128])


In [61]:
# simple rope
seq_len = block_size
n_elem = config.n_embd
base = 10000
print(f"输入:句长:{seq_len},元素个数:{n_elem}")
theta = 1.0 / (10000 ** (torch.arange(0,n_elem,2)) / n_elem)
pos = torch.arange(seq_len)
idx_theta = torch.outer(pos,theta)
print("idx_theta.shape:",idx_theta.shape)
cache = torch.stack([torch.cos(idx_theta),torch.sin(idx_theta)],dim=-1)
print("cache.shape:",cache.shape)

输入:句长:1024,元素个数:128
idx_theta.shape: torch.Size([1024, 64])
cache.shape: torch.Size([1024, 64, 2])


In [62]:
# 创建RoPE位置编码
RoPECache = torch.Tensor

# print("输入:句长,单头维度")
def build_rope_cache(
    seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
) -> RoPECache:
    
    print("输入:句长:{},单头维度:{}".format(seq_len, n_elem))
    
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
    print(torch.arange(0, n_elem, 2, dtype=dtype, device=device))
    print("theta:", theta)

    # Create position indexes `[0, 1, ..., seq_len - 1]`
    seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
    print("seqidx:", seq_idx)

    # Calculate the product of position index and $\theta_i$
    idx_theta = torch.outer(seq_idx, theta).float()
    print("position idx* theta :", idx_theta.shape)
    print("idx_theta[:4,:4]:", idx_theta[:4,:4])

    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
    print("cache: ", cache.shape)
    print(cache[1,:4,:2])

    # this is to mimic the behaviour of complex32, else we will get different results
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        cache = cache.half()
    print(cache.shape)
    print(cache.type)
    return cache

# Rope 实现
RoPECache = build_rope_cache(
    seq_len=model.config.block_size,
    n_elem=model.config.n_embd // model.config.n_head,
    dtype=idx.dtype,
    device=idx.device,
)

model.RoPECache = RoPECache

输入:句长:1024,单头维度:16
tensor([ 0,  2,  4,  6,  8, 10, 12, 14])
theta: tensor([1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03,
        1.0000e-03, 3.1623e-04])
seqidx: tensor([   0,    1,    2,  ..., 1021, 1022, 1023])
position idx* theta : torch.Size([1024, 8])
idx_theta[:4,:4]: tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.3162, 0.1000, 0.0316],
        [2.0000, 0.6325, 0.2000, 0.0632],
        [3.0000, 0.9487, 0.3000, 0.0949]])
cache:  torch.Size([1024, 8, 2])
tensor([[0.5403, 0.8415],
        [0.9504, 0.3110],
        [0.9950, 0.0998],
        [0.9995, 0.0316]])
torch.Size([1024, 8, 2])
<built-in method type of Tensor object at 0x7f14697d5210>


In [63]:
def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor:
    T = x.size(1)
    rope_cache = rope_cache[:T] # [T,head_size // 2,2]
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [bs,seq,head_num,head_size // 2,2]
    rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # [1,seq,1,head_size // 2,2]
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)


In [73]:
# block attention 实现

block_attn = model.transformer.h[0].attn
print(block_attn)

x_attn, _ = block_attn(x_rms_1, rope, mask, max_seq_length, None, None)
print(x_attn.shape)

x = x_rms_1
B, T, C = x.size()
print("batch:{}, length:{}, n_embding:{}".format(B,T,C))

print('--------------1. attenion split------------------')
q, k, v = block_attn.c_attn(x).split(block_attn.n_embd, dim=2)
head_size = C // block_attn.n_head
k = k.view(B, T, block_attn.n_head, head_size)
q = q.view(B, T, block_attn.n_head, head_size)
v = v.view(B, T, block_attn.n_head, head_size)
print("batch, length, head: n_embding: {}".format(k.shape))

print('--------------2. qk RoPE 旋转相对位置编码------------------')
print('RoPE编码作用在每个block的attention计算QK里')
q_rope_before = q
q = apply_rope(q, rope)
q_rope_after = q
k = apply_rope(k, rope)
print("q_rope前:", q_rope_before.shape)
print("q_rope后:", q_rope_after.shape)

k = k.transpose(1, 2)  # (B, nh, T, hs)
q = q.transpose(1, 2)  # (B, nh, T, hs)
v = v.transpose(1, 2)  # (B, nh, T, hs)

print('--------------3. 计算scale dot product 和前向传播------------------')
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
# output projection
y = block_attn.c_proj(y)
print("attention output:", y.shape)

CasualSelfAttention(
  (c_attn): Linear(in_features=128, out_features=384, bias=True)
  (c_proj): Linear(in_features=128, out_features=128, bias=True)
)
torch.Size([16, 1024, 128])
batch:16, length:1024, n_embding:128
--------------1. attenion split------------------
batch, length, head: n_embding: torch.Size([16, 1024, 8, 16])
--------------2. qk RoPE 旋转相对位置编码------------------
RoPE编码作用在每个block的attention计算QK里
q_rope前: torch.Size([16, 1024, 8, 16])
q_rope后: torch.Size([16, 1024, 8, 16])
--------------3. 计算scale dot product 和前向传播------------------
attention output: torch.Size([16, 1024, 128])


In [76]:
## mlp silu
mlp = model.transformer.h[0].mlp
print(mlp)
x = x_rms_1
print("SiLU(x) = x * sigmoid(x)")
x = F.silu(mlp.fc_1(x)) * mlp.fc_2(x) 
print("c_fc1 is gate")
print("c_fc2 is up")
x = mlp.c_proj(x)
print("mlp output:", x.shape)

MLP(
  (fc_1): Linear(in_features=128, out_features=512, bias=False)
  (fc_2): Linear(in_features=128, out_features=512, bias=False)
  (c_proj): Linear(in_features=512, out_features=128, bias=True)
)
SiLU(x) = x * sigmoid(x)
c_fc1 is gate
c_fc2 is up
mlp output: torch.Size([16, 1024, 128])
