### 1. 导入相关包

In [16]:

import torch
import torch.nn as nn
import torch.nn.functional as F

# from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass
from transformers import AutoTokenizer
from datasets import load_dataset, Dataset, load_from_disk
import math
import wandb
import os


wandb.init(project="gpt_mla", name='v2')

torch.manual_seed(1024)


<torch._C.Generator at 0x7ff9c19774d0>

In [17]:
# 前置代码
class DeepseekV2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


# 位置编码
class DeepseekV2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # 较小索引位置对应较低频率
        # 较大的索引位置有较高的频率
        
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
        )
        self.max_seq_len_cached = None

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )

        freqs = torch.outer(t, self.inv_freq.to(t.device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def apply_rotary_pos_emb_v2(q: torch.Tensor, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    return q_embed

### 2. 定义 GPT 相关参数

In [18]:
@dataclass
class GPTConfig:
    block_size: int = 512 # seq_len
    batch_size: int = 12    
    n_layer: int = 6
    n_head: int = 8
    n_embed: int = 2048
    # head_size: int = 128
    dropout: float = 0.1 
    vocab_size: int = 50257
    use_cache: bool = False


    hidden_size: int = 2048
    num_heads: int = 8
    max_position_embeddings: int = 512
    rope_theta: float = 128000
    attention_dropout: float = 0.1
    q_lora_rank: int = 256 # 得到低维度q的低秩变换矩阵
    qk_rope_head_dim: int = 64 # qv向量所需 rope 位置编码向量的维度
    kv_lora_rank: int = 64 # 得到低维度kv的低秩变换矩阵 
    v_head_dim: int  = 128 # v向量升秩后的维度
    qk_nope_head_dim: int = 128  # qk向量升秩后的维度  
    attention_bias: bool = False
    
    training: bool = True


@dataclass
class DeepseekConfig:
    hidden_size: int
    num_heads: int
    max_position_embeddings: int
    rope_theta: float

    attention_dropout: float

    q_lora_rank: int # 得到低维度q的低秩变换矩阵
    qk_rope_head_dim: int # qv向量所需 rope 位置编码向量的维度

    kv_lora_rank: int # 得到低维度kv的低秩变换矩阵 

    v_head_dim: int # v向量升秩后的维度
    qk_nope_head_dim: int # qk向量升秩后的维度  
    attention_bias: bool 
    
    training: bool = True





### 3. 模型结构

In [19]:
# class SingleHeadAttention(nn.Module):
#     def __init__(self, config):
#         super(SingleHeadAttention, self).__init__()
#         self.key = nn.Linear(config.n_embed, config.head_size)
#         self.query = nn.Linear(config.n_embed, config.head_size)
#         self.value = nn.Linear(config.n_embed, config.head_size)

#         self.head_size = config.head_size
#         self.use_cache = config.use_cache

#         # 不计算梯度的方式
#         self.register_buffer(
#             "attention_mask",
#             torch.tril(
#                 torch.ones(
#                     config.block_size, config.block_size
#                 )
#             )
#         )

#         self.dropout = nn.Dropout(config.dropout)

#         # 缓存k,v
#         self.k_cache = None
#         self.v_cache = None
#         self.out_cache = None

#     def forward(self, x):
#         batch_size, seq_len, embed_size = x.size()

#         # 使用缓存
#         if self.use_cache:
            
#             q_to_use = None

#             if self.k_cache is None:
#                 self.k_cache = self.key(x)
#                 self.v_cache = self.value(x)
#                 q_to_use = self.query(x)
#             else:
#                 # 添加k,v新的行
#                 k = self.key(x[:, -1, :]).unsqueeze(1)
#                 v = self.value(x[:, -1, :]).unsqueeze(1)

#                 self.k_cache = torch.cat([self.k_cache, k], dim=1)
#                 self.v_cache = torch.cat([self.v_cache, v], dim=1)
            
#                 q_to_use = self.query(x[:, -1, :]).unsqueeze(1)


#             weight = q_to_use @ self.k_cache.transpose(-2, -1) # [batch_size, 1 or init, h_size] @ [batch_size, h_size, seq_len]

#             weight = F.softmax(weight / (self.head_size ** 0.5), dim=-1) 
#             weight = self.dropout(weight)
#             new_token = weight @ self.v_cache

            
#             if self.out_cache is not None:
#                 # print("拼接outcache")
#                 self.out_cache = torch.cat([self.out_cache, new_token], dim=1)
#                 # print(f"self.out_cahe: {self.out_cache.shape}")
#             else:
#                 self.out_cache = new_token

#             return self.out_cache   
            


#         else:
#             k, q, v = self.key(x), self.query(x), self.value(x)

#             weight = q @ k.transpose(-2, -1)
            
#             # mask
#             weight = weight.masked_fill(
#                 self.attention_mask[:seq_len, :seq_len] == 0,
#                 float("-inf")
#             )
            
#             weight = F.softmax(weight / (self.head_size ** 0.5), dim=-1)
#             weight = self.dropout(weight)
#             out = weight @ v

#             return out

class MLA(nn.Module):
    '''包含矩阵吸收'''
    def __init__(self, config: DeepseekConfig):
        super(MLA, self).__init__()
        self.training = config.training

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.v_head_dim = config.v_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.qk_rope_head_dim = config.qk_rope_head_dim


        self.out_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False
        )

        # 压缩
        self.q_lora_rank = config.q_lora_rank  
        self.kv_lora_rank = config.kv_lora_rank

        
        self.q_down_proj = nn.Linear(
            self.hidden_size,
            self.q_lora_rank,
            bias=config.attention_bias
        )
        
        self.q_down_norm = DeepseekV2RMSNorm(self.q_lora_rank)

        # k在降维时得到位置编码
        self.kv_down_proj = nn.Linear(
            self.hidden_size,
            self.kv_lora_rank + config.qk_rope_head_dim,
            bias=config.attention_bias
        )

    

        self.kv_down_norm = DeepseekV2RMSNorm(self.kv_lora_rank + self.qk_rope_head_dim)

        # 升维
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim # 

        # q向量在升维时得到位置编码
        self.q_up_proj = nn.Linear(
            self.q_lora_rank,
            self.num_heads * self.q_head_dim,
            bias= False, # 
        )

        # qk升维用的是同一个低秩向量，这里同时进行kv的升维，所以映射后的维度要包括q和v的
        self.kv_up_proj = nn.Linear(
            self.kv_lora_rank,
            self.num_heads * (config.qk_nope_head_dim + config.v_head_dim),
            bias=False
        ) 

        self.rotary_emb = DeepseekV2RotaryEmbedding(
            config.qk_rope_head_dim,
            config.max_position_embeddings,
            config.rope_theta
        )

    def forward(self, hidden_states: torch.Tensor, position_ids, attention_mask=None):
        b, s, d = hidden_states.shape
        # q part
        q = self.q_down_proj(hidden_states)
        q = self.q_down_norm(q)
        q = self.q_up_proj(q) # num_heads * ( nope_dim + rope_dim )

        # [b, s, num_heads * (nope+rope)] -> [b, num_heads, s, nope+rope]
        q = q.view(b, s, self.num_heads, self.q_head_dim).transpose(1, 2)
        # split q to q_nope and q_rope
        q_nope, q_rope = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        
        # k v part
        c_kv_and_rope = self.kv_down_proj(hidden_states)
        c_kv_and_rope = self.kv_down_norm(c_kv_and_rope)
        c_kv, k_rope = c_kv_and_rope.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

        # [b, s, 1, rope_dim] -> [b, 1, s, rope_dim]
        k_rope = k_rope.view(b, s, 1, self.qk_rope_head_dim).transpose(1, 2)

        # 从 kv_up_proj 中分离出 W_UK 和 W_UV
        kv_b_proj = self.kv_up_proj.weight.view(
            self.num_heads, -1, self.kv_lora_rank
        )

        q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]
        out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]


        cos, sin = self.rotary_emb(q_rope, seq_len=s)
        q_rope = apply_rotary_pos_emb_v2(
            q_rope, cos, sin, position_ids
        )


        # W_UK被q_nope吸收
        q_nope = torch.einsum('hdc, bhqd->bhqc', q_absorb, q_nope)


        attn_weights = torch.matmul(q_rope, k_rope.transpose(-1, -2)) + torch.einsum('bhqc, blc->bhql', q_nope, c_kv)
        attn_weights = attn_weights / math.sqrt(self.q_head_dim)


        if attention_mask is not None:
            attn_weights = torch.masked_fill(
                attn_weights,
                attention_mask == 0,
                float('-inf')
            )
        
        attn_weights = F.softmax(attn_weights, dim=-1).to(hidden_states.dtype)

        attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)

        o_  = torch.einsum('bhql,blc->bhqc', attn_weights, c_kv) # (4)
        o   = torch.einsum('bhqc,hdc->bhqd', o_, out_absorb)  # (5)
        u   = torch.einsum('hdD,bhqd->bqD', self.out_proj.weight.view(self.num_heads, self.v_head_dim, -1), o)     # (6)

        return u, attn_weights
    


In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        
        self.max_position_embeddings = config.max_position_embeddings
        self.mla = MLA(config)

        self.fc = nn.Linear(config.n_embed, config.n_embed)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output, _ = self.mla(x,torch.arange(
                self.max_position_embeddings,
                ).unsqueeze(0).expand(
                    x.size(0), -1
                ) 
        )

        output = self.fc(output)
        output = self.dropout(output)
        # print(f"multi ouput: {output.shape}")        

        return output

In [21]:
# MLP

class FeedForward(nn.Module):
    def __init__(self, config):
        super(FeedForward, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embed, 4 * config.n_embed),
            nn.GELU(),
            nn.Linear(4 * config.n_embed, config.n_embed),
            nn.Dropout(config.dropout)
        ) 

    def forward(self, x):
        return self.mlp(x)
        

In [22]:
class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()
        self.att = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.n_embed)
        self.ln2 = nn.LayerNorm(config.n_embed)

    def forward(self, x):
        # 注意残差链接 GPT模型使用的是前归一化

        # print(f'x: {x.shape}, att: {self.att(self.ln1(x)).shape}')

        x = x + self.att(self.ln1(x))
        
        x = x + self.ffn(self.ln2(x))
        return x

### 4. 完整的GPT

In [23]:
class GPT(nn.Module):
    def __init__(self, config):
        super(GPT, self).__init__()
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embed)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embed)

        self.blocks = nn.Sequential(
            *[
                Block(config) for _ in range(config.n_layer)
            ]
        )

        self.ln_final = nn.LayerNorm(config.n_embed)
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, std=0.02, mean=0)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, std=0.02, mean=0)


    def forward(self, ids, targets=None):
        batch_size, seq_len = ids.size()
        token_embedding = self.token_embedding_table(ids)
        # position_ids = torch.arange(seq_len, device=ids.device).unsqueeze(0).expand(batch_size, seq_len)
        position_ids = torch.arange(seq_len, device=ids.device)

        position_embedding = self.position_embedding_table(position_ids)

        x = token_embedding + position_embedding # 广播机制

        x = self.blocks(x)

        x = self.ln_final(x)

        logits = self.lm_head(x)

        if targets is None:
            loss = None
        
        else:
            _, seq_len, vocab_size = logits.size()
            logits = logits.view(-1, vocab_size)
            targets = targets.view(-1)

            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
        
    def generate(self, idx, max_new_tokens):
        pass


### 数据集处理

In [24]:
# dataset_name = 'CausalLM/Refined-Anime-Text'
# tokenizer = AutoTokenizer.from_pretrained("gpt2")
# tokenizer.pad_token = tokenizer.eos_token
# # 对数据集加载并排序
# ds = load_dataset(dataset_name)
# ds_add_len = ds['train'].map(lambda x : {"length" : len(x['text']) }, num_proc=8)
# sorted_ds = ds_add_len.sort('length')
# # 取前半数据
# dataset = sorted_ds[: len(sorted_ds) // 2 ]
# dataset = Dataset.from_dict(dataset)
# dataset = dataset.remove_columns(['length'])
# tokenized_data = dataset.map(
#     lambda x: tokenizer(x['text'], return_tensors='pt', padding=True, truncation=True, max_length=2048),
#     batched=True,
#     remove_columns=['text'],
#     num_proc=8
# )
# selected_data = tokenized_data.shuffle().select(range(10))
# for data in selected_data:
#     print(len(data['input_ids']))
# block_size = 512
# def group_texts(examples):
#     concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
#     total_length = len(concatenated_examples["input_ids"])
#     total_length = (total_length // (block_size + 1)) * (block_size + 1)
#     result = {
#         k: [t[i : i + block_size + 1] for i in range(0, total_length, block_size + 1)]
#         for k, t in concatenated_examples.items()
#     }
#     result["labels"] = [seq[1:] for seq in result["input_ids"]]
#     result["input_ids"] = [seq[:-1] for seq in result["input_ids"]]
#     return result

# chunked_data = tokenized_data.map(group_texts, batched=True, num_proc=8)

In [25]:
# chunked_data.save_to_disk('./dataset/chunked_data_seq512')
chunked_data = load_from_disk('./dataset/chunked_data_seq512/')

In [26]:
chunked_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [27]:
chunked_data[0]['input_ids']

tensor([  464,   464,   464,   464,   162,   253,   232, 42468, 25465,   164,
          251,   236, 41753,   100, 16764,   164,   232,   247,   164,   243,
          122, 42468, 20804, 33768,   237, 16764, 30585,   243,   162,   107,
          242, 42468, 20804, 33768,   237, 16764,   165,   247,   228,   162,
          110,   231, 42468, 26193,   222, 33768,   237, 16764,   162,    96,
          106,   162,   122,   250, 42468, 43718,   239,   162,   229,   240,
        16764, 20046,   242,   164,   233,   241, 42468, 20998,   234,   165,
          109,   120, 41753,   100, 16764,   164,   236,   104, 36685,   106,
        39355,    94, 42468, 20804, 33768,   237, 16764,   164,   236,   231,
          164,   236,   231, 42468,   163,   251,    94, 20804, 33768,   237,
        16764, 27670,   232,   164,   236,   231, 12859,   248, 42468, 20804,
        33768,   237, 16764, 24376, 43145, 10310,   119,   164,   100,   240,
        42468,   164,   233,   237,   162,   248,   244,   162, 

In [28]:
train_dataloader = DataLoader(chunked_data, batch_size=8, shuffle=False)

In [29]:
len(train_dataloader)

56567

### 5. 训练

In [30]:
model = GPT(GPTConfig())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

In [31]:
total_params = sum(p.numel() for p in model.parameters())

print(f'Total parameters: {total_params/ 1e6}M')

Total parameters: 453.970176M


In [32]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
sheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

In [33]:
def train(model, optimizer, sheduler, train_loader, test_loader, device, epoch):
    model.train()
  
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        
        ids = batch['input_ids'].to(device)
        targets = batch['labels'].to(device)

        # 计算loss
        logits, loss = model(ids, targets)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 调整学习率
        sheduler.step()

        total_loss += loss.item()

        # 打印信息
        if batch_idx % 100 == 0:
            print(f' Batch: {batch_idx}, Loss: {loss.item():.4f}')

            checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'sheduler_state_dict': sheduler.state_dict(),
            'train_loss': loss,
        }
        
        if batch_idx % 10000 == 0 or batch_idx == len(train_dataloader) - 1:
            # 定义保存路径
            save_path = f'checkpoint/gpt_all/epoch{epoch}-{batch_idx}.pth'

            # 检查文件夹是否存在，如果不存在则创建
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            # 保存模型
            torch.save(checkpoint, save_path)   


        wandb.log({"loss": loss.item()})

    return total_loss / len(train_loader)

In [34]:
def eval(model, test_loader, device):

    model.eval()

    val_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            ids = batch['input_ids'].to(device)
            targets = batch['labels'].to(device)

            logits, loss = model(ids, targets)
            val_loss += loss.item()
            
            print(f'Val Loss: {val_loss / len(test_loader):.4f}')

            


    return val_loss / len(test_loader)

In [35]:
def train_gpt(epoch_num = 3, ):

    for epoch in range(epoch_num):

        train_loss = train(model, optimizer, sheduler, train_dataloader, None, device, epoch)
        # val_loss = eval(model, test_loader, device)

        print(f'Epoch: {epoch}, Train Loss: {train_loss}')


    
    print('finish training!!')


In [36]:
train_gpt(3)

 Batch: 0, Loss: 11.3040
 Batch: 100, Loss: 3.8349


KeyboardInterrupt: 

: 

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

: 

In [None]:
# import time

# # 进行推理
# # dataset = MyDataset('./dataset/anime_text/')
# chpt_path = './checkpoint/gpt_all/epoch0-100.pth'

# infer_config = GPTConfig()
# infer_config.use_cache = True
# model = GPT(infer_config)
# checkpoint = torch.load(chpt_path)
# model.load_state_dict(checkpoint['model_state_dict'])

# no_cache_config = GPTConfig()
# no_cache_config.use_cache = False
# model_no_cache = GPT(no_cache_config)
# checkpoint = torch.load(chpt_path)
# model_no_cache.load_state_dict(checkpoint['model_state_dict'])

# def inference(model, prompt, max_len):
#     model.eval()
#     model = model.to(device)
#     ids = torch.tensor(dataset.encode(prompt)).to(device)
#     ids = ids.unsqueeze(0)
#     # print(ids)
#     res = []
#     for _ in range(max_len):
#         b, sl = ids.size()
#         if sl > GPTConfig.block_size:
#             ids = ids[:, :-GPTConfig.block_size]
       
#         logits, _ = model(ids, targets=None)
#         logits = logits[:, -1, :]
#         probs = F.softmax(logits, dim = -1)
#         idx_new = torch.multinomial(probs, num_samples=1)
#         ids = torch.cat((ids, idx_new), dim=1)
#     ids = ids[:, -max_len:].flatten()
#     # print(ids)
#     return dataset.decode(ids)

# while 1:
#     prompt = input("我:")


#     start = time.time()
#     res = inference(model_no_cache, prompt, 300)
#     end = time.time()
#     print(f'无cache耗时： {end-start} \n 我：{prompt} \n chat-bot: {res}')


: 