In [52]:
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import DataLoader
from torch import nn
import os
from torch.utils.data import IterableDataset, Dataset
import json
import numpy as np
from transformers import  PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PretrainedConfig
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig
from datasets import load_dataset, load_from_disk

In [None]:
import wandb
wandb.login(key="")

In [None]:
wandb.init(project="TinyLLM pretrain", name="batch_size=16 lr=2e-4 max_step=25000")

RMSNorm公式为:

$ RMSNorm(x) = \frac{x}{RMS(x)} \cdot  \gamma $  
$ RMS(x) = \sqrt{\frac{1}{d} \sum_{i=1}^d x_i^2} $

In [25]:
# RMSNorm
class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        # gemma 增大特征表达能力
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 防止分母为 0 
        self.variance_epsilon = eps
    def forward(self, hidden_states):
        hidden_states = hidden_states.float()
        variance = hidden_states.pow(2).mean(-1,keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.float()

旋转后的矩阵为

$v' = R(\theta) \cdot v = \begin{bmatrix} v_1 \cdot \cos(\theta) - v_2 \cdot \sin(\theta) \\ v_1 \cdot \sin(\theta) + v_2 \cdot \cos(\theta)\end{bmatrix}$

公式可写为
$v' = R(\theta) \cdot v = v \cdot \cos(\theta) + \text{rotate\_half}(v) \cdot \sin(\theta)$

In [26]:
# RoPE
def rotate_half(x):
    x1, x2 = x.chunk(2, dim = -1)
    return torch.cat((-x2, x1), dim = -1)

In [27]:
# 应用RoPE
def apply_RoPE(q, k, cos, sin, unsqueeze_dim = 2):
    # 增加维度以对 GQA 的Q K shape进行广播
    cos = cos.unsqueeze(unsqueeze_dim) # (1, seq_length, d_model) -> (1, seq_length, 1, d_model)
    sin = sin.unsqueeze(unsqueeze_dim) 
    q_embed = (q * cos) + (rotate_half(q) * sin) # (batch_size, seq_length, head_num, d_model) =  (batch_size, seq_length, head_num, d_model) * (1, seq_length, 1, d_model) 广播
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [28]:
class RoPE(nn.Module):
    def __init__(self, dim, max_seq_length = 2048):
        super(RoPE, self).__init__()
        self.dim = dim
        self.max_seq_length = max_seq_length
        # 绝对位置信息
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_length).float().unsqueeze(1) # (max_seq_length, 1)
        freqs = t @ inv_freq.unsqueeze(0)  #(max_seq_len, 1) * (1, dim/2) = (max_seq_len, dim/2)
        freqs = torch.cat((freqs, freqs), dim=-1)  # (max_seq_len, dim)
        
        self.register_buffer("cos_cached", freqs.cos())
        self.register_buffer("sin_cached", freqs.sin())
    def forward(self, q, k):
        # 根据seq_length截取 freqs 得到sin cos数值
        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0) # (1, seq_length, dim)
        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
        return apply_RoPE(q, k, cos, sin)

In [29]:
# Group Query Attention(GQA)需要一个Q共享 多个K,V 此函数对K V进行复制
# param1: hidden_state ,param2: n_rep 复制次数
def repeat_kv(hidden_states, n_rep):
    batch, seq_length, head_num, d_k = hidden_states.shape
    # 复制一次则不动
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, :, None, :].expand(batch, seq_length, head_num, n_rep, d_k)
    return hidden_states.reshape(batch, seq_length, head_num * n_rep, d_k)

In [30]:
# Attention类
class Attention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.dropout = config.dropout
        self.hidden_size = config.hidden_size
        self.head_num = config.attention_head_num
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.head_num)
        self.kvhead_num = config.kvhead_num
        self.kvgroup_num = self.head_num // self.kvhead_num
        self.k_cache, self.v_cache = None, None
        self.is_causal = True
        self.flash_attn = self.config.flash_attn

        # 初始化矩阵
        self.q_proj = nn.Linear(self.hidden_size, self.head_num * self.head_dim, bias = config.attention_bias)
        # GQA KV分组
        self.k_proj = nn.Linear(self.hidden_size, self.kvhead_num * self.head_dim, bias = config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.kvhead_num * self.head_dim, bias = config.attention_bias)
        self.o_proj = nn.Linear(self.head_num * self.head_dim, self.hidden_size, bias = config.attention_bias)
        self.residual_dropout = nn.Dropout(self.dropout)
        self.attention_dropout = nn.Dropout(self.dropout)
        self.RoPE_emb = RoPE(self.head_dim)

    def forward(self, hidden_states, use_kv_cache = False):
        batch, seq_length = hidden_states[:2]

        if use_kv_cache and self.eval():
            if self.k_cache is None or self.k_cache.shape[1] != seq_length - 1:
                q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
            else:
                # 获取最新生成的token
                token = hidden_states[:, -1:, :] # 形状(b, 1, dim)
                q = torch.cat((torch.zeros_like(hidden_states[:, :-1, :]), self.q_proj(token)), dim=1) 
                # 新的k,v和之前已经生成的进行拼接
                k = torch.cat((self.k_cache, self.k_proj(token)), dim=1)
                v = torch.cat((self.v_cache, self.v_proj(token)), dim=1)
            # 更新cache
            self.k_cache, self.v_cache = k, v
            
        else:
            q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
        
        q = q.view(batch, seq_length, self.head_num, self.head_dim) # (batch, seq_length, head_num, head_dim/d_k)
        k = k.view(batch, seq_length, self.kvhead_num, self.head_dim)
        v = v.view(batch, seq_length, self.kvhead_num, self.head_dim)

        q, k = self.RoPE_emb(q, k)

        k = repeat_kv(k, self.kvgroup_num)
        v = repeat_kv(v, self.kvgroup_num)

        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)

        if self.flash_attn:
            output = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p = self.dropout if self.training else 0.0, is_causal = self.is_causal)
        else:
            mask = torch.full((1, 1, self.config.max_seq_length, self.config.max_seq_length), float("-inf"))
            mask = torch.triu(mask, diagonal = 1)
            scores = torch.matmul(q, k.transpose(2, 3))/math.sqrt(self.head_dim)
            scores = scores + self.mask[:, :, :seq_length, :seq_length]
            scores = F.softmax(scores.float(), dim = -1).type_as(q)
            scores = self.attention_dropout(scores)
            output = torch.matmul(scores, v)
        
        output = output.transpose(1, 2).contiguous().view(batch, seq_length, -1)
        output = self.o_proj(output)
        output = self.residual_dropout(output)
        return output

In [31]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)

    def forward(self, x):
        down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

In [32]:
class DecoderLayer(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = Attention(config)
        self.mlp = MLP(config)
        self.input_layernorm = RMSNorm(config.hidden_size)
        self.post_attention_layernorm = RMSNorm(config.hidden_size)
        self.layer_idx = layer_idx
    def forward(self, hidden_states, use_kv_cache):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(hidden_states=hidden_states, use_kv_cache=use_kv_cache)
        
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = hidden_states
        return outputs

In [33]:
class Config(PretrainedConfig):
    model_type = "small_model"
    
    def __init__(self,
                hidden_size = 512,
                attention_head_num = 16,
                kvhead_num = 8,
                flash_attn = True,
                attention_bias = False,
                max_seq_len = 512,
                intermediate_size = 2048,
                mlp_bias = False,
                vocab_size = 6400,
                n_layers = 8,
                dropout = 0.0,
                **kwargs):
        self.hidden_size = hidden_size
        self.attention_head_num = attention_head_num
        self.kvhead_num = kvhead_num
        self.flash_attn = flash_attn
        self.attention_bias = attention_bias
        self.max_seq_len = max_seq_len
        self.intermediate_size = intermediate_size
        self.mlp_bias = mlp_bias
        self.vocab_size = vocab_size
        self.n_layers = n_layers
        self.dropout = dropout
        super().__init__(**kwargs)

In [34]:
class LLM(PreTrainedModel):
    config_class = Config
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.vocab_size = self.config.vocab_size
        self.n_layers = self.config.n_layers

        self.tokon_embeddings = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
        self.dropout = nn.Dropout(self.config.dropout) 
        self.layers = torch.nn.ModuleList() 
        for layer_idx in range(self.n_layers):
            self.layers.append(DecoderLayer(self.config, layer_idx)) 
        self.norm = RMSNorm(self.config.hidden_size)
        self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) 
        self.apply(self._init_weights) 
        self.loss = None 
        
        for pn, p in self.named_parameters():
            if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers)) 

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)  
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)  
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 
            
        
    def forward(self, input_ids, labels, use_kv_cache=False):
       
        hidden_states = self.tokon_embeddings(input_ids) 
        hidden_states = self.dropout(hidden_states)  
        for idx, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states, use_kv_cache=use_kv_cache)  

        hidden_states = self.norm(hidden_states) 

        if labels is not None:
            logits = self.output(hidden_states)  
            self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=0) 
        else:
            logits = self.output(hidden_states[:, [-1], :])  
            self.loss = None  

        return CausalLMOutputWithPast(self.loss, logits)
    
    @torch.inference_mode
    def generate(self, inputs, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1., use_kv_cache=True):
        
        input_ids = inputs['input_ids']
        labels = inputs['labels']
        s = input_ids.shape[1]
        while input_ids.shape[1] < max_new_tokens - 1:  
            inference_res = self(input_ids, labels, use_kv_cache=use_kv_cache)  
            logits = inference_res.logits 
            # 取最后一个token的logits
            logits = logits[:, -1, :] 

            for token in set(input_ids.tolist()[0]):  
                logits[:, token] /= repetition_penalty

            if temperature == 0.0: 
                _, idx_next = torch.topk(logits, k=1, dim=-1)
            else:
                logits = logits / temperature  
                if top_k is not None:  
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('Inf') 

                probs = F.softmax(logits, dim=-1)  
                idx_next = torch.multinomial(probs, num_samples=1, generator=None)  

            if idx_next == eos:  
                break

            input_ids = torch.cat((input_ids, idx_next), dim=1)  
            if stream:  
                yield input_ids[:, s:]  

        if not stream:  
            yield input_ids[:, s:] 

In [35]:
config = Config()
model = LLM(config)

In [None]:
print(f'模型参数量为：{sum(p.numel() for p in model.parameters() if p.requires_grad)}')

In [37]:
dataset_path = ''
tokenizer_path = ''

In [56]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
dataset = load_dataset('json', data_files='autodl-tmp/mobvoi_seq_monkey_general_open_corpus.jsonl')

In [None]:
def process_data(dataset, tokenizer, max_seq_len):
    processed_data = []
    for example in tqdm(dataset, desc="Processing dataset"):
        text = f"<s>{example['text']}</s>"

        # 对文本进行编码
        input_ids = tokenizer.encode(text, add_special_tokens=False)
        text_len = len(input_ids)

        # 限制长度
        if text_len > max_seq_len:
            input_ids = input_ids[:max_seq_len]
        else:
            input_ids = input_ids + [0] * (max_seq_len - text_len)

        # 构造 input_ids 和 labels
        input_ids = np.array(input_ids)
        X = torch.tensor(input_ids[:-1], dtype=torch.long)  # 输入
        Y = torch.tensor(input_ids[1:], dtype=torch.long)  # 标签

        processed_data.append({
            'input_ids': X,
            'labels': Y
        })
    return processed_data

In [None]:
dataset = process_data(dataset['train'],tokenizer,config.max_seq_len)

In [43]:
data_collator = DefaultDataCollator()
args = TrainingArguments(output_dir='./results', 
                        num_train_epochs=20, 
                        do_train=True, 
                        per_device_train_batch_size=16,
                        gradient_accumulation_steps=8,
                        group_by_length=False,
                        max_steps=10,
                        logging_steps=10,
                        report_to = 'wandb')            

In [None]:
trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator)

In [None]:
trainer.train()