# 1. Packages

In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json
import tiktoken
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast, GradScaler
import torch.multiprocessing as mp

torch.manual_seed(1024)

<torch._C.Generator at 0x7f0ba0e923d0>

# 2. Parameters

In [58]:
# @dataclass
# class GPTConfig:
#     block_size: int = 512 # 文本最大长度，max_seq
#     batch_size: int = 12
#     n_layer: int = 12
#     n_head: int = 12
#     n_embd: int = 768     # hidden_dim, hidden_size; 此处emb_size
#     hidden_dim: int = n_embd
#     # 为了 tie_embedding_weight
#     dropout: float = 0.1
#     head_size: int = n_embd // n_head
#     # vocab_size
#     # 和GPT2官方tokenizer
#     vocab_size: int = 50527

@dataclass
class GPTConfig:
    # 减小模型配置
    block_size: int = 256  # 从512减至256
    batch_size: int = 4    # 从12减至4
    n_layer: int = 6       # 从12减至6
    n_head: int = 8        # 从12减至8
    n_embd: int = 384      # 从768减至384
    hidden_dim: int = n_embd
    dropout: float = 0.5
    head_size: int = n_embd // n_head
    vocab_size: int = 50527

# 3. Structure of GPT

In [47]:
# 1. Single head attention
class SingleHeadAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.key = nn.Linear(config.hidden_dim, config.head_size)
        self.value = nn.Linear(config.hidden_dim, config.head_size)
        self.query = nn.Linear(config.hidden_dim, config.head_size)
        self.head_size = config.head_size

        # 用 register_buffer 注册 attention_mask
        # 不用计算梯度，节约内存和显存
        self.register_buffer(
            "attention_mask",
            # tril 是下三角
            # block_size 文本最大长度, 512
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            )
        )
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        weight = q @ k.transpose(-2, -1)  # @ 是 torch.matmul
        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0,
            float('-inf')
        )
        # 注意计算 weight 时要除以 sqrt(d_k)
        weight = F.softmax(weight, dim=-1) / math.sqrt(self.head_size)
        
        # dropout 要放在weight后
        weight = self.dropout(weight)
        output = weight @ v
        return output
    
# 2. Multi head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads =  nn.ModuleList(
            [
                SingleHeadAttention(config)
                for _ in range(config.n_head)
            ]
        )
        self.proj = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [h(x) for h in self.heads],
            dim = -1
        )
        output = self.proj(output)
        output = self.dropout(output)
        return output
    
# 3. feed forward(MLP)
# class FeedForward(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.net = nn.Sequential(
#             nn.Linear(config.hidden_dim, 4 * config.hidden_dim),  # swiglu # up to 8/3
#             nn.GELU,
#             nn.Linear(4 * config.hidden_dim, config.hidden_dim),
#             nn.Dropout(config.dropout)
#         )
#     def forward(self, x):
#         return self.net(x)
    
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),  # <-- Fixed comma here
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout)
        )

    def forward(self, x):
        return self.net(x)
# 4. block
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.att = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.hidden_dim)
        self.ln2 = nn.LayerNorm(config.hidden_dim)


    def forward(self, x):
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

# 5. GPT
# # class GPT(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         # (embedding, position, norm, mlp, block)
#         # position embedding 从0，1, ...embedding -> rope
#         # norm   layer norm -> rms norm
#         # mlp -> swiglu
#         # mha -> gqa
#         self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
#         self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
#         self.blocks = nn.Sequential(
#             *[Block(config) for _ in range(config.n_layer)]
#         )
#         self.ln_final = nn.LayerNorm(config.n_embd)
#         self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
#         # now SLM use tie weight to lessen parameters

#         # linear (4->8), weight实际上的shape是8*4
#         self.token_embedding_table.weight = self.lm_head.weight

#     def _init_weights(self, module):
#         if isinstance(module, nn.Linear):
#             # 初始化为高斯分布
#             torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
#             if module.bias is not None:
#                 torch.nn.init.zeros_(module.bias)
#             elif isinstance(module, nn.Embedding):
#                 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

#     def forward(self, idx, targets=None):
#         # idx 输入是 token ids
#         # targets 是目标 token ids
#         # shape 一样
#         batch, seq_len = idx.size()  # (batch, seq_len)
#         token_emb = self.token_embedding_table(idx)  # (batch, seq_len, n_embd)
#         pos_emb = self.position_embedding_table(
#             torch.arange(seq_len, device=idx.device)
#         )
#         # token_embedding, position_embedding 可以相加
#         x = token_emb + pos_emb
#         x = self.blocks(x)
#         x = self.ln_final(x)
#         logits = self.lm_head(x)
#         if targets is None:
#             loss = None
#         else:
#             batch, seq_len, vocab_size = logits.size()
#             logits = logits.view(batch * seq_len, vocab_size)
#             targets = targets.view(batch * seq_len)
#             loss = F.cross_entropy(logits, targets)
#         return logits, loss
    
#     def generate(self, idx, max_new_tokens):
#         pass

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )
        self.ln_final = nn.LayerNorm(config.n_embd)
        # 修改输出层以匹配emoji分类任务
        self.classifier = nn.Linear(config.n_embd, len(dataset.label_encoder.classes_))
        
        # 初始化权重
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()  # batch_size, sequence_length
        
        # 1. 获取token和位置嵌入
        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = token_emb + pos_emb
        
        # 2. 通过transformer blocks
        x = self.blocks(x)
        x = self.ln_final(x)
        
        # 3. 对序列进行池化，取平均值作为文本表示
        x = x.mean(dim=1)  # [B, n_embd]
        
        # 4. 通过分类器得到emoji预测
        logits = self.classifier(x)  # [B, num_classes]
        
        # 5. 如果提供了目标，计算损失
        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # 生成功能可以根据需要实现
        pass

# 4. construct the dataset

## understanding the input of model

In [48]:
# class MyDataset(Dataset):
#     def __init__(self, path, block_size=512):

#         import tiktoken
#         self.enc = tiktoken.get_encoding("gpt2")
#         self.block_size = block_size  # pos最大长度

#         self.encoded_data = []
#         # <|endoftext|> 分割不同文本
#         self.eos_token = self.enc.encode(
#             "<|endoftext|>",
#             allowed_special={"<|endoftext|>"}
#         )[0]
        
#         self.max_lines = 1000
#         import json

#         raw_data = []
#         with open(path, 'r') as f:
#             for i, line in enumerate(f):
#                 if i >= self.max_lines:
#                     break
#                 try:
#                     text = json.loads(line.strip())["text"]
#                     raw_data.append(text)
#                 except Exception as e:
#                     continue

#         full_encoded =  []
#         for text in raw_data:
#             encoded_text = self.enc.encode(text)
#             full_encoded.extend(encoded_text + [self.eos_token])

#         # block_size is 512
#         # 长 -> 短(512)
#         for i in range(0, len(full_encoded), self.block_size):
#             chunk = full_encoded[i:i+self.block_size+1] # 512 实际上是513
#             if len(chunk) < self.block_size + 1:
#                 chunk = chunk + [self.eos_token] * (self.block_size + 1 - len(chunk))
#             self.encoded_data.append(chunk)

#     def __len__(self):
#         return len(self.encoded_data)
    
#     def __getitem__(self, idx):
#         chunk = self.encoded_data[idx]
#         x = torch.tensor(chunk[:-1], dtype=torch.long)
#         y = torch.tensor(chunk[1:], dtype=torch.long)
#         return x, y
    
#     def encode(self, text):
#         """将文本编码为token IDs"""
#         return self.enc.encode(text)
    
#     def decode(self, ids):
#         """将token IDs解码为文本"""
#         return self.enc.decode(ids)

class MyDataset(Dataset):
    def __init__(self, path, block_size=512):
        self.block_size = block_size
        self.texts = []
        self.emojis = []
        
        # 加载数据集
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            for item in data:
                self.texts.append(item["text"])
                self.emojis.append(item["emoji"])

        # 将表情符号转换为整数标签
        self.label_encoder = LabelEncoder()
        self.emoji_labels = self.label_encoder.fit_transform(self.emojis)
        
        # 初始化tokenizer
        self.tokenizer = tiktoken.get_encoding("gpt2")

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        emoji_label = self.emoji_labels[idx]

        # 将文本编码为token IDs
        token_ids = self.tokenizer.encode(text)

        # 处理序列长度
        if len(token_ids) > self.block_size:
            token_ids = token_ids[:self.block_size]
        else:
            # 填充到固定长度
            token_ids = token_ids + [self.tokenizer.eot_token] * (self.block_size - len(token_ids))

        return torch.tensor(token_ids, dtype=torch.long), torch.tensor(emoji_label, dtype=torch.long)

    def get_emoji_mapping(self):
        return dict(zip(self.label_encoder.classes_, 
                       self.label_encoder.transform(self.label_encoder.classes_)))

# 5. run the functions

In [59]:
# model = GPT(GPTConfig())
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = model.to(device)

# # 打印模型参数量

# total_params = sum(p.numel() for p in model.parameters())
# print(f"Total parameters: {total_params / 1e6} M")

# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# #设置cosine学习率
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

# 分布式训练设置函数
def setup_ddp(rank, world_size):
    """初始化DDP设置"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup_ddp():
    """清理DDP进程组"""
    dist.destroy_process_group()

# 分布式训练函数
def train_ddp(rank, world_size, config):
    # 设置DDP
    setup_ddp(rank, world_size)
    
    # 初始化数据集
    dataset = MyDataset(path="dataset.json", block_size=config.block_size)
    
    # 划分训练集和验证集
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # 创建分布式采样器
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)
    
    # 创建DataLoader
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        sampler=train_sampler,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        sampler=val_sampler,
        pin_memory=True
    )
    
    # 初始化模型
    model = GPT(config).to(rank)
    model = DDP(model, device_ids=[rank])
    
    # 初始化优化器和学习率调度器
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
    scaler = GradScaler()
    
    # 训练循环
    num_epochs = 3
    for epoch in range(num_epochs):
        train_sampler.set_epoch(epoch)  # 确保每个epoch的数据分布不同
        model.train()
        total_loss = 0
        
        for batch_idx, (token_ids, emoji_labels) in enumerate(train_loader):
            token_ids = token_ids.to(rank)
            emoji_labels = emoji_labels.to(rank)
            
            with autocast():
                logits, loss = model(token_ids, targets=emoji_labels)
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            
            total_loss += loss.item()
            
            if rank == 0 and batch_idx % 10 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for token_ids, emoji_labels in val_loader:
                token_ids = token_ids.to(rank)
                emoji_labels = emoji_labels.to(rank)
                with autocast():
                    logits, loss = model(token_ids, targets=emoji_labels)
                val_loss += loss.item()
        
        if rank == 0:
            print(f'Epoch: {epoch + 1}, Train Loss: {total_loss/len(train_loader):.4f}, '
                  f'Val Loss: {val_loss/len(val_loader):.4f}')
            
            # 保存模型
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.module.state_dict(),  # 注意这里使用model.module
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': val_loss / len(val_loader)
            }
            torch.save(checkpoint, f'model_epoch_{epoch + 1}.pt')
    
    cleanup_ddp()

# 单GPU训练函数
def train_single_gpu(gpu_id=0):
    # 设置要使用的GPU
    torch.cuda.set_device(gpu_id)
    device = torch.device(f'cuda:{gpu_id}')
    
    config = GPTConfig()
    
    # 初始化数据集和数据加载器
    dataset = MyDataset(path="dataset.json", block_size=config.block_size)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        pin_memory=True
    )
    
    # 初始化模型
    model = GPT(config).to(device)
    
    # 初始化优化器等
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
    scaler = GradScaler()
    
    # 训练循环
    num_epochs = 30
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (token_ids, emoji_labels) in enumerate(train_loader):
            token_ids = token_ids.to(device)
            emoji_labels = emoji_labels.to(device)
            
            with autocast():
                logits, loss = model(token_ids, targets=emoji_labels)
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for token_ids, emoji_labels in val_loader:
                token_ids = token_ids.to(device)
                emoji_labels = emoji_labels.to(device)
                with autocast():
                    logits, loss = model(token_ids, targets=emoji_labels)
                val_loss += loss.item()
        
        print(f'Epoch: {epoch + 1}, Train Loss: {total_loss/len(train_loader):.4f}, '
              f'Val Loss: {val_loss/len(val_loader):.4f}')
        
        # 保存模型
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss / len(val_loader)
        }
        torch.save(checkpoint, f'checkpoint/model_epoch_{epoch + 1}.pt')

In [60]:
device

'cuda'

In [61]:
# 选择训练模式：

# 1. 使用单个GPU（例如GPU 1）
train_single_gpu(gpu_id=1)

# 或者

# 2. 使用所有可用GPU进行分布式训练
# def main_ddp():
#     world_size = torch.cuda.device_count()  # 获取可用的GPU数量
#     config = GPTConfig()
#     mp.spawn(
#         train_ddp,
#         args=(world_size, config),
#         nprocs=world_size,
#         join=True
#     )

# main_ddp()  # 取消注释此行来使用多GPU训练

Epoch: 0, Batch: 0, Loss: 3.9482
Epoch: 0, Batch: 10, Loss: 3.6270
Epoch: 0, Batch: 20, Loss: 3.3979
Epoch: 1, Train Loss: 4.0807, Val Loss: 4.3141
Epoch: 1, Batch: 0, Loss: 3.6963
Epoch: 1, Batch: 10, Loss: 4.2041
Epoch: 1, Batch: 20, Loss: 4.4878
Epoch: 2, Train Loss: 3.7246, Val Loss: 4.6834
Epoch: 2, Batch: 0, Loss: 3.5615
Epoch: 2, Batch: 10, Loss: 2.9434
Epoch: 2, Batch: 20, Loss: 3.1699
Epoch: 3, Train Loss: 3.6081, Val Loss: 4.8612
Epoch: 3, Batch: 0, Loss: 2.4219
Epoch: 3, Batch: 10, Loss: 3.4126
Epoch: 3, Batch: 20, Loss: 4.1592
Epoch: 4, Train Loss: 3.4786, Val Loss: 4.5826


KeyboardInterrupt: 

# 6. explanation of the checkpoint.pt

In [54]:
def inspect_checkpoint(checkpoint_path):
    """
    检查保存的模型checkpoint文件
    参数:
        checkpoint_path: 模型文件路径，如 'model_epoch_1.pt'
    """
    # 加载checkpoint
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    print(f"Checkpoint信息:")
    print("-" * 50)
    
    # 打印基本信息
    print(f"Epoch: {checkpoint['epoch']}")
    print(f"Validation Loss: {checkpoint['val_loss']:.4f}")
    
    # 检查模型状态字典
    model_state = checkpoint['model_state_dict']
    print("\n模型参数信息:")
    print("-" * 50)
    total_params = 0
    for name, param in model_state.items():
        param_count = param.numel()
        total_params += param_count
        print(f"{name}: shape {list(param.shape)}, parameters: {param_count:,}")
    print(f"\n总参数量: {total_params:,}")
    
    # 检查优化器状态
    optimizer_state = checkpoint['optimizer_state_dict']
    print("\n优化器信息:")
    print("-" * 50)
    print(f"优化器类型: {optimizer_state['param_groups'][0]['lr']:.2e} (学习率)")
    
    # 检查学习率调度器状态
    scheduler_state = checkpoint['scheduler_state_dict']
    print("\n学习率调度器信息:")
    print("-" * 50)
    print(f"Last epoch: {scheduler_state['last_epoch']}")
    
    return checkpoint

# 使用示例:
# 查看某个特定epoch的模型
checkpoint_path = 'model_epoch_1.pt'  # 更改为你想查看的文件名
checkpoint = inspect_checkpoint(checkpoint_path)

# 如果想加载模型进行预测，可以这样做：
def load_model_for_inference(checkpoint_path, device='cuda'):
    """
    加载模型用于推理
    """
    # 创建模型实例
    config = GPTConfig()
    model = GPT(config)
    
    # 加载checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # 加载模型权重
    if 'module.state_dict' in checkpoint['model_state_dict']:
        # 如果是从DDP模型保存的
        model.load_state_dict({k[7:]: v for k, v in checkpoint['model_state_dict'].items()})
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
    
    model = model.to(device)
    model.eval()  # 设置为评估模式
    return model

# 使用示例:
# model = load_model_for_inference('model_epoch_1.pt')

Checkpoint信息:
--------------------------------------------------
Epoch: 1
Validation Loss: 4.1938

模型参数信息:
--------------------------------------------------
token_embedding_table.weight: shape [50527, 384], parameters: 19,402,368
position_embedding_table.weight: shape [256, 384], parameters: 98,304
blocks.0.att.heads.0.attention_mask: shape [256, 256], parameters: 65,536
blocks.0.att.heads.0.key.weight: shape [48, 384], parameters: 18,432
blocks.0.att.heads.0.key.bias: shape [48], parameters: 48
blocks.0.att.heads.0.value.weight: shape [48, 384], parameters: 18,432
blocks.0.att.heads.0.value.bias: shape [48], parameters: 48
blocks.0.att.heads.0.query.weight: shape [48, 384], parameters: 18,432
blocks.0.att.heads.0.query.bias: shape [48], parameters: 48
blocks.0.att.heads.1.attention_mask: shape [256, 256], parameters: 65,536
blocks.0.att.heads.1.key.weight: shape [48, 384], parameters: 18,432
blocks.0.att.heads.1.key.bias: shape [48], parameters: 48
blocks.0.att.heads.1.value.weight: 

# 7. start our inference!

In [62]:
class EmojiPredictor:
    def __init__(self, model_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
        # 加载模型和配置
        self.device = device
        self.config = GPTConfig()
        self.model = GPT(self.config)
        
        # 加载保存的模型权重
        checkpoint = torch.load(model_path, map_location=device)
        if 'module.state_dict' in checkpoint:
            # 如果是从DDP模型保存的
            self.model.load_state_dict({k[7:]: v for k, v in checkpoint['model_state_dict'].items()})
        else:
            self.model.load_state_dict(checkpoint['model_state_dict'])
            
        self.model = self.model.to(device)
        self.model.eval()
        
        # 初始化tokenizer
        self.tokenizer = tiktoken.get_encoding("gpt2")
        
        # 从训练数据中获取emoji映射
        self.dataset = MyDataset("dataset.json")
        self.emoji_mapping = self.dataset.get_emoji_mapping()
        self.reverse_emoji_mapping = {v: k for k, v in self.emoji_mapping.items()}

    def predict(self, text):
        """
        预测输入文本对应的表情
        参数:
            text: 输入文本
        返回:
            predicted_emoji: 预测的表情
            probabilities: 每个表情的概率分布
        """
        # 对输入文本进行编码
        token_ids = self.tokenizer.encode(text)
        
        # 处理序列长度
        if len(token_ids) > self.config.block_size:
            token_ids = token_ids[:self.config.block_size]
        else:
            token_ids = token_ids + [self.tokenizer.eot_token] * (self.config.block_size - len(token_ids))
        
        # 转换为tensor并移到对应设备
        token_ids = torch.tensor(token_ids).unsqueeze(0).to(self.device)  # 添加batch维度
        
        # 进行预测
        with torch.no_grad():
            logits, _ = self.model(token_ids)
            probabilities = torch.softmax(logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()
            
        # 获取预测的表情
        predicted_emoji = self.reverse_emoji_mapping[predicted_class]
        
        # 获取所有表情的概率分布
        prob_dict = {}
        probs = probabilities[0].cpu().numpy()
        for emoji_class in range(len(self.emoji_mapping)):
            emoji = self.reverse_emoji_mapping[emoji_class]
            prob_dict[emoji] = float(probs[emoji_class])
        
        return predicted_emoji, prob_dict

    def predict_top_k(self, text, k=3):
        """
        预测输入文本对应的前k个最可能的表情
        """
        _, prob_dict = self.predict(text)
        # 按概率排序并返回前k个
        sorted_emojis = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True)[:k]
        return sorted_emojis

# 使用示例
def main():
    # 初始化预测器
    predictor = EmojiPredictor('model_epoch_3.pt')  # 使用最后一个epoch的模型
    
    # 交互式预测
    while True:
        text = input("\n请输入文本 (输入'quit'退出): ")
        if text.lower() == 'quit':
            break
            
        # 获取预测结果
        emoji, probs = predictor.predict(text)
        print(f"\n预测的表情: {emoji}")
        
        # 显示前3个最可能的表情及其概率
        top_3 = predictor.predict_top_k(text, k=3)
        print("\n前3个最可能的表情:")
        for emoji, prob in top_3:
            print(f"{emoji}: {prob:.4f}")

if __name__ == "__main__":
    main()


预测的表情: 得意

前3个最可能的表情:
得意: 0.1309
🎇: 0.1007
😭: 0.0815

预测的表情: 得意

前3个最可能的表情:
得意: 0.1196
🎇: 0.0948
😭: 0.0872

预测的表情: 得意

前3个最可能的表情:
得意: 0.1206
🎇: 0.0889
😭: 0.0831

预测的表情: 得意

前3个最可能的表情:
得意: 0.1234
🎇: 0.0974
😭: 0.0859

预测的表情: 得意

前3个最可能的表情:
得意: 0.1209
🎇: 0.0924
😭: 0.0851

预测的表情: 得意

前3个最可能的表情:
得意: 0.1234
🎇: 0.0974
😭: 0.0859

预测的表情: 得意

前3个最可能的表情:
得意: 0.1215
🎇: 0.0916
😭: 0.0837
