In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import random

# 1. 数据准备
def load_poems(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        poems = [line.strip() for line in f.readlines() if len(line.strip()) > 0]
    return poems

# 2. 数据预处理
def build_vocab(poems):
    all_chars = [char for poem in poems for char in poem]
    counter = Counter(all_chars)
    char_freq = sorted(counter.items(), key=lambda x: -x[1])
    chars = [char for char, freq in char_freq]
    char_to_idx = {char: idx for idx, char in enumerate(chars)}
    idx_to_char = {idx: char for idx, char in enumerate(chars)}
    return char_to_idx, idx_to_char, len(chars)

def poem_to_indices(poem, char_to_idx):
    return [char_to_idx[char] for char in poem]

def create_dataset(poems, char_to_idx, seq_length=50):
    inputs = []
    targets = []
    for poem in poems:
        indices = poem_to_indices(poem, char_to_idx)
        for i in range(0, len(indices) - seq_length):
            inputs.append(indices[i:i+seq_length])
            targets.append(indices[i+1:i+seq_length+1])
    return np.array(inputs), np.array(targets)

class PoemDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# 3. 模型构建
class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(PoetryModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out)
        return out, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.num_layers, batch_size, self.hidden_dim).zero_(),
                  weight.new(self.num_layers, batch_size, self.hidden_dim).zero_())
        return hidden

# 4. 训练函数
def train_model(model, dataloader, epochs, learning_rate):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(epochs):
        model.train()
        hidden = model.init_hidden(batch_size)
        
        for batch, (inputs, targets) in enumerate(dataloader):
            hidden = tuple([h.data for h in hidden])
            outputs, hidden = model(inputs, hidden)
            loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch}/{len(dataloader)}], Loss: {loss.item():.4f}')
            if batch % 10 == 0:
                # 保存模型
                torch.save(model.state_dict(), 'model/poetry_model_epoch_{}_batch_{}.pth'.format(epoch+1, batch))
        torch.cuda.empty_cache()

# 5. 生成函数
def generate_poem(model, idx_to_char, char_to_idx, start_string, 
                 poem_type='五言', temperature=0.8, max_lines=4):
    """
    生成格式整齐的古诗
    poem_type: 五言或七言
    """
    model.eval()
    
    # 设置诗句长度
    line_length = 5 if poem_type == '五言' else 7
    separators = ['，', '。']
    
    # 初始化输入
    input_seq = [char_to_idx[char] for char in start_string]
    input_seq = torch.LongTensor(input_seq).unsqueeze(0)
    hidden = model.init_hidden(1)
    
    generated_poem = start_string
    current_line_len = len(start_string)
    lines = 1
    expecting_separator = False
    
    while lines <= max_lines:
        # 生成下一个字符
        output, hidden = model(input_seq, hidden)
        last_char_logits = output[:, -1, :] / temperature
        probabilities = torch.softmax(last_char_logits, dim=-1)
        predicted_idx = torch.multinomial(probabilities, 1).item()
        predicted_char = idx_to_char[predicted_idx]
        while not expecting_separator and predicted_char in separators:
            probabilities[0][predicted_idx] = 0.0  # 避免重复
            predicted_idx = torch.multinomial(probabilities, 1).item()
            predicted_char = idx_to_char[predicted_idx]

        
        # 行长度控制
        if not expecting_separator:
            current_line_len += 1
            # 达到指定长度时应该生成分隔符
            if current_line_len == line_length:
                expecting_separator = True
        else:
            # 检查是否是有效的分隔符
            predicted_char='，' if lines%2==1 else '。'
            predicted_idx = char_to_idx[predicted_char]
            # if predicted_char not in separators:
            #     predicted_char = random.choice(separators)
            #     predicted_idx = char_to_idx[predicted_char]
            lines += 1
            current_line_len = 0
            expecting_separator = False
        generated_poem += predicted_char
        input_seq = torch.LongTensor([predicted_idx]).unsqueeze(0)
        if lines > max_lines:
            break
    
    # # 后处理，确保最后以句号结束
    # if not generated_poem.endswith('。'):
    #     generated_poem += '。'
    
    # 格式化为每行一句
    poem_lines = []
    current_line = ""
    for char in generated_poem:
        current_line += char
        if char in ['，', '。']:
            poem_lines.append(current_line)
            current_line = ""
    
    # # 确保偶数行（完整的对联）
    # if len(poem_lines) % 2 != 0 and len(poem_lines) > 1:
    #     poem_lines = poem_lines[:-1]
    
    return "\n".join(poem_lines)

In [8]:
# 参数设置
seq_length = 30
batch_size = 781*2
embedding_dim = 128
hidden_dim = 128
# batch_size = 64
# embedding_dim = 256
# hidden_dim = 512
num_layers = 2
learning_rate = 0.1
epochs = 200

In [9]:

# 加载数据
poems = load_poems('data/poems.txt')  # 替换为你的文件路径
char_to_idx, idx_to_char, vocab_size = build_vocab(poems)

In [10]:
# 创建数据集
X, y = create_dataset(poems, char_to_idx, seq_length)
X = torch.from_numpy(X).long()
y = torch.from_numpy(y).long()
dataset = PoemDataset(X, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
X.shape,y.shape

(torch.Size([23430, 30]), torch.Size([23430, 30]))

In [11]:
model = PoetryModel(vocab_size, embedding_dim, hidden_dim, num_layers)

In [12]:
train_model(model, dataloader, epochs, learning_rate)

Epoch [1/200], Batch [0/15], Loss: 8.0750
Epoch [1/200], Batch [1/15], Loss: 7.8473
Epoch [1/200], Batch [2/15], Loss: 9.2196
Epoch [1/200], Batch [3/15], Loss: 9.6992
Epoch [1/200], Batch [4/15], Loss: 8.3168
Epoch [1/200], Batch [5/15], Loss: 7.8207
Epoch [1/200], Batch [6/15], Loss: 7.3639
Epoch [1/200], Batch [7/15], Loss: 7.1081
Epoch [1/200], Batch [8/15], Loss: 7.0009
Epoch [1/200], Batch [9/15], Loss: 6.8943
Epoch [1/200], Batch [10/15], Loss: 6.8211
Epoch [1/200], Batch [11/15], Loss: 6.9482
Epoch [1/200], Batch [12/15], Loss: 7.0042
Epoch [1/200], Batch [13/15], Loss: 6.9068
Epoch [1/200], Batch [14/15], Loss: 6.8562
Epoch [2/200], Batch [0/15], Loss: 8.3369
Epoch [2/200], Batch [1/15], Loss: 7.1474
Epoch [2/200], Batch [2/15], Loss: 6.9345
Epoch [2/200], Batch [3/15], Loss: 6.9039
Epoch [2/200], Batch [4/15], Loss: 6.8394
Epoch [2/200], Batch [5/15], Loss: 6.8148
Epoch [2/200], Batch [6/15], Loss: 6.7549
Epoch [2/200], Batch [7/15], Loss: 6.7131
Epoch [2/200], Batch [8/15], 