In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# ====================== 1. 读取数据 ======================
with open('tokenized_poems.json', 'r', encoding='utf-8') as f:
    poems = json.load(f)  # 直接得到分词后的二维列表

# ====================== 2. 构建词表 ======================
from collections import Counter

# 得到所有token
all_tokens_list = [token for poem in poems for para in poem['paragraphs'] for token in para]
cnt = Counter(all_tokens_list)
print(cnt.most_common(20))  # 查看常见词

min_freq = 5  # 只保留出现>=10次的词，其他都归为<UNK>
all_tokens = set([token for token in all_tokens_list if cnt[token] >= min_freq])

tokens = ['<PAD>', '<START>', '<END>', '<UNK>'] + list(all_tokens)
token2idx = {token: idx for idx, token in enumerate(tokens)}
idx2token = {idx: token for token, idx in token2idx.items()}

# 假设已有 token2idx 和 idx2token 字典
with open('token2idx.json', 'w', encoding='utf-8') as f:
    json.dump(token2idx, f, ensure_ascii=False, indent=4)

with open('idx2token.json', 'w', encoding='utf-8') as f:
    json.dump(idx2token, f, ensure_ascii=False, indent=4)

vocab_size = len(tokens)
print(f"筛选后词汇表大小: {vocab_size}")



In [None]:

# ====================== 3. 数据预处理 ======================
# 每个句子前加<START>，后加<END>
max_len = max(len(para) for poem in poems for para in poem['paragraphs']) + 2  # +2 for <START> <END>

def poem_to_ids(poem):
    ids = [token2idx['<START>']]
    for token in poem:
        ids.append(token2idx.get(token, token2idx['<UNK>']))
    ids.append(token2idx['<END>'])
    # PAD到max_len
    while len(ids) < max_len:
        ids.append(token2idx['<PAD>'])
    return ids

data = [poem_to_ids(para) for poem in poems for para in poem['paragraphs']]

# ====================== 4. 自定义数据集 ======================
class PoemDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx][:-1]  # 输入序列
        y = self.data[idx][1:]   # 目标序列（预测下一个词）
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

dataset = PoemDataset(data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [None]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

# ====================== 5. 定义 Transformer 模型 ======================
class PoemTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, ff_dim=512, dropout=0.1):
        super(PoemTransformer, self).__init__()
        
        # 词向量层
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=token2idx['<PAD>'])
        
        # 位置编码（Transformer不使用RNN，因此需要位置编码）
        self.positional_encoding = nn.Parameter(torch.rand(1, 5000, embed_size))  # 最大序列长度 5000
        
        # Transformer 层
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=dropout
        )
        
        # 输出层
        self.fc_out = nn.Linear(embed_size, vocab_size)
    
    def forward(self, src, tgt):
        # 将输入和目标嵌入并加上位置编码
        src = self.embed(src) + self.positional_encoding[:, :src.size(1), :]
        tgt = self.embed(tgt) + self.positional_encoding[:, :tgt.size(1), :]
        
        # Transformer 期望的输入格式是 (sequence_length, batch_size, features)
        src = src.transpose(0, 1)  # 转换成 (T, B, E)
        tgt = tgt.transpose(0, 1)  # 转换成 (T, B, E)
        
        # 通过 Transformer 层
        output = self.transformer(src, tgt)
        
        # 通过全连接层将输出映射回词汇表大小
        output = self.fc_out(output)
        
        return output

# ====================== 6. 训练模型 ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 初始化模型，优化器和损失函数
model = PoemTransformer(vocab_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)  # 使用较小的学习率
criterion = nn.CrossEntropyLoss(ignore_index=token2idx['<PAD>'])

# 训练轮次
epochs = 10

# 用于记录每个epoch的损失和困惑度
epoch_losses = []
epoch_perplexities = []

# 初始化图表
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
loss_line, = ax1.plot([], [], label='Loss', color='blue')
perplexity_line, = ax2.plot([], [], label='Perplexity', color='red')

ax1.set_title('Training Loss over Epochs')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.set_xlim(0, epochs)
ax1.set_ylim(0, 5)

ax2.set_title('Training Perplexity over Epochs')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Perplexity')
ax2.set_xlim(0, epochs)
ax2.set_ylim(0, 50)

ax1.legend()
ax2.legend()

plt.ion()  # 开启交互模式

# ====================== 训练过程 ======================
for epoch in range(epochs):
    model.train()
    total_loss = 0
    total_perplexity = 0
    for x, y in tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}', ncols=100):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        
        # 将源序列（x）和目标序列（y）传入 Transformer 模型
        output = model(x, y[:, :-1])  # 使用目标序列作为输入（右移 1 位）
        output_dim = output.shape[-1]
        
        # 计算损失
        loss = criterion(output.view(-1, output_dim), y[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        # 计算困惑度
        perplexity = torch.exp(loss)
        total_perplexity += perplexity.item()

    avg_loss = total_loss / len(dataloader)
    avg_perplexity = total_perplexity / len(dataloader)
    epoch_losses.append(avg_loss)
    epoch_perplexities.append(avg_perplexity)
    
    print(f'Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}, Avg Perplexity: {avg_perplexity:.4f}')
    
    # 更新图表
    loss_line.set_data(range(1, epoch + 2), epoch_losses)
    perplexity_line.set_data(range(1, epoch + 2), epoch_perplexities)
    
    ax1.relim()
    ax1.autoscale_view()
    ax2.relim()
    ax2.autoscale_view()
    
    plt.draw()
    plt.pause(0.1)
    
    # 保存模型权重（每个 epoch 保存一次）
    torch.save(model.state_dict(), f'poem_transformer_epoch_{epoch+1}.pth')

# 训练结束后保存最终模型
torch.save(model.state_dict(), 'poem_transformer_final.pth')

plt.ioff()  # 关闭交互模式
plt.show()

