# Loading Data & Preprocessing

In [1]:
import re
import torch
import time
import random
import numpy as np
import torch.nn as nn

In [2]:
poetry = []
tf_word = {}
with open('poetry.txt', encoding='utf-8') as f:
    for line in f:
        line = re.sub('（\S+）', '', line)
        for word in line:
            if word not in tf_word:
                tf_word[word] = 1
            else:
                tf_word[word] += 1
        if len(line) > 15:
            poetry += re.split('[，。！？；]', line.strip())[:-1]

In [3]:
five_words_poetry = list(filter(lambda x: len(x) == 5, poetry))
seven_words_poetry = list(filter(lambda x: len(x) == 7, poetry))

In [4]:
print(five_words_poetry[:5], seven_words_poetry[:5], sep='\n')

['寒随穷律变', '春逐鸟声开', '初风飘带柳', '晚雪间花梅', '碧林青旧竹']
['暧暧去尘昏灞岸', '飞飞轻盖指河梁', '云峰衣结千重叶', '雪岫花开几树妆', '深悲黄鹤孤舟远']


In [5]:
print('五言诗诗句总数:{}，七言诗诗句总数:{}'.format(len(five_words_poetry), len(seven_words_poetry)))

五言诗诗句总数:296255，七言诗诗句总数:141968


In [6]:
min_tf_word = 150
# 过滤出现次数小于该值的字
word_seq = []
word2idx = {}

In [7]:
for line in five_words_poetry:
    words = [word if tf_word[word] > min_tf_word else '<UNK>' for word in line]
    word_seq.append(words)

In [8]:
print(word_seq[:20])

[['寒', '随', '穷', '律', '变'], ['春', '逐', '鸟', '声', '开'], ['初', '风', '飘', '带', '柳'], ['晚', '雪', '间', '花', '梅'], ['碧', '林', '青', '旧', '竹'], ['绿', '沼', '翠', '新', '苔'], ['芝', '田', '初', '雁', '去'], ['绮', '树', '巧', '莺', '来'], ['晚', '霞', '聊', '自', '<UNK>'], ['初', '晴', '弥', '可', '喜'], ['日', '<UNK>', '百', '花', '色'], ['风', '动', '千', '林', '翠'], ['池', '鱼', '跃', '不', '同'], ['园', '鸟', '声', '还', '异'], ['寄', '言', '博', '通', '者'], ['知', '予', '物', '外', '志'], ['一', '朝', '春', '夏', '改'], ['隔', '夜', '鸟', '花', '迁'], ['阴', '阳', '深', '浅', '叶'], ['晓', '夕', '重', '轻', '烟']]


# Embedding

In [9]:
class Embedding:
    def __init__(self, data):
        self.data = data
        self.idx2word = []
        self.word2idx = {}
    
    def mk_embedding(self):
        for line in self.data:
            for word in line:
                if word not in self.word2idx:
                    self.word2idx[word] = len(self.idx2word)
                    self.idx2word.append(word)
                    
    def word_seq2idx_seq(self, word_seq):
        idx_seq = []
        for line in word_seq:
            idx_seq.append(list(map(lambda word: embedding.word2idx[word], line)))
        return idx_seq

In [10]:
embedding = Embedding(word_seq)
embedding.mk_embedding()
idx_seq = embedding.word_seq2idx_seq(word_seq)

In [11]:
print(idx_seq[:5])

[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]


# Dataset

In [12]:
from torch.utils.data import Dataset, DataLoader

class PoetryDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, i):
        return self.x[i,:], self.y[i,:]

def x2y(x):
    return x[1:]
    
x_train = np.array(idx_seq)
y_train = np.array(list(map(lambda x: (x2y(x)), idx_seq)))

In [13]:
print('x:\n{}\ny:\n{}'.format(x_train, y_train))

x:
[[   0    1    2    3    4]
 [   5    6    7    8    9]
 [  10   11   12   13   14]
 ...
 [ 170   27   42  149  264]
 [  42   42  734 1731   42]
 [ 172  994 1530  451 2055]]
y:
[[   1    2    3    4]
 [   6    7    8    9]
 [  11   12   13   14]
 ...
 [  27   42  149  264]
 [  42  734 1731   42]
 [ 994 1530  451 2055]]


In [14]:
train_set = DataLoader(PoetryDataset(x_train, y_train), batch_size=128)

# Network

In [15]:
class RNN(torch.nn.Module):
    def __init__(self, embedding_size, hidden_size, num_layers=2):
        super().__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = torch.nn.Embedding(*embedding_size)
        self.embedding.weight = torch.nn.Parameter(torch.randn(*embedding_size))
        self.embedding.weight.requires_grad = True
        
        self.lstm = torch.nn.LSTM(embedding_size[1], hidden_size, num_layers, batch_first=True)
        
        self.classifier = torch.nn.Sequential(torch.nn.Dropout(0.2),
                                              torch.nn.Linear(self.hidden_size, self.hidden_size//2),
                                              torch.nn.Sigmoid(),
                                              
                                              torch.nn.Dropout(0.2),
                                              torch.nn.Linear(self.hidden_size//2, self.hidden_size//4),
                                              torch.nn.Sigmoid(),
                                              
                                              torch.nn.Dropout(0.2),
                                              torch.nn.Linear(self.hidden_size//4, self.embedding_size[0]),
                                             )
    def forward(self, x):
        x = self.embedding(x)
        hidden, _ = self.lstm(x.to(dtype=torch.float32), None)
        y = self.classifier(hidden)
        return y

In [16]:
num_words = len(embedding.idx2word)
# 诗句中出现的词总数
num_word_size = 200
# embedding后每个单词的维度
num_hidden_size = 1024
# LSTM隐藏层输出维度
num_epoch = 5

In [17]:
model = RNN((num_words, num_word_size), num_word_size)
model.to(device='cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
calc_loss = nn.CrossEntropyLoss()

In [18]:
def calc_acc(pred, y):
    pred = np.argmax(pred.detach().cpu().numpy(), axis=1)
    return (y.detach().cpu().numpy() == pred).sum() / pred.shape[0]

In [19]:
try:
    model = torch.load('5_epoch_model.pkl')
except FileNotFoundError:
    model.train()
    for epoch in range(num_epoch):
        start_time = time.time()
        train_loss = 0.0
        train_acc = 0.0

        for data in train_set:
            x = data[0].to(device='cuda', dtype=torch.long)
            y = data[1].to(device='cuda', dtype=torch.long)
            pred = model(x)[:,:-1,:]
            optimizer.zero_grad()
            # 去掉预测序列的最后最后一个值
            pred = pred.transpose(1,2)
            batch_loss = calc_loss(pred, y)
            batch_loss.backward()
            optimizer.step()

            train_loss += batch_loss / len(train_set)
            train_acc += calc_acc(pred, y) / len(train_set)

        print('[{:03d}/{:03d}] time:{:.2f}(sec) loss:{:.4f} acc:{:.4f}'.format(epoch+1, num_epoch, time.time()-start_time, train_loss, train_acc))

    torch.save(model, '{}_epoch_model.pkl'.format(num_epoch))

[001/005] time:23.91(sec) loss:7.2610 acc:0.0434
[002/005] time:23.97(sec) loss:7.2500 acc:0.0763
[003/005] time:23.23(sec) loss:7.2436 acc:0.0595
[004/005] time:22.14(sec) loss:7.2180 acc:0.0504
[005/005] time:23.92(sec) loss:7.1965 acc:0.0838


In [30]:
class Generator:
    def __init__(self, data, model, embedding):
        self.data = data
        self.model = model.eval()
        self.embedding = embedding
    
    def idx2word(self, idx_seq):
        return list(map(lambda x:self.embedding.idx2word[x], idx_seq[0]))
        
    def generate(self):
        i = random.randint(0, len(self.data))
        x = torch.tensor(self.data[i], device=torch.device('cuda'), dtype=torch.long).unsqueeze(0)
        pred = self.model(x)
        prob_pred = np.argsort(pred.detach().cpu().numpy(), axis=2)
        idx_pred = np.zeros((1,5))
        for i in range(5):
            rand_word_idx = np.random.randint(0,5,1)
            idx_pred[0,i] = prob_pred[:,i,rand_word_idx]
        idx_pred = idx_pred.astype(np.long)
        return [self.idx2word(x)[0]] + self.idx2word(idx_pred)[:-1]

In [31]:
generator = Generator(x_train, model, embedding)

In [36]:
for i in range(20):
    print(generator.generate())

['鞭', '笺', '娟', '寞', '倏']
['阴', '每', '寞', '樱', '但']
['残', '笺', '娟', '牡', '但']
['愁', '娟', '寞', '娟', '只']
['意', '蓉', '蓉', '辱', '偷']
['秋', '偷', '蓉', '樱', '偷']
['林', '霖', '翩', '扁', '笺']
['眼', '跎', '娟', '牡', '娟']
['余', '蓉', '寞', '樱', '但']
['长', '寞', '娟', '辱', '递']
['春', '偷', '寞', '或', '偷']
['孤', '聊', '寞', '特', '但']
['精', '寞', '娟', '特', '递']
['誓', '蓉', '翩', '樱', '聊']
['弟', '蓉', '跎', '翩', '寞']
['秋', '倏', '寞', '跎', '蓉']
['独', '寞', '翁', '笺', '倏']
['宿', '讵', '跎', '娟', '偷']
['又', '跎', '蓉', '笺', '递']
['愿', '跎', '寞', '跎', '跎']
