<center>
<h1>基于pytorch + LSTM 的古诗生成</h1>
</center>

### 课程介绍: 
本课程使用pytorch框架, 完成NLP任务:古诗生成,使用的模型为 LSTM, 并训练了词向量, 支持随机古诗和藏头诗生成, 并且生成的古诗具有多变性。在课堂中会从0完成代码的编写，并分为多个文件，以便对比。

<br>

### 导包:

In [3]:
import os
import numpy as np
import pickle
import torch
import torch.nn as nn
from gensim.models.word2vec import Word2Vec
from torch.utils.data import Dataset, DataLoader

<br>

### 生成切分文件:

In [4]:
def split_text(file="poetry_7.txt", train_num=6000):
    all_data = open(file, "r", encoding="utf-8").read()
    with open("split_7.txt", "w", encoding="utf-8") as f:
        split_data = " ".join(all_data)
        f.write(split_data)
    return split_data[:train_num * 64]

<br>

### 训练词向量:

In [5]:
def train_vec(split_file="split_7.txt", org_file="poetry_7.txt", train_num=6000):
    param_file = "word_vec.pkl"
    org_data = open(org_file, "r", encoding="utf-8").read().split("\n")[:train_num]
    if os.path.exists(split_file):
        all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")[:train_num]
    else:
        all_data_split = split_text().split("\n")[:train_num]

    if os.path.exists(param_file):
        return org_data, pickle.load(open(param_file, "rb"))

    models = Word2Vec(all_data_split, vector_size=128, workers=7, min_count=1)
    pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
    return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)

<br>

### 构建数据集:

In [6]:
class Poetry_Dataset(Dataset):
    def __init__(self, w1, word_2_index, all_data):
        self.w1 = w1
        self.word_2_index = word_2_index
        self.all_data = all_data

    def __getitem__(self, index):
        a_poetry = self.all_data[index]

        a_poetry_index = [self.word_2_index[i] for i in a_poetry]
        xs = a_poetry_index[:-1]
        ys = a_poetry_index[1:]
        xs_embedding = self.w1[xs]

        return xs_embedding, np.array(ys).astype(np.int64)

    def __len__(self):
        return len(self.all_data)

<br>

### 模型构建:

In [7]:
class Poetry_Model_lstm(nn.Module):
    def __init__(self, hidden_num, word_size, embedding_num):
        super().__init__()

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.hidden_num = hidden_num

        self.lstm = nn.LSTM(input_size=embedding_num, hidden_size=hidden_num, batch_first=True, num_layers=2,
                            bidirectional=False)
        self.dropout = nn.Dropout(0.3)
        self.flatten = nn.Flatten(0, 1)
        self.linear = nn.Linear(hidden_num, word_size)
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, xs_embedding, h_0=None, c_0=None):
        if h_0 == None or c_0 == None:
            h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
            c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
        h_0 = h_0.to(self.device)
        c_0 = c_0.to(self.device)
        xs_embedding = xs_embedding.to(self.device)
        hidden, (h_0, c_0) = self.lstm(xs_embedding, (h_0, c_0))
        hidden_drop = self.dropout(hidden)
        hidden_flatten = self.flatten(hidden_drop)
        pre = self.linear(hidden_flatten)

        return pre, (h_0, c_0)

<br>

### 自动生成古诗:


In [8]:
def generate_poetry_auto():
    result = ""
    word_index = np.random.randint(0, word_size, 1)[0]

    result += index_2_word[word_index]
    h_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))
    c_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))

    for i in range(31):
        word_embedding = torch.tensor(w1[word_index][None][None])
        pre, (h_0, c_0) = model(word_embedding, h_0, c_0)
        word_index = int(torch.argmax(pre))
        result += index_2_word[word_index]

    return result


<br>

### 藏头诗生成:

In [9]:
def generate_poetry_acrostic():
    input_text = input("请输入四个汉字：")[:4]
    result = ""
    punctuation_list = ["，", "。", "，", "。"]
    for i in range(4):
        result += input_text[i]
        h_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))
        c_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))
        word = input_text[i]
        for j in range(6):
            word_index = word_2_index[word]
            word_embedding = torch.tensor(w1[word_index][None][None])
            pre , (h_0,c_0) = model(word_embedding,h_0,c_0)
            word = word_2_index[int(torch.argmax(pre))]
            result += word

    return result

<br>

### 主函数: 定义参数, 模型, 优化器, 模型训练

In [10]:
if __name__ == "__main__":

    all_data, (w1, word_2_index, index_2_word) = train_vec(train_num=300)

    batch_size = 32
    epochs = 1000
    lr = 0.01
    hidden_num = 128
    word_size, embedding_num = w1.shape

    dataset = Poetry_Dataset(w1, word_2_index, all_data)
    dataloader = DataLoader(dataset, batch_size)

    model = Poetry_Model_lstm(hidden_num, word_size, embedding_num)
    model = model.to(model.device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for e in range(epochs):
        for batch_index, (batch_x_embedding, batch_y_index) in enumerate(dataloader):
            model.train()
            batch_x_embedding = batch_x_embedding.to(model.device)
            batch_y_index = batch_y_index.to(model.device)

            pre, _ = model(batch_x_embedding)
            loss = model.cross_entropy(pre, batch_y_index.reshape(-1))

            loss.backward()  # 梯度反传 , 梯度累加, 但梯度并不更新, 梯度是由优化器更新的
            optimizer.step()  # 使用优化器更新梯度
            optimizer.zero_grad()  # 梯度清零

            if batch_index % 100 == 0:
                # model.eval()
                print(f"loss:{loss:.3f}")
                print(generate_poetry_auto())

loss:8.173
渺声。。。成。。，。堪。，。。，。成。。。。。。。。。。。。定。
loss:7.006
寒。，，，，，，。，。，，，，，。。。。。。。，，，，。，。。。
loss:6.838
急。，，。，。，。。。，，。。。。。。。，，，，，。，，。。，。
loss:6.742
已。，，，。。。。。。，。，。。，。，。。。。。。。。。，，。。
loss:6.703
跃山山。。，。，，，。。，。，。。，，。，。。。，。。。。，。。
loss:6.656
翻有三，，，。。。。。。。。。。。，，。，。，。，，。。。。。。
loss:6.642
魂山三路，，，，，，，。，，，。。。，。。。。。。。，。。。，，
loss:6.634
嬴门风，三，，，，，，，。。。。，。。。。。。。。。。。，。。。
loss:6.563
评有风海，，，，，，，，，，，，，，，，，，。。。。。。，。，。
loss:6.460
鲛门风传斗，，，，，，。。。。。。。。。。。。。。。。。。。。。
loss:6.406
现门三树斗，，，，，，，。。。，。。。。。。。。。。。。。。。。
loss:6.242
傅门三海生天，，，，，。。。。。。一。。。。。。，。。。。。。。
loss:6.107
读山下车林天，，，山来。一一一。一来来。。来一。一山。。。山。。
loss:5.943
履山风车性三宣，一无一。一风花。一山风。一山天，一风一。一不风。
loss:5.827
珊色风光三林，，一来烟。一山青。一山一。一来风。一山如。一山时。
loss:5.760
眷台高路香有，，一山一山一来天，一山如声不来新。一来风人一来天。
loss:5.687
芃八若光海峰，，一须一。，一青。一来一。一无花。一来一。一天微。
loss:5.612
泚人梅光姓三宽，一山一。一天心。一来一山一不天。一来一山。不来。
loss:5.580
叉得梅路三天空，一步一。一天仙。一山一声一无花。未山一年。不来。
loss:5.539
岳阙别光不千过，砌山一觉一水仙。一山一山一云花。一来一山一不花，
loss:5.538
栈台若光海天天，一山一阁一水新。一来一风一无。，一教烟花一不花，
loss:5.451
雁有风光斗无天，一山如花旧天场。一来一年一无。，一山一年一人香，
loss:5.444
身门风光斗氏天，筠载编贵。天深。一是遥时不

KeyboardInterrupt: 