### 1. 导入相关包

In [33]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

import wandb
import os

wandb.init(project="my_gpt2", name='add_kv_cache')

torch.manual_seed(1024)


<torch._C.Generator at 0x7f3a1022c270>

### 2. 定义 GPT 相关参数

In [34]:
@dataclass
class GPTConfig:
    block_size: int = 512
    batch_size: int = 12    
    n_layer: int = 6
    n_head: int = 12
    n_embed: int = 768
    head_size: int = n_embed // n_head
    dropout: float = 0.1 
    vocab_size: int = 50257
    use_cache: bool = False



### 3. 模型结构

In [35]:
class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super(SingleHeadAttention, self).__init__()
        self.key = nn.Linear(config.n_embed, config.head_size)
        self.query = nn.Linear(config.n_embed, config.head_size)
        self.value = nn.Linear(config.n_embed, config.head_size)

        self.head_size = config.head_size
        self.use_cache = config.use_cache

        # 不计算梯度的方式
        self.register_buffer(
            "attention_mask",
            torch.tril(
                torch.ones(
                    config.block_size, config.block_size
                )
            )
        )

        self.dropout = nn.Dropout(config.dropout)

        # 缓存k,v
        self.k_cache = None
        self.v_cache = None
        self.out_cache = None

    def forward(self, x):
        batch_size, seq_len, embed_size = x.size()

        # 使用缓存
        if self.use_cache:
            if self.k_cache is None:
                self.k_cache = self.key(x)
                self.v_cache = self.value(x)
            else:
                # 添加k,v新的行
                k = self.key(x[:, -1, :])
                v = self.value(x[:, -1, :])

                self.k_cache = torch.cat([self.k_cache, k], dim=1)
                self.v_cache = torch.cat([self.v_cache, v], dim=1)
            
            q = self.query(x[:, -1, :])

            weight = q @ self.k_cache.transpose(-2, -1) # (batch_size, 1, head_size) @ (batch_size, head_size, seq_len) = (batch_size, 1, seq_len)

            weight = F.softmax(weight / (self.head_size ** 0.5), dim=-1)
            weight = self.dropout(weight)
            new_token = weight @ self.v_cache # (batch_size, 1, seq_len) @ (batch_size, seq_len, head_size) = (batch_size, 1, head_size)

            if self.out_cache:
                self.out_cache = torch.cat([self.out_cache, new_token], dim=1)
            else:
                self.out_cache = new_token

            return self.out_cache   
            


        else:
            k, q, v = self.key(x), self.query(x), self.value(x)

            weight = q @ k.transpose(-2, -1)
            
            # mask
            weight = weight.masked_fill(
                self.attention_mask[:seq_len, :seq_len] == 0,
                float("-inf")
            )
            
            weight = F.softmax(weight / (self.head_size ** 0.5), dim=-1)
            weight = self.dropout(weight)
            out = weight @ v

            return out
    


In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList(
            [
                SingleHeadAttention(config) for _ in range(config.n_head)
            ]
        )
        self.fc = nn.Linear(config.n_embed, config.n_embed)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [
                head(x) for head in self.heads
            ],
            dim=-1
        )
        output = self.fc(output)
        output = self.dropout(output)
        return output

In [37]:
# MLP

class FeedForward(nn.Module):
    def __init__(self, config):
        super(FeedForward, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embed, 4 * config.n_embed),
            nn.GELU(),
            nn.Linear(4 * config.n_embed, config.n_embed),
            nn.Dropout(config.dropout)
        ) 

    def forward(self, x):
        return self.mlp(x)
        

In [38]:
class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()
        self.att = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.n_embed)
        self.ln2 = nn.LayerNorm(config.n_embed)

    def forward(self, x):
        # 注意残差链接 GPT模型使用的是前归一化
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

### 4. 完整的GPT

In [39]:
class GPT(nn.Module):
    def __init__(self, config):
        super(GPT, self).__init__()
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embed)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embed)

        self.blocks = nn.Sequential(
            *[
                Block(config) for _ in range(config.n_layer)
            ]
        )

        self.ln_final = nn.LayerNorm(config.n_embed)
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, std=0.02, mean=0)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, std=0.02, mean=0)


    def forward(self, ids, targets=None):
        batch_size, seq_len = ids.size()
        token_embedding = self.token_embedding_table(ids)
        # position_ids = torch.arange(seq_len, device=ids.device).unsqueeze(0).expand(batch_size, seq_len)
        position_ids = torch.arange(seq_len, device=ids.device)

        position_embedding = self.position_embedding_table(position_ids)

        x = token_embedding + position_embedding # 广播机制

        x = self.blocks(x)

        x = self.ln_final(x)

        logits = self.lm_head(x)

        if targets is None:
            loss = None
        
        else:
            _, seq_len, vocab_size = logits.size()
            logits = logits.view(-1, vocab_size)
            targets = targets.view(-1)

            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
        
    def generate(self, idx, max_new_tokens):
        pass


In [40]:
from datasets import load_from_disk

class MyDataset(Dataset):
    def __init__ (self, path, block_size=512):

        import tiktoken

        self.enc = tiktoken.get_encoding('gpt2')
        self.block_size = block_size

        self.eos_token = self.enc.encode(
            '<|endoftext|>',allowed_special={"<|endoftext|>"})[0]
        
        import json

        self.encoded_data = []

        raw_data = load_from_disk(path)

        eos_encoded = []
        for text in raw_data['train']['text'][:]:

            # print(example)
            encoded_text = self.enc.encode(text)
            eos_encoded.extend(encoded_text + [self.eos_token])
        

        self.exmaples = []

        # 分割出训练样本
        for i in range(0, len(eos_encoded), block_size):
            chunk = eos_encoded[i:i+block_size + 1] 
            if len(chunk) < block_size + 1:
                chunk += [self.eos_token] * (block_size + 1 - len(chunk))
            self.exmaples.append(chunk)
        
    
    def __len__(self):
        return len(self.exmaples)

    def __getitem__(self, idx):
        chunk = self.exmaples[idx]
        x = torch.tensor(chunk[:-1])
        y = torch.tensor(chunk[1:])
        return {
            "input_ids": x,
            "labels": y
        }

    def encode(self, text):
        """编码器"""
        return self.enc.encode(text) 

    def decode(self, ids):
        """解码器"""
        ids = list(ids)
        return self.enc.decode(ids)




In [41]:
dataset = MyDataset('./dataset/anime_text/')

In [42]:
len(dataset)

355791

In [43]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])

In [44]:
len(train_dataset), len(test_dataset)

(320212, 35579)

In [45]:
train_loader, test_loader = DataLoader(train_dataset, batch_size=12, shuffle=True), DataLoader(test_dataset, batch_size=12, shuffle=False)

### 5. 训练

In [46]:
model = GPT(GPTConfig())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

In [47]:
total_params = sum(p.numel() for p in model.parameters())

print(f'Total parameters: {total_params/ 1e6}M')

Total parameters: 120.116736M


In [48]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
sheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

In [49]:
next(iter(train_loader))

{'input_ids': tensor([[  115,   163,   230,  ...,   225, 12859,   228],
         [  171,   120,   249,  ...,   290, 15347,  2095],
         [49035,   119, 27950,  ...,   240,   164,   231],
         ...,
         [  236, 26344,   246,  ...,   171,   120,   234],
         [21410, 31660, 37772,  ..., 13160,   286,   262],
         [  238,   171,   120,  ...,    34, 25634, 20025]]),
 'labels': tensor([[  163,   230,   115,  ..., 12859,   228, 31660],
         [  120,   249, 25001,  ..., 15347,  2095,   287],
         [  119, 27950,   249,  ...,   164,   231,   110],
         ...,
         [26344,   246, 13783,  ...,   120,   234, 26193],
         [31660, 37772,   246,  ...,   286,   262, 43998],
         [  171,   120,   231,  ..., 25634, 20025,   171]])}

In [50]:
for batch_idx, batch in enumerate(train_loader):
    ids = batch['input_ids'].to(device)
    targets = batch['labels'].to(device)
    logits, loss = model(ids, targets)
    print(loss)
    break

tensor(10.9717, device='cuda:0', grad_fn=<NllLossBackward0>)


In [51]:
len(train_loader)

26685

In [52]:
def train(model, optimizer, sheduler, train_loader, test_loader, device):
    model.train()
  
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        ids = batch['input_ids'].to(device)
        targets = batch['labels'].to(device)

        # 计算loss
        logits, loss = model(ids, targets)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 调整学习率
        sheduler.step()

        total_loss += loss.item()

        # 打印信息
        if batch_idx % 100 == 0:
            print(f' Batch: {batch_idx}, Loss: {loss.item():.4f}')

        wandb.log({"loss": loss.item()})

    return total_loss / len(train_loader)

In [53]:
def eval(model, test_loader, device):

    model.eval()

    val_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            ids = batch['input_ids'].to(device)
            targets = batch['labels'].to(device)

            logits, loss = model(ids, targets)
            val_loss += loss.item()
            
            print(f'Val Loss: {val_loss / len(test_loader):.4f}')

    return val_loss / len(test_loader)

In [54]:
def train_gpt(epoch_num = 3):

    for epoch in range(epoch_num):

        train_loss = train(model, optimizer, sheduler, train_loader, test_loader, device)
        val_loss = eval(model, test_loader, device)

        print(f'Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}')

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'sheduler_state_dict': sheduler.state_dict(),
            'val_loss': val_loss,
        }
        


        # 定义保存路径
        save_path = f'checkpoint/gpt_all/epoch_{epoch}.pth'

        # 检查文件夹是否存在，如果不存在则创建
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        # 保存模型
        torch.save(checkpoint, save_path)   
        


        torch.save(checkpoint, save_path)
    
    print('finish training!!')


In [55]:
# train_gpt(3)

In [56]:
dataset = MyDataset('./dataset/anime_text/')
chpt_path = './checkpoint/gpt_all/epoch_0.pth'
model = GPT(GPTConfig())
checkpoint = torch.load(chpt_path)
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load(chpt_path)


<All keys matched successfully>

In [57]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [74]:
# 进行推理


def inference(model, prompt, max_len):
    model.eval()
    model = model.to(device)
    ids = torch.tensor(dataset.encode(prompt)).to(device)
    ids = ids.unsqueeze(0)
    # print(ids)
    res = []
    for _ in range(max_len):
        b, sl = ids.size()
        if sl > GPTConfig.block_size:
            ids = ids[:, :-GPTConfig.block_size]
       
        logits, _ = model(ids, targets=None)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim = -1)
        idx_new = torch.multinomial(probs, num_samples=1)
        ids = torch.cat((ids, idx_new), dim=1)
    ids = ids[:, -max_len:].flatten()
    # print(ids)
    return dataset.decode(ids)


In [77]:

while 1:
    prompt = input("我:")
    res = inference(model, prompt, 300)
    print(f'我：{prompt} \n chat-bot: {res}')


chat-bot: 。你可以叫我沙耶娅谢你。

这是自己认识的，能看穿七人去的九条须太郎，她憎恨精灵城世界的事物，也非常自恋自己的荒野的头发。精灵深影将你知道哪些是九条须太郎的威入。须太郎杀害查，你们嚁犍到须太郎那里藐了。

经历
：
进入哪些的资料笼结束后，在一家人面前装刻身份膝靴。这个须太郎因为她
chat-bot: 之剑已被L'Antica打死，封印了“L'Antica”（詹姆斯），在BOLDBES训练时，她升格为G'15UB型电力机，服役于大卡多娘。

总的来说，缪尔薇是妖精英剑这种角色，在游戏中扮演着重要的角色。她的出现让过去具备了强大的战斗能力和情感。<|endoftext|>切激（Ringo）是由TRIGGER所制作的游戏《掠夺者》及其衍生作品的登场角色。

切激是婆罗尔的
chat-bot: 原酬的女岛孝略强大，被上井樱、鹰堂附身一同进行工作。她还会能在陽夜最萌大会的物品中获得超人称呼，保护一井虎发生突然见返的贵族物品帮他，塞尔薇遇到了她，黑衣索起了一位“电风虎风”称呼，即米斯艾莉正在维护正在暗杀邪恶和帝国不利生活而遇害的怪物。在黑衣索起的过程中，她坦诚地微泄无比
chat-bot: ？你请叫我要把我当作妹妹产生近感了很久。……，我母亲照顾很多玩，还会做鸭子，去拉面，并在咖啡店工作过程中打工。

除了家务人，我还是一个很好的朋友。我担任同好会的缘故，结束后我和妹妹一起相识，帮助我们继续过问题。这是因为我知道只是为了帮助我的朋友，我会积极地面对。我喜欢快乐和照顾人，�
chat-bot: ？”，我大哭得好柚子，经常会把其他孩子带回幌室带回家里。我非常讨厌打柚子，因为我背着习惯和他人相处的朋友很深，所以有时候还会折磨旁人的存在。我是一个野外，类型的百合宿舍狂风地。我和雪非常要好，常常用令地上学，无论即使是各种裁缝。

在宿舍狂们的帮助下，我觉得他们能成为第一个看见并接收集的�
chat-bot: 舰。她的潜艇是一把敦号舰，可以通过突击搜索敌人的委托角色。

赤麻心在游戏的剧情中经历了许多事件。作为第七届以及执行任务过错误的保镖，赤麻与他的琨巴赫一同脱逃。她一直哭泣，但在第五人格模式下一直陪伴着父母。她更喜欢欺负他，给人的印象是配合的，是一个睡梦类型的中二病，性格活泼的角色。
chat-bot: 队、现代战力驾驶等人共同负责，首都无穷、使用姓“玩笑”。凌百

KeyboardInterrupt: Interrupted by user