In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
os.chdir('drive/MyDrive/Colab/Task5')

## 1. 数据预处理

In [None]:
def readFile(path):
    with open(path) as fp:
        datas = []
        lines = fp.readlines()

        curr_doc = ""
        for line in lines:
            line = line.strip('\n')

            if len(line) != 0:
                datas.append(line)

    return datas 

In [None]:
train_datas = readFile('dataset/poetryFromTang.txt')

In [None]:
import torch.utils.data
import jieba

class TangPoetryDataset(torch.utils.data.Dataset):
    def __init__(self, datas):
        super(TangPoetryDataset).__init__()
        self.vocab, self.idx2word, self.datas = self._tokenize(datas)

    def _tokenize(self, datas):
        idx2word = ['<pad>', '<unk>', '<sos>', '<eos>']
        word2idx = {'<pad>':0, '<unk>':1, '<sos>':2, '<eos>':3}
        result = []

        cur_idx = 4
        for line in datas:
            line = jieba.lcut(line)
            cur_doc = []

            for word in line:
                # 如果当前词不在词表里，添加进去
                if word not in word2idx:
                    idx2word.append(word)
                    word2idx[word] = cur_idx
                    cur_idx = cur_idx + 1
                if word != "，" and word != "。":   # 这个会导致模型学到的只有符号
                    cur_doc.append(word2idx[word])
            result.append(cur_doc)
        return word2idx, idx2word, result
    
    def __getitem__(self, index):
        doc = [self.vocab['<sos>']] + self.datas[index]
        label = self.datas[index] + [self.vocab['<eos>']]
        return doc, label

    def __len__(self):
        return len(self.datas)

In [None]:
train_dataset = TangPoetryDataset(train_datas)

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.932 seconds.
Prefix dict has been built successfully.


## 2. LSTM模型

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class LSTM(nn.Module):
    def __init__(self, args, vocab):
        super(LSTM, self).__init__()
        self.args = args
        self.vocab = vocab

        self.embedding = nn.Embedding(num_embeddings=len(vocab),
                                      embedding_dim=args.embed_size,
                                      padding_idx=0)
        
        self.lstm = nn.LSTM(input_size=args.embed_size,
                            hidden_size=args.hidden_size,
                            batch_first=True)

        self.output = nn.Linear(args.hidden_size, len(vocab))

        self.dropout = nn.Dropout(args.dropout_rate)
    
    def forward(self, doc, mask):
        batch_size = doc.shape[0]
        # Embedding Layer
        # (batch_size, length) -> (batch_size, length, embed_dim)
        embed_doc = self.embedding(doc)
        embed_doc = self.dropout(embed_doc)

        # Input Encoding Layer
        # (batch_size, length, embed_dim) -> (batch_size, length, hidden_dim)
        encoded_doc, _ = self.lstm(embed_doc)
        encoded_doc = self.dropout(encoded_doc)
        #encoded_doc = encoded_doc * mask.view(batch_size, -1, 1)

        # Output Layer
        # (batch_size, length, hidden_dim) -> (batch_size, length, vocab_size)
        out = self.output(encoded_doc)
        
        return out
    
    # 生成文本
    def generate(self, seq_in):
        len = seq_in.shape[0] # 记录序列长度，防止死循环
        
        # 处理初始输入序列
        seq_in = seq_in.view(1, -1)
        embed_seq = self.embedding(seq_in)
        _, (h_n, c_n) = self.lstm(embed_seq)
        next_token = torch.argmax(self.output(h_n.view(1, -1)).view(-1))

        res = torch.cat([seq_in.view(-1), next_token.view(-1)], dim=0)
        while next_token != self.vocab['<eos>'] and res.shape[0] < self.args.max_len:
            embed_token = self.embedding(next_token.view(1, -1))
            _, (h_n, c_n) = self.lstm(embed_token, (h_n, c_n))
            next_token = torch.argmax(self.output(h_n.view(1, -1)).view(-1))

            res = torch.cat([res.view(-1), next_token.view(-1)], dim=0)

        return res

## 3. 训练过程

In [None]:
def collate_fn(batch_data):
    batch_size = len(batch_data)

    # 以batch中最长的句子作为长度进行padding
    max_len = max([len(x[0]) for x in batch_data])

    vec = torch.ones((batch_size, max_len), dtype=torch.int64)
    mask = torch.zeros((batch_size, max_len), dtype=torch.bool)
    padded_label = torch.zeros((batch_size, max_len), dtype=torch.int64)
    for i, example in enumerate(batch_data):
        for j, x in enumerate(example[0]):
            vec[i, j] = x
            mask[i, j] = 1
        
        for j, x in enumerate(example[1]):
            padded_label[i, j] = x
    
    return (vec, mask, padded_label)

困惑度计算方法如下：
$$\mathbf{PPL}(\theta)=(\Pi_n\Pi_tp_{\theta}(x_t^{(n)}|x_{1:(t-1)}^{(n)}))^{-1/T}=\exp(-\frac{1}{T}\Sigma_n\Sigma_t\log p_{\theta}(x_t^{(n)}|x_{1:(t-1)}^{(n)}))=\exp(-\frac{1}{T}CE(\theta))$$
根据算式，我们可以通过累加模型的交叉熵，即可算出困惑度。

In [None]:
import jieba
import numpy as np
from tqdm import tqdm

# 处理输入序列
def preprocess(seq, vocab):
    seq = jieba.lcut(seq)
    seq = ['<sos>'] + seq

    res = torch.zeros((len(seq)), dtype=torch.int64)
    for i, x in enumerate(seq):
        res[i] = vocab[x]
    
    return res

# 处理输出序列
def postprocess(seq, idx2word):
    res = ""
    for x in seq:
        res += idx2word[x]
    
    return res

def train_progress(args, model, vocab, idx2word, optimizer, criterion, train_dataloader, device):
    train_prep_arr = []
    best_train_prep = 1e9
    best_epoch = -1
    
    for epoch in range(args.epochs):
        print(f"Epoch {epoch}:")

        # Training
        train_len = 0
        train_loss = 0.

        model.train()
        for step, datas in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            optimizer.zero_grad()

            text, mask, labels = [data.to(device) for data in datas]
        
            output = model(text, mask)
            output = output.masked_fill(~mask.view(output.shape[0], output.shape[1], 1), 0.).cpu()
            loss = criterion(output.view(-1, len(vocab)), labels.view(-1))
            
            # 训练
            loss.backward()
            optimizer.step()

            # 统计结果
            size = labels.numel()
            train_len += size
            train_loss += loss.item()

        train_prep = np.exp(train_loss / train_len)
        train_prep_arr.append(train_prep)
        print(f'Train: | perplexity: {train_prep}')

        if train_prep < best_train_prep:
            best_train_prep = train_prep
            best_epoch = epoch
            torch.save(model.state_dict(), "models/best_model.pth")
        
        model.eval()
        with torch.no_grad():
            text = preprocess("烟笼寒水月笼沙", vocab)
            res = model.generate(text)
            print(postprocess(res, idx2word))
    return train_prep_arr

In [None]:
class Arguments():
    epochs = 30
    batch_size = 16
    lr = 5e-5
    embed_size = 50
    hidden_size = 50
    dropout_rate = 0.2
    max_len = 50

In [None]:
import torch.optim

args = Arguments()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab = train_dataset.vocab
idx2word = train_dataset.idx2word
model = LSTM(args, vocab)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss(reduction='sum')

if torch.cuda.is_available():
    model.cuda()

In [None]:
import torch.utils.data

train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               collate_fn=collate_fn)

train_prep_arr = train_progress(args=args,
                                model=model,
                                vocab=vocab,
                                idx2word=idx2word,
                                optimizer=optimizer,
                                criterion=criterion,
                                train_dataloader=train_dataloader, 
                                device=device)

Epoch 0:


100%|██████████| 45/45 [00:01<00:00, 27.60it/s]


Train: | perplexity: 5475.2560992542185
<sos>烟笼寒水月笼沙客中遇天地关西洋洋新地清栖钓川涨金谷游人霜严玄都齐鲁高标齐鲁高标齐鲁高标暮雨复道彤庭登楼筝长流天生疏驿父执驿父执偃溟驿雷霆朝来盈尺能支毕屠何由竦乱落缘底先贤泪如珠复道之灾之灾
Epoch 1:


100%|██████████| 45/45 [00:01<00:00, 28.75it/s]


Train: | perplexity: 5461.002788381642
<sos>烟笼寒水月笼沙客中遇天地关西有名关西昔时之庆连晓战锒铛三百杯游魂征狄前轩地清栖钓川涨金谷天子再述画图韬当今要津困梁玉绳过射陇亩偃溟驿潇湘潇湘恐泥高标潇湘恐泥再拜字猛将画图清影机巧独送松柏平津
Epoch 2:


100%|██████████| 45/45 [00:01<00:00, 28.07it/s]


Train: | perplexity: 5449.004625458652
<sos>烟笼寒水月笼沙客中遇天地关西有名关西昔时之庆连晓战锒铛三百杯游魂征狄前轩地清栖钓川涨金谷天子再述画图韬当今要津困梁玉绳过射陇亩偃溟驿潇湘潇湘恐泥枭骜过无时竦连晓战梁未试二十四斥余朱炎赫风寒风寒
Epoch 3:


100%|██████████| 45/45 [00:01<00:00, 28.65it/s]


Train: | perplexity: 5438.0292835147975
<sos>烟笼寒水月笼沙客中遇天地关西有名关西昔时之庆连晓战童子晚前轩地清栖钓川涨金谷三十梁玉觞朱炎赫云间蒲为疏主人碧云天碧云天葵藿倾荷俱物役难甘原韬跣辕门缠黄云驱复道之灾之灾皇谟载衰病泾鼓间紫焰浮深异堆臂
Epoch 4:


100%|██████████| 45/45 [00:01<00:00, 28.80it/s]


Train: | perplexity: 5428.341789447291
<sos>烟笼寒水月笼沙客中遇天地关西有名关西昔时之庆连晓战童子晚前轩地清栖钓川涨金谷三十梁玉觞朱炎赫云间蒲为疏主人碧云天碧云天葵藿倾荷肉自僻近世人暮春知君命蒲为愚漂沙献延秋门龙堆连晓战锒铛三百杯钓广文到荷二十
Epoch 5:


100%|██████████| 45/45 [00:01<00:00, 29.28it/s]


Train: | perplexity: 5413.186902052568
<sos>烟笼寒水月笼沙客中遇天地关西有名关西有名涡者夜发画图风寒风寒勤王谢麻姑众山平津臂复道自适臂复道葱葱行自迟萦盈萦盈登楼玉绳缧震之盘右军不宁地清栖钓川涨消息每岁谢麻姑众山平津臂复道自适臂复道葱葱
Epoch 6:


100%|██████████| 45/45 [00:01<00:00, 29.17it/s]


Train: | perplexity: 5404.079690056019
<sos>烟笼寒水月笼沙客中遇天地关西有名关西有名涡者夜发画图风寒风寒勤王谢麻姑众山平津臂复道自适臂复道葱葱行自迟萦盈萦盈登楼玉绳玉泉睢请公问游魂超古今复道赐名炎月北潇湘复道晚玉绳玉泉睢请公问游魂超
Epoch 7:


100%|██████████| 45/45 [00:01<00:00, 29.22it/s]


Train: | perplexity: 5389.257768885635
<sos>烟笼寒水月笼沙客中遇天地既东郊花门复道陇亩偃溟驿潇湘潇湘星复道赐名炎月北潇湘复道晚玉绳玉泉睢请公问游魂超古今复道生鹊广文到登楼走穷谷断绝执秣马地清栖钓川涨金谷游人霜严请公问陌马空天生况资菱下有
Epoch 8:


100%|██████████| 45/45 [00:01<00:00, 29.27it/s]


Train: | perplexity: 5376.846753648327
<sos>烟笼寒水月笼沙客中遇天地既东郊花门复道陇亩偃溟驿潇湘潇湘星鼓间鸿毛以兹悟以兹悟下来自适臂复道自适臂复道葱葱行自迟萦盈萦盈登楼玉绳玉泉睢请公问游魂超古今复道生鹊广文到僻近齐鲁齐鲁高标齐鲁高标
Epoch 9:


100%|██████████| 45/45 [00:01<00:00, 29.14it/s]


Train: | perplexity: 5363.415258075109
<sos>烟笼寒水月笼沙客中遇天地既东郊花门复道行自迟萦盈萦盈天地间竦断折飧炎月北潇湘星复道彤庭登楼筝<eos>
Epoch 10:


100%|██████████| 45/45 [00:01<00:00, 28.90it/s]


Train: | perplexity: 5358.993387078175
<sos>烟笼寒水月笼沙客中遇天地既东郊花门复道湿声云间云间蒲为疏主人芳华西江雷霆竦乱落穷竟慷慨东郊执号谢麻姑众山平津臂复道<eos>
Epoch 11:


100%|██████████| 45/45 [00:01<00:00, 28.89it/s]


Train: | perplexity: 5341.419271957817
<sos>烟笼寒水月笼沙客中遇天地既东郊花门复道<eos>
Epoch 12:


100%|██████████| 45/45 [00:01<00:00, 29.18it/s]


Train: | perplexity: 5331.2723738296645
<sos>烟笼寒水月笼沙客中遇天地既东郊花门复道<eos>
Epoch 13:


100%|██████████| 45/45 [00:01<00:00, 29.06it/s]


Train: | perplexity: 5316.852334536886
<sos>烟笼寒水月笼沙鸿毛鸿毛以兹悟以兹悟下来十年自适亿广文到号自僻近齐鲁名军玉泉<eos>
Epoch 14:


100%|██████████| 45/45 [00:01<00:00, 29.54it/s]


Train: | perplexity: 5297.928679040219
<sos>烟笼寒水月笼沙鸿毛鸿毛以兹悟以兹悟下来十年自适亿<eos>
Epoch 15:


100%|██████████| 45/45 [00:01<00:00, 29.01it/s]


Train: | perplexity: 5282.109440343376
<sos>烟笼寒水月笼沙鸿毛鸿毛以兹悟以兹悟且且<eos>
Epoch 16:


100%|██████████| 45/45 [00:01<00:00, 28.62it/s]


Train: | perplexity: 5271.193668368601
<sos>烟笼寒水月笼沙鸿毛鸿毛以兹悟<eos>
Epoch 17:


100%|██████████| 45/45 [00:01<00:00, 29.02it/s]


Train: | perplexity: 5243.691091761153
<sos>烟笼寒水月笼沙鸿毛鸿毛以兹悟<eos>
Epoch 18:


100%|██████████| 45/45 [00:01<00:00, 28.99it/s]


Train: | perplexity: 5222.894398829187
<sos>烟笼寒水月笼沙鸿毛鸿毛<eos>
Epoch 19:


100%|██████████| 45/45 [00:01<00:00, 28.52it/s]


Train: | perplexity: 5200.122809316974
<sos>烟笼寒水月笼沙鸿毛<eos>
Epoch 20:


100%|██████████| 45/45 [00:01<00:00, 29.18it/s]


Train: | perplexity: 5178.585325949592
<sos>烟笼寒水月笼沙鸿毛<eos>
Epoch 21:


100%|██████████| 45/45 [00:01<00:00, 29.31it/s]


Train: | perplexity: 5142.70086431673
<sos>烟笼寒水月笼沙<eos>
Epoch 22:


100%|██████████| 45/45 [00:01<00:00, 29.03it/s]


Train: | perplexity: 5108.956158902017
<sos>烟笼寒水月笼沙<eos>
Epoch 23:


100%|██████████| 45/45 [00:01<00:00, 28.43it/s]


Train: | perplexity: 5050.12586627624
<sos>烟笼寒水月笼沙<eos>
Epoch 24:


100%|██████████| 45/45 [00:01<00:00, 29.16it/s]


Train: | perplexity: 4969.272317558007
<sos>烟笼寒水月笼沙<eos>
Epoch 25:


100%|██████████| 45/45 [00:01<00:00, 29.23it/s]


Train: | perplexity: 4841.885960578866
<sos>烟笼寒水月笼沙<eos>
Epoch 26:


100%|██████████| 45/45 [00:01<00:00, 28.77it/s]


Train: | perplexity: 4612.426412349055
<sos>烟笼寒水月笼沙<eos>
Epoch 27:


100%|██████████| 45/45 [00:01<00:00, 29.83it/s]


Train: | perplexity: 4321.191964825249
<sos>烟笼寒水月笼沙<eos>
Epoch 28:


100%|██████████| 45/45 [00:01<00:00, 29.52it/s]


Train: | perplexity: 4091.2394917278593
<sos>烟笼寒水月笼沙<eos>
Epoch 29:


100%|██████████| 45/45 [00:01<00:00, 28.67it/s]

Train: | perplexity: 3933.4033738074513
<sos>烟笼寒水月笼沙<eos>





## 4. 结果分析

可能是数据太少或者是我使用了`jieba`提供的中文分词方法，导致词表太少，或者是因为词向量没有经过预训练，导致是词向量质量很差，这样也导致了生成模型训练的效果很差。\
虽然模型到最后能学习到输出以`<eos>`结尾，但纯用贪心搜索的生成方式，模型总喜欢到最后只输出一个`<eos>`，可能需要采用beam search等方式。\
训练时试过将，。这两个符号加上，模型到最后会只生成`，。<eos>`虽然看上去确实学到了符号的关联性，但仅此而已，生成出来的句子完全不能用。