In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import re



# ---  参数设置 (参照 RNN-LM-song.html) ---
SEQ_LEN = 16        # 序列长度
BATCH_SIZE = 32     # 批处理大小
EMBEDDING_DIM = 128 # 嵌入层维度
HIDDEN_DIM = 256    # LSTM 隐藏层维度
NUM_EPOCHS = 30     # 训练周期
LR = 0.001          # 学习率
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"使用设备: {DEVICE}")

# ---  数据处理 (参照 RNN-LM-song.html) ---
def load_data(file_path):
    """加载并预处理数据"""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = f.read()
    
    # 参照案例，替换换行符和空格
    data = re.sub(r'[\n\r\s]', '', data)
    print(f"语料库总字数: {len(data)}")
    
    # 构建词汇表 (字符级)
    words = sorted(list(set(data)))
    word_to_ix = {word: ix for ix, word in enumerate(words)}
    ix_to_word = {ix: word for ix, word in enumerate(words)}
    vocab_size = len(word_to_ix)
    
    print(f"词汇表大小: {vocab_size}")
    return data, word_to_ix, ix_to_word, vocab_size

# ---  Dataset 和 DataLoader (参照 RNN-LM-song.html) ---
class PoetryDataset(Dataset):
    """自定义数据集"""
    def __init__(self, data, word_to_ix, seq_len):
        self.data = data
        self.word_to_ix = word_to_ix
        self.seq_len = seq_len
        self.vocab_size = len(word_to_ix)

    def __getitem__(self, i):
        # 定义输入序列 (x) 和目标序列 (y)
        # y 是 x 向后错位一个字符
        data_seq = self.data[i: i + self.seq_len]
        label_seq = self.data[i + 1: i + 1 + self.seq_len]
        
        # 转换为索引
        data_ix = [self.word_to_ix[char] for char in data_seq]
        label_ix = [self.word_to_ix[char] for char in label_seq]
        
        return torch.tensor(data_ix), torch.tensor(label_ix)

    def __len__(self):
        # 使用滑动窗口，总长度为 data_len - seq_len
        return len(self.data) - self.seq_len

# ---  RNN (LSTM) 模型 (参照 RNN-LM-song.html) ---
class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # 案例中使用 LSTM
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        embeds = self.embedding(x)
        # embeds: [batch_size, seq_len, embedding_dim]
        
        if hidden is None:
            lstm_out, hidden = self.lstm(embeds)
        else:
            lstm_out, hidden = self.lstm(embeds, hidden)
        
        # lstm_out: [batch_size, seq_len, hidden_dim]
        
        # 将每个时间步的输出都通过全连接层
        output = self.fc(lstm_out)
        # output: [batch_size, seq_len, vocab_size]
        
        return output, hidden

    def init_hidden(self, batch_size, device):
        # 初始化 LSTM 的 hidden state 和 cell state
        return (torch.zeros(1, batch_size, self.hidden_dim).to(device),
                torch.zeros(1, batch_size, self.hidden_dim).to(device))

# ---  训练函数 (参照 RNN-LM-song.html) ---
def train(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for i, (data, label) in enumerate(data_loader):
        data, label = data.to(device), label.to(device)
        
        optimizer.zero_grad()
        
        # 初始化隐藏状态
        hidden = model.init_hidden(data.size(0), device)
        
        output, hidden = model(data, hidden)
        
        # LSTM的输出是 [batch_size, seq_len, vocab_size]
        # 损失函数需要 [batch_size * seq_len, vocab_size]
        # label 需要 [batch_size * seq_len]
        loss = criterion(output.view(-1, model.embedding.num_embeddings), label.view(-1))
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if i % 100 == 0 and i > 0:
            print(f"  批次 {i}/{len(data_loader)}, Loss: {loss.item():.4f}")
            
    return total_loss / len(data_loader)

# ---  生成函数 (参照 RNN-LM-song.html) ---
def generate(model, start_word, length, ix_to_word, word_to_ix, device):
    """
    生成文本
    参照案例中的 'generate' 函数
    """
    model.eval()
    results = list(start_word)
    
    # 准备初始输入
    input_tensor = torch.tensor([word_to_ix[start_word[0]]]).view(1, 1).to(device)
    hidden = model.init_hidden(1, device)

    # "预热" RNN，使用 'start_word' 来初始化隐藏状态
    for i in range(len(start_word)):
        input_tensor = torch.tensor([word_to_ix[start_word[i]]]).view(1, 1).to(device)
        output, hidden = model(input_tensor, hidden)
    
    # 获取 'start_word' 最后一个字符的索引
    word_ix = word_to_ix[start_word[-1]]
    
    # 开始生成
    for _ in range(length):
        input_tensor = torch.tensor([word_ix]).view(1, 1).to(device)
        output, hidden = model(input_tensor, hidden)
        
        # 参照案例，使用 topk(1) (即 argmax) 来选择下一个词
        # output 形状: [1, 1, vocab_size]
        output_data = output.data[0].view(-1)
        
        # 采用 argmax (最可能的下一个词)
        top_index = output_data.topk(1)[1].item()
        
        word_ix = top_index
        word = ix_to_word[word_ix]
        results.append(word)

    return "".join(results)

# ---  主执行逻辑 ---
if __name__ == "__main__":
    
    # 加载数据和词汇表
    data, word_to_ix, ix_to_word, vocab_size = load_data('poetry_corpus.txt')
    
    # 创建 Dataset 和 DataLoader
    dataset = PoetryDataset(data, word_to_ix, SEQ_LEN)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # 初始化模型
    model = PoetryModel(vocab_size, EMBEDDING_DIM, HIDDEN_DIM).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    
    print("\n开始训练...")
    
    for epoch in range(1, NUM_EPOCHS + 1):
        avg_loss = train(model, data_loader, optimizer, criterion, DEVICE)
        print(f'--- Epoch {epoch}/{NUM_EPOCHS} 完成, 平均 Loss: {avg_loss:.4f} ---')
        
        # 每5个周期测试一次生成效果
        if epoch % 5 == 0:
            print(f"\n--- Epoch {epoch} 测试生成 ---")
            generated_text = generate(model, '春眠', 40, ix_to_word, word_to_ix, DEVICE)
            print(f"  以 '春眠' 开头: {generated_text}")
            generated_text_2 = generate(model, '明月', 40, ix_to_word, word_to_ix, DEVICE)
            print(f"  以 '明月' 开头: {generated_text_2}")
            print("--------------------------\n")
            
    print("训练完成。")
    
    # --- 9. 最终测试 ---
    print("\n--- 最终生成测试 ---")
    print(generate(model, '白日', 60, ix_to_word, word_to_ix, DEVICE))
    print(generate(model, '红豆', 60, ix_to_word, word_to_ix, DEVICE))

使用设备: cuda
语料库总字数: 25869
词汇表大小: 2491

开始训练...
  批次 100/808, Loss: 6.3675
  批次 200/808, Loss: 6.1105
  批次 300/808, Loss: 5.6957
  批次 400/808, Loss: 5.2936
  批次 500/808, Loss: 4.8081
  批次 600/808, Loss: 4.7293
  批次 700/808, Loss: 4.3257
  批次 800/808, Loss: 3.9525
--- Epoch 1/30 完成, 平均 Loss: 5.3732 ---
  批次 100/808, Loss: 3.5367
  批次 200/808, Loss: 3.4780
  批次 300/808, Loss: 3.1143
  批次 400/808, Loss: 2.8966
  批次 500/808, Loss: 2.5326
  批次 600/808, Loss: 2.4987
  批次 700/808, Loss: 2.5066
  批次 800/808, Loss: 2.0920
--- Epoch 2/30 完成, 平均 Loss: 2.9456 ---
  批次 100/808, Loss: 1.9191
  批次 200/808, Loss: 1.8137
  批次 300/808, Loss: 1.6639
  批次 400/808, Loss: 1.5365
  批次 500/808, Loss: 1.4701
  批次 600/808, Loss: 1.3435
  批次 700/808, Loss: 1.3320
  批次 800/808, Loss: 1.2313
--- Epoch 3/30 完成, 平均 Loss: 1.5895 ---
  批次 100/808, Loss: 1.1239
  批次 200/808, Loss: 0.9966
  批次 300/808, Loss: 0.9619
  批次 400/808, Loss: 0.9052
  批次 500/808, Loss: 0.8093
  批次 600/808, Loss: 0.9511
  批次 700/808, Loss: 0.8382
