In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# ====================== 1. 读取数据 ======================
with open('D:/Desktop/tokenized_poems.json', 'r', encoding='utf-8') as f:
    poems = json.load(f)  # 直接得到分词后的二维列表

# ====================== 2. 构建词表 ======================
from collections import Counter

# 得到所有token
all_tokens_list = [token for poem in poems for para in poem['paragraphs'] for token in para]
cnt = Counter(all_tokens_list)
print(cnt.most_common(20))  # 查看常见词

min_freq = 10  # 只保留出现>=10次的词，其他都归为<UNK>
all_tokens = set([token for token in all_tokens_list if cnt[token] >= min_freq])

tokens = ['<PAD>', '<START>', '<END>', '<UNK>'] + list(all_tokens)
token2idx = {token: idx for idx, token in enumerate(tokens)}
idx2token = {idx: token for token, idx in token2idx.items()}
vocab_size = len(tokens)
print(f"筛选后词汇表大小: {vocab_size}")



In [None]:
# ====================== 3. 数据预处理 ======================

lengths = [len(para) for poem in poems for para in poem['paragraphs']]
max_content_len = int(np.mean(lengths) + 3)   # 内容部分的最大长度
max_len = max_content_len + 2                 # <START> 和 <END>

print(f"每条数据最终长度: {max_len}")

def poem_to_ids(para):
    ids = [token2idx['<START>']]
    for token in para[:max_content_len]:
        if token in token2idx:
            ids.append(token2idx[token])
        else:
            ids.append(token2idx['<UNK>'])
    ids.append(token2idx['<END>'])
    while len(ids) < max_len:
        ids.append(token2idx['<PAD>'])
    # 保险起见，如果超长就截断
    if len(ids) > max_len:
        ids = ids[:max_len]
    return ids


# 只保留内容长度<=max_content_len的句子
data = [poem_to_ids(para) for poem in poems for para in poem['paragraphs'] if len(para) <= max_content_len]

# 检查所有样本长度
print("所有样本长度分布：", set(len(d) for d in data))


# 正确做法应该是遍历所有句子
data = [poem_to_ids(para) for poem in poems for para in poem['paragraphs']]
for i in range(3):
    print(len(data[i]), data[i])

print(f"训练样本数: {len(data)}")

for i in range(3):
    print(data[i])
    print([idx2token[idx] for idx in data[i]])

# ====================== 4. 自定义数据集 ======================
class PoemDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx][:-1]  # 输入序列
        y = self.data[idx][1:]   # 目标序列（预测下一个词）
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

dataset = PoemDataset(data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
# ====================== 5. 定义RNN模型 ======================
class PoemRNN(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=256, num_layers=2):
        super(PoemRNN, self).__init__()
        # 词向量层
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=token2idx['<PAD>'])
        # LSTM层
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        # 输出层
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embed(x)  # [B, T, E]
        output, hidden = self.lstm(x, hidden)  # output: [B, T, H]
        output = self.fc(output)  # [B, T, V]
        return output, hidden

In [None]:
# ====================== 6. 训练模型 ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PoemRNN(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=token2idx['<PAD>'])

# 训练轮数设置适中，确保训练时间
epochs = 15
print(f"词汇表大小: {vocab_size}")  # 输出词汇表大小

for epoch in range(epochs):
    model.train()
    total_loss = 0
    # tqdm外层加epoch信息，内层包dataloader显示batch进度
    with tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}', ncols=100) as pbar:
        for x, y in pbar:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output, _ = model(x)
            loss = criterion(output.view(-1, vocab_size), y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            # 动态显示当前loss
            pbar.set_postfix(loss=loss.item())
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}')


# ====================== 7. 生成古诗函数 ======================
def generate_poem(model, start_words=None, max_gen_len=20):
    model.eval()
    # 初始输入为<START>
    input_idx = [token2idx['<START>']]
    if start_words:
        for word in start_words:
            input_idx.append(token2idx.get(word, token2idx['<UNK>']))
    input_tensor = torch.tensor([input_idx], dtype=torch.long).to(device)
    hidden = None
    result = []
    with torch.no_grad():
        for _ in range(max_gen_len):
            output, hidden = model(input_tensor, hidden)
            last_word_logits = output[0, -1]  # 取最后一个时间步
            next_word_id = torch.argmax(last_word_logits).item()
            next_word = idx2token[next_word_id]
            if next_word == '<END>':
                break
            result.append(next_word)
            input_tensor = torch.tensor([[next_word_id]], dtype=torch.long).to(device)
    return ''.join(result)

In [None]:
# ====================== 8. 示例：生成古诗 ======================
print("生成古诗示例：", generate_poem(model, start_words=['湖边', '西', '飞雁']))

# ====================== 9. 保存模型 ======================
torch.save(model.state_dict(), '../../../张范渝超/qq下载/poem_rnn.pth')



In [None]:
import json
import torch

# ====================== 训练结束后保存模型和词典 ======================
# 保存模型
torch.save(model.state_dict(), '../../../张范渝超/qq下载/poem_rnn.pth')

# 保存词典
token_dict = {
    'token2idx': token2idx,
    'idx2token': idx2token
}

with open('token_dict.json', 'w', encoding='utf-8') as f:
    json.dump(token_dict, f, ensure_ascii=False)
