### 1. 导入相关包

In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

import wandb
import os

wandb.init(project="my_gpt2", name='add_kv_cache')

torch.manual_seed(1024)


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mxyxustbcs[0m ([33mxyxustbcs-xx[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


<torch._C.Generator at 0x7f53b7d02d30>

### 2. 定义 GPT 相关参数

In [2]:
@dataclass
class GPTConfig:
    block_size: int = 512
    batch_size: int = 12    
    n_layer: int = 6
    n_head: int = 12
    n_embed: int = 768
    head_size: int = n_embed // n_head
    dropout: float = 0.1 
    vocab_size: int = 50257
    use_cache: bool = False



### 3. 模型结构

In [3]:
class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super(SingleHeadAttention, self).__init__()
        self.key = nn.Linear(config.n_embed, config.head_size)
        self.query = nn.Linear(config.n_embed, config.head_size)
        self.value = nn.Linear(config.n_embed, config.head_size)

        self.head_size = config.head_size
        self.use_cache = config.use_cache

        # 不计算梯度的方式
        self.register_buffer(
            "attention_mask",
            torch.tril(
                torch.ones(
                    config.block_size, config.block_size
                )
            )
        )

        self.dropout = nn.Dropout(config.dropout)

        # 缓存k,v
        self.k_cache = None
        self.v_cache = None
        self.out_cache = None

    def forward(self, x):
        batch_size, seq_len, embed_size = x.size()

        # 使用缓存
        if self.use_cache:
            
            q_to_use = None

            if self.k_cache is None:
                self.k_cache = self.key(x)
                self.v_cache = self.value(x)
                q_to_use = self.query(x)
            else:
                # 添加k,v新的行
                k = self.key(x[:, -1, :]).unsqueeze(1)
                v = self.value(x[:, -1, :]).unsqueeze(1)

                self.k_cache = torch.cat([self.k_cache, k], dim=1)
                self.v_cache = torch.cat([self.v_cache, v], dim=1)
            
                q_to_use = self.query(x[:, -1, :]).unsqueeze(1)


            weight = q_to_use @ self.k_cache.transpose(-2, -1) # [batch_size, 1 or init, h_size] @ [batch_size, h_size, seq_len]

            weight = F.softmax(weight / (self.head_size ** 0.5), dim=-1) 
            weight = self.dropout(weight)
            new_token = weight @ self.v_cache

            
            if self.out_cache is not None:
                # print("拼接outcache")
                self.out_cache = torch.cat([self.out_cache, new_token], dim=1)
                # print(f"self.out_cahe: {self.out_cache.shape}")
            else:
                self.out_cache = new_token

            return self.out_cache   
            


        else:
            k, q, v = self.key(x), self.query(x), self.value(x)

            weight = q @ k.transpose(-2, -1)
            
            # mask
            weight = weight.masked_fill(
                self.attention_mask[:seq_len, :seq_len] == 0,
                float("-inf")
            )
            
            weight = F.softmax(weight / (self.head_size ** 0.5), dim=-1)
            weight = self.dropout(weight)
            out = weight @ v

            return out
    


In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList(
            [
                SingleHeadAttention(config) for _ in range(config.n_head)
            ]
        )
        self.fc = nn.Linear(config.n_embed, config.n_embed)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [
                head(x) for head in self.heads
            ],
            dim=-1
        )

        output = self.fc(output)
        output = self.dropout(output)
        # print(f"multi ouput: {output.shape}")        

        return output

In [5]:
# MLP

class FeedForward(nn.Module):
    def __init__(self, config):
        super(FeedForward, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embed, 4 * config.n_embed),
            nn.GELU(),
            nn.Linear(4 * config.n_embed, config.n_embed),
            nn.Dropout(config.dropout)
        ) 

    def forward(self, x):
        return self.mlp(x)
        

In [6]:
class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()
        self.att = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.n_embed)
        self.ln2 = nn.LayerNorm(config.n_embed)

    def forward(self, x):
        # 注意残差链接 GPT模型使用的是前归一化

        # print(f'x: {x.shape}, att: {self.att(self.ln1(x)).shape}')

        x = x + self.att(self.ln1(x))
        
        x = x + self.ffn(self.ln2(x))
        return x

### 4. 完整的GPT

In [7]:
class GPT(nn.Module):
    def __init__(self, config):
        super(GPT, self).__init__()
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embed)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embed)

        self.blocks = nn.Sequential(
            *[
                Block(config) for _ in range(config.n_layer)
            ]
        )

        self.ln_final = nn.LayerNorm(config.n_embed)
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, std=0.02, mean=0)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, std=0.02, mean=0)


    def forward(self, ids, targets=None):
        batch_size, seq_len = ids.size()
        token_embedding = self.token_embedding_table(ids)
        # position_ids = torch.arange(seq_len, device=ids.device).unsqueeze(0).expand(batch_size, seq_len)
        position_ids = torch.arange(seq_len, device=ids.device)

        position_embedding = self.position_embedding_table(position_ids)

        x = token_embedding + position_embedding # 广播机制

        x = self.blocks(x)

        x = self.ln_final(x)

        logits = self.lm_head(x)

        if targets is None:
            loss = None
        
        else:
            _, seq_len, vocab_size = logits.size()
            logits = logits.view(-1, vocab_size)
            targets = targets.view(-1)

            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
        
    def generate(self, idx, max_new_tokens):
        pass


In [8]:
from datasets import load_from_disk

class MyDataset(Dataset):
    def __init__ (self, path, block_size=512):

        import tiktoken

        self.enc = tiktoken.get_encoding('gpt2')
        self.block_size = block_size

        self.eos_token = self.enc.encode(
            '<|endoftext|>',allowed_special={"<|endoftext|>"})[0]
        
        import json

        self.encoded_data = []

        raw_data = load_from_disk(path)

        eos_encoded = []
        for text in raw_data['train']['text'][:]:

            # print(example)
            encoded_text = self.enc.encode(text)
            eos_encoded.extend(encoded_text + [self.eos_token])
        

        self.exmaples = []

        # 分割出训练样本
        for i in range(0, len(eos_encoded), block_size):
            chunk = eos_encoded[i:i+block_size + 1] 
            if len(chunk) < block_size + 1:
                chunk += [self.eos_token] * (block_size + 1 - len(chunk))
            self.exmaples.append(chunk)
        
    
    def __len__(self):
        return len(self.exmaples)

    def __getitem__(self, idx):
        chunk = self.exmaples[idx]
        x = torch.tensor(chunk[:-1])
        y = torch.tensor(chunk[1:])
        return {
            "input_ids": x,
            "labels": y
        }

    def encode(self, text):
        """编码器"""
        return self.enc.encode(text) 

    def decode(self, ids):
        """解码器"""
        ids = list(ids)
        return self.enc.decode(ids)




In [9]:
dataset = MyDataset('./dataset/anime_text/')

In [10]:
len(dataset)

355791

In [11]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])

In [12]:
len(train_dataset), len(test_dataset)

(320212, 35579)

In [13]:
train_loader, test_loader = DataLoader(train_dataset, batch_size=12, shuffle=True), DataLoader(test_dataset, batch_size=12, shuffle=False)

### 5. 训练

In [14]:
# model = GPT(GPTConfig())
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = model.to(device)

In [15]:
# total_params = sum(p.numel() for p in model.parameters())
# 
# print(f'Total parameters: {total_params/ 1e6}M')

In [16]:
# optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# sheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

In [17]:
# next(iter(train_loader))

In [18]:
# for batch_idx, batch in enumerate(train_loader):
#     ids = batch['input_ids'].to(device)
#     targets = batch['labels'].to(device)
#     logits, loss = model(ids, targets)
#     print(loss)
#     break

In [19]:
len(train_loader)

26685

In [20]:
def train(model, optimizer, sheduler, train_loader, test_loader, device, epoch):
    model.train()
  
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        ids = batch['input_ids'].to(device)
        targets = batch['labels'].to(device)

        # 计算loss
        logits, loss = model(ids, targets)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 调整学习率
        sheduler.step()

        total_loss += loss.item()

        # 打印信息
        if batch_idx % 100 == 0:
            print(f' Batch: {batch_idx}, Loss: {loss.item():.4f}')

            checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'sheduler_state_dict': sheduler.state_dict(),
            'train_loss': loss,
        }
        
            # 定义保存路径
            save_path = f'checkpoint/gpt_all/epoch{epoch}-{batch_idx}.pth'

            # 检查文件夹是否存在，如果不存在则创建
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            # 保存模型
            torch.save(checkpoint, save_path)   


        wandb.log({"loss": loss.item()})

    return total_loss / len(train_loader)

In [21]:
def eval(model, test_loader, device):

    model.eval()

    val_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            ids = batch['input_ids'].to(device)
            targets = batch['labels'].to(device)

            logits, loss = model(ids, targets)
            val_loss += loss.item()
            
            print(f'Val Loss: {val_loss / len(test_loader):.4f}')

            


    return val_loss / len(test_loader)

In [22]:
def train_gpt(epoch_num = 3):

    for epoch in range(epoch_num):

        train_loss = train(model, optimizer, sheduler, train_loader, test_loader, device, epoch)
        val_loss = eval(model, test_loader, device)

        print(f'Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}')


    
    print('finish training!!')


In [23]:
# train_gpt(3)

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [25]:
import time

# 进行推理
# dataset = MyDataset('./dataset/anime_text/')
chpt_path = './checkpoint/gpt_all/epoch0-100.pth'

infer_config = GPTConfig()
infer_config.use_cache = True
model = GPT(infer_config)
checkpoint = torch.load(chpt_path)
model.load_state_dict(checkpoint['model_state_dict'])

no_cache_config = GPTConfig()
no_cache_config.use_cache = False
model_no_cache = GPT(no_cache_config)
checkpoint = torch.load(chpt_path)
model_no_cache.load_state_dict(checkpoint['model_state_dict'])

def inference(model, prompt, max_len):
    model.eval()
    model = model.to(device)
    ids = torch.tensor(dataset.encode(prompt)).to(device)
    ids = ids.unsqueeze(0)
    # print(ids)
    res = []
    for _ in range(max_len):
        b, sl = ids.size()
        if sl > GPTConfig.block_size:
            ids = ids[:, :-GPTConfig.block_size]
       
        logits, _ = model(ids, targets=None)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim = -1)
        idx_new = torch.multinomial(probs, num_samples=1)
        ids = torch.cat((ids, idx_new), dim=1)
    ids = ids[:, -max_len:].flatten()
    # print(ids)
    return dataset.decode(ids)

while 1:
    prompt = input("我:")


    start = time.time()
    res = inference(model_no_cache, prompt, 300)
    end = time.time()
    print(f'无cache耗时： {end-start} \n 我：{prompt} \n chat-bot: {res}')


无cache耗时： 8.544002056121826 
 我：你好 
 chat-bot: 能暴击够觉星钱。

总的来说，峰是一个具有漠火狮子，有时会陷入异纱狱的异系魔王。她希望能够帮助一口问题，并且在我们的玩伴们忘怀，但也会感受到惊要的秘密。<|endoftext|>七宫圣嘿尔是游戏《洛克人系列》中的大登场角色。他们是一个认真太亲，具有能力的能力。由于任何地穿着其他人和公主魔光形的魅力，停止者是青梅竹马的少女。在�
无cache耗时： 9.520055532455444 
 我：n 
 chat-bot: er for his latest energetic and optimistic teacher. He sets his teammates, a squad woman why Keito's nice normal environment, showcased his own situations. He Yonoro Kudarbu Sweep his occupation-hand man who wore other realize a handsome Tooro became able to lend a helping duelelt Kudetsu. Despite Dr. When Arao, Masura has deceptive strength, In his two sister Mouse, pit himself his death but is a waitress, who stopped him to Mountemi, triumphant. He was revealed to burst the wise and strong from existence in Animalents.

Tyike is a Bilation century known for how tying him known for his friend due to his colleagues, Kagumei Magao was soc beyond due to his own years from him. He will survived informed him by Vacu's from his Duelmansh

KeyboardInterrupt: 