In [10]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import random

# 1. 数据准备
def load_poems(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        poems = [line.strip() for line in f.readlines() if len(line.strip()) > 0]
    return poems

# 2. 数据预处理
def build_vocab(poems):
    all_chars = [char for poem in poems for char in poem]
    counter = Counter(all_chars)
    char_freq = sorted(counter.items(), key=lambda x: -x[1])
    chars = [char for char, freq in char_freq]
    char_to_idx = {char: idx for idx, char in enumerate(chars)}
    idx_to_char = {idx: char for idx, char in enumerate(chars)}
    return char_to_idx, idx_to_char, len(chars)

def poem_to_indices(poem, char_to_idx):
    return [char_to_idx[char] for char in poem]

def create_dataset(poems, char_to_idx, seq_length=50):
    inputs = []
    targets = []
    for poem in poems:
        indices = poem_to_indices(poem, char_to_idx)
        for i in range(0, len(indices) - seq_length):
            inputs.append(indices[i:i+seq_length])
            targets.append(indices[i+1:i+seq_length+1])
    return np.array(inputs), np.array(targets)

class PoemDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# 3. 模型构建
class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(PoetryModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out)
        return out, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.num_layers, batch_size, self.hidden_dim).zero_(),
                  weight.new(self.num_layers, batch_size, self.hidden_dim).zero_())
        return hidden

# 4. 训练函数
def train_model(model, dataloader, epochs, learning_rate):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(epochs):
        model.train()
        hidden = model.init_hidden(batch_size)
        
        for batch, (inputs, targets) in enumerate(dataloader):
            hidden = tuple([h.data for h in hidden])
            outputs, hidden = model(inputs, hidden)
            loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch}/{len(dataloader)}], Loss: {loss.item():.4f}')
            if batch % 10 == 0:
                # 保存模型
                torch.save(model.state_dict(), 'data2/test4model3/poetry_model_epoch_{}_batch_{}.pth'.format(epoch+1, batch))
        torch.cuda.empty_cache()

# 5. 生成函数
def generate_poem(model, idx_to_char, char_to_idx, start_string, 
                 poem_type='五言', temperature=0.8, max_lines=4):
    """
    生成格式整齐的古诗
    poem_type: 五言或七言
    """
    model.eval()
    
    # 设置诗句长度
    line_length = 5 if poem_type == '五言' else 7
    separators = ['，', '。']
    
    # 初始化输入
    input_seq = [char_to_idx[char] for char in start_string]
    input_seq = torch.LongTensor(input_seq).unsqueeze(0)
    hidden = model.init_hidden(1)
    
    generated_poem = start_string
    current_line_len = len(start_string)
    lines = 1
    expecting_separator = False
    
    while lines <= max_lines:
        # 生成下一个字符
        output, hidden = model(input_seq, hidden)
        last_char_logits = output[:, -1, :] / temperature
        probabilities = torch.softmax(last_char_logits, dim=-1)
        predicted_idx = torch.multinomial(probabilities, 1).item()
        predicted_char = idx_to_char[predicted_idx]
        while not expecting_separator and predicted_char in separators:
            probabilities[0][predicted_idx] = 0.0  # 避免重复
            predicted_idx = torch.multinomial(probabilities, 1).item()
            predicted_char = idx_to_char[predicted_idx]

        
        # 行长度控制
        if not expecting_separator:
            current_line_len += 1
            # 达到指定长度时应该生成分隔符
            if current_line_len == line_length:
                expecting_separator = True
        else:
            # 检查是否是有效的分隔符
            predicted_char='，' if lines%2==1 else '。'
            predicted_idx = char_to_idx[predicted_char]
            # if predicted_char not in separators:
            #     predicted_char = random.choice(separators)
            #     predicted_idx = char_to_idx[predicted_char]
            lines += 1
            current_line_len = 0
            expecting_separator = False
        generated_poem += predicted_char
        input_seq = torch.LongTensor([predicted_idx]).unsqueeze(0)
        if lines > max_lines:
            break
    
    # # 后处理，确保最后以句号结束
    # if not generated_poem.endswith('。'):
    #     generated_poem += '。'
    
    # 格式化为每行一句
    poem_lines = []
    current_line = ""
    for char in generated_poem:
        current_line += char
        if char in ['，', '。']:
            poem_lines.append(current_line)
            current_line = ""
    
    # # 确保偶数行（完整的对联）
    # if len(poem_lines) % 2 != 0 and len(poem_lines) > 1:
    #     poem_lines = poem_lines[:-1]
    
    return "\n".join(poem_lines)

In [11]:
# 参数设置
seq_length = 30
batch_size = 781
embedding_dim = 64
hidden_dim = 64
# batch_size = 64
# embedding_dim = 256
# hidden_dim = 512
num_layers = 2
learning_rate = 0.1
epochs = 200

In [12]:

# 加载数据
poems = load_poems('data2/poems.txt')#[:100]  # 替换为你的文件路径
char_to_idx, idx_to_char, vocab_size = build_vocab(poems)

In [13]:
# 创建数据集
X, y = create_dataset(poems, char_to_idx, seq_length)
X = torch.from_numpy(X).long()
y = torch.from_numpy(y).long()
dataset = PoemDataset(X, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
X.shape,y.shape

(torch.Size([23430, 30]), torch.Size([23430, 30]))

In [14]:
model = PoetryModel(vocab_size, embedding_dim, hidden_dim, num_layers)

In [None]:
import psutil

def show_memory():
    process = psutil.Process()
    print(f"内存使用: {process.memory_info().rss / 1024 / 1024:.2f} MB")

show_memory()

ERROR: Could not find file C:\Users\hypherd\AppData\Local\Temp\ipykernel_20404\2848607697.py


In [None]:
train_model(model, dataloader, epochs, learning_rate)

Epoch [1/200], Batch [0/30], Loss: 4.9894
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [1/30], Loss: 6.9189
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [2/30], Loss: 5.5664
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [3/30], Loss: 5.8171
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [4/30], Loss: 5.7976
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [5/30], Loss: 5.6899
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [6/30], Loss: 5.6277
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [7/30], Loss: 5.5844
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [8/30], Loss: 5.5693
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [9/30], Loss: 5.4816
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [10/30], Loss: 5.4473
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [11/30], Loss: 5.3994
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [12/30], Loss: 5.3586
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [13/30], Loss: 5.3249
当前分配内存 0.0 MB
缓存内存 0.0 MB
Epoch [1/200], Batch [14/30], Loss: 5.2863
当

In [4]:
model = PoetryModel(vocab_size, embedding_dim, hidden_dim, num_layers)
model.load_state_dict(torch.load('data2/test4model2/poetry_model_epoch_32_batch_0.pth'))

<All keys matched successfully>

In [5]:
start_chars = ["春"]

print("\n生成五言诗:")
for i in range(20):
    for start in start_chars:
        generated_poem = generate_poem(model, idx_to_char, char_to_idx, start, 
                                        poem_type='五言', temperature=1)
        print(f"\n以'{start}'开头的五言诗:")
        print(generated_poem)

    print("\n生成七言诗:")
    for start in start_chars:
        generated_poem = generate_poem(model, idx_to_char, char_to_idx, start, 
                                        poem_type='七言', temperature=1)
        print(f"\n以'{start}'开头的七言诗:")
        print(generated_poem)


生成五言诗:

以'春'开头的五言诗:
春虽故乡未，
云霞五说与。
云倚画树唱，
离处隔鸡琴。

生成七言诗:

以'春'开头的七言诗:
春分道只和惠志，
聪间王州泪奔高。
公朝手驿閒事已，
华想家尘凤声来。

以'春'开头的五言诗:
春卧稳自近，
踟蹰路荣是。
名成新词近，
东篱影悬仙。

生成七言诗:

以'春'开头的七言诗:
春过山卧兴衰旧，
更阑分君臣可愁。
离然散卧霜叶叶，
青为断处我平归。

以'春'开头的五言诗:
春驰栈知春，
却在紫是含。
祇遥雾势年，
缓飘无辞穗。

生成七言诗:

以'春'开头的七言诗:
春接修时遣扫望，
枝堂梦得神居处。
别后成旅路真踪，
不得分分白醉故。

以'春'开头的五言诗:
春鸿迎又棹，
拾翠谁复响。
远来浮世利，
唤静为始堪。

生成七言诗:

以'春'开头的七言诗:
春在食迳收前宝，
抄得新书已在学。
独听病依难照香，
天惊后清飞时淅。

以'春'开头的五言诗:
春赏古长终，
有令自荒干。
怡惊蝉展名，
江淹散谩恋。

生成七言诗:

以'春'开头的七言诗:
春分取欺炉香儒，
清星对善友相未。
便牵魂密犹吟影，
公退斯得陇残红。

以'春'开头的五言诗:
春虽自枯谙，
直艇高斋轻。
筹难名忽岭，
兴后逢时绝。

生成七言诗:

以'春'开头的七言诗:
春锁惊后岸垂钓，
风卷喷云梦憾江。
素道归路亦醉吟，
暖歌尽鹭敲翠枝。

以'春'开头的五言诗:
春讲似秋丝，
留畔野去故。
旌表宜颜古，
心同入当欲。

生成七言诗:

以'春'开头的七言诗:
春棹山圆非却屠，
接印不独能肃物。
竹逐春山尽石江，
三泉石影轻曲印。

以'春'开头的五言诗:
春虽了露中，
红山只庭迟。
幽窗前见问，
春风春平鼻。

生成七言诗:

以'春'开头的七言诗:
春棹笼霜起竹人，
山光生寿老时髦。
翅佩烟笼琴生冷，
欲色芳受百笑知。

以'春'开头的五言诗:
春卧九路赊，
今松院恨长。
警露吟出格，
求化营年知。

生成七言诗:

以'春'开头的七言诗:
春万事搜官事远，
两礼岸烟霞高然。
云霞曾时心惟来，
自由俱汉绿微难。

以'春'开头的五言诗:
春谩花光惆，
无凭费叟烟。
数家芦苇风，
茱萸饮数春。

生成七言诗:

以'春'开头的七言诗:
春莫在未为仙照，
身无别随鬼流轻。
意知更心常玉箸，
故帆春幽景未迟。

以'春'开头的五言诗: