## Task 5
- 使用 GRU、LSTM 计算字符级别语言困惑度  


其实语言困惑度就是对RNN模型在训练过程中求出的损失值的指数形式（也即交叉熵的指数值），由于不要求输出对应的样例，所以就只在训练过程中输出对应的困惑度的值；

In [1]:
# 根据给定 txt 建立语料库
path = './poetryFromTang.txt'
text = open(path,'rb').read().lower().decode('utf-8')
print ('corpus length:', len(text))
#print(type(text))
chars = set(text)
#print(chars[0:10])
print ('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))

corpus length: 16647
total chars: 2514


In [2]:
# 对训练集进行处理
maxlen = 40
step = 3
sentences = []
# 判断下一个字符的信息
next_chars = []
for i in range(0, len(text) - maxlen, step):
    sentences.append(text[i : i + maxlen])
    next_chars.append(text[i + maxlen])

In [14]:
print(len(sentences[0]))

40


In [106]:
import numpy as np
X = []
y = []
for i ,sentence in enumerate(sentences):
    x = []
    for _ , char in enumerate(sentence):
        x.append(char_indices[char])
    X.append(x)
    y.append(char_indices[next_chars[i]])

In [108]:
print((X[2]))
print(y[1])

[644, 1980, 1544, 2073, 1373, 154, 2254, 406, 762, 2401, 1048, 1549, 435, 513, 1730, 1552, 126, 622, 1544, 619, 2053, 2347, 105, 2089, 834, 1763, 1674, 1048, 2228, 958, 2197, 1259, 1875, 2033, 1905, 1544, 2406, 311, 1667, 2504]
311


In [78]:
# 创建模型
# 创建 RNN + GRU 类
import torch.nn as nn

class BaseGRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0, bidirectional=False):
        super(BaseGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_directions = 2 if bidirectional else 1
        self.embedding = nn.Embedding(2514, 300)
        self.gru = nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional, dropout=dropout)
        self.h2o = nn.Linear(self.num_directions * hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)  
    def forward(self, inputs):
        hidden = self.initHidden(False)
        line = self.embedding(inputs)
        line = torch.transpose(line, 0, 1)
        output, hidden = self.gru(line, hidden)
        output = self.h2o(output[line.size(0)-1])
        output = self.softmax(output)
        return output    
    def initHidden(self, is_cuda=True):
        if is_cuda:
            hidden = torch.zeros(self.num_layers*self.num_directions, 1, self.hidden_size).cuda()
        else:
            hidden = torch.zeros(self.num_layers*self.num_directions, 1, self.hidden_size)
        return hidden

In [112]:
# 创建模型并进行训练
import torch.optim as optim
import math
import torch
criterion = nn.CrossEntropyLoss()
learning_rate = 0.05
n_hidden = 128
model = BaseGRU(300, n_hidden,len(chars), 2, 0.5, bidirectional=True)
def train(model, category_tensor, line_tensor, weight_clip=0.1):
    output = model(line_tensor)
    model.zero_grad()
    loss = criterion(output, category_tensor)
    loss.backward()
    for p in model.parameters():
        if hasattr(p.grad, "data"):
            p.data.add_(-learning_rate, p.grad.data)
    return output, loss.item()
for iter in range(0,1):
    for i in range(0,100):
        Y = torch.tensor([y[i]], dtype=torch.long)
        x = torch.tensor([X[i]], dtype=torch.long)
        _, loss = train(model,Y,x)
        print('perplexity:{}'.format(math.exp(loss)))

perplexity:2645.736828925497
perplexity:2951.3793151299483
perplexity:2509.8388895133567
perplexity:3018.1107104001226
perplexity:2477.552070313373
perplexity:1952.6048012003296
perplexity:2628.1760619370925
perplexity:3072.370747280244
perplexity:2725.3983313229996
perplexity:2806.180224461259
perplexity:2420.739026827909
perplexity:2888.0025117170517
perplexity:2344.9632858738937
perplexity:2738.640539417584
perplexity:2012.0085305711323
perplexity:2679.4806241140277
perplexity:2771.5134421120947
perplexity:2085.6913129180034
perplexity:2744.46967799424
perplexity:2220.997818287951
perplexity:2363.733741537051
perplexity:3000.566883377853
perplexity:2295.051606359756
perplexity:3024.983256069884
perplexity:2453.5463465641287
perplexity:2263.5873865196995
perplexity:1536.1524081058724
perplexity:2774.6049737495887
perplexity:2577.306895368079
perplexity:2052.5402243587655
perplexity:1936.9772064227448
perplexity:3067.8881276763705
perplexity:2198.387449320412
perplexity:2352.447906070