In [2]:
import torch
import random
import zipfile
import time
import math
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [1]:
def grad_clipping(params, theta, device):
    norm = torch.tensor([0.0], device = device)
    for param in params:
        norm += (param.grad.data**2).sum()
    norm = norm.sqrt().item()
    if norm > theta:
        for param in params:
            param.grad.data *= (theta/norm) 

In [4]:
def sgd(params, lr, batch_size):
    for param in params:
        param.data -= lr*param.grad/batch_size

In [5]:
def one_hot(x,n_class):
    x = x.long()
    res=torch.zeros(x.shape[0],n_class, dtype = torch.float, device = x.device)
    res.scatter_(1,x.view(-1,1),1)
    return res
def to_onehot(X, n_class):
    return [one_hot(X[:,i],n_class) for i in range(X.shape[1])]

In [7]:
def load_data_jay_lyrics():
    with zipfile.ZipFile('./Dataset/JayLyrics/jaychou_lyrics.txt.zip') as zin:
        with zin.open('jaychou_lyrics.txt') as f:
            corpus_chars = f.read().decode('utf-8')
    corpus_chars = corpus_chars.replace('\n',' ').replace('\r', ' ')
    corpus_chars = corpus_chars[:10000]
    idx_to_char = list(set(corpus_chars))
    char_to_idx_dict = dict([(char, i) for i, char in enumerate(idx_to_char)])
    idx_to_char_dict = dict([(i, char) for i, char in enumerate(idx_to_char)])
    corpus_indice = [char_to_idx_dict[char] for char in corpus_chars]
    return corpus_chars, corpus_indice, idx_to_char_dict, char_to_idx_dict, len(idx_to_char_dict)

corpus_chars, corpus_indice, idx_to_char_dict, char_to_idx_dict, vocab_size = load_data_jay_lyrics()

In [13]:
def data_iter_consecutive(corpus_indice, batch_size, num_step, device):
    corpus_indice = torch.tensor(corpus_indice, dtype = torch.float, device = device)
    data_len = len(corpus_indice)
    batch_len = data_len//batch_size
    indice = corpus_indice[0:batch_len*batch_size].view(batch_size, batch_len)
    epoch_size = batch_len//num_step
    
    for i in range(epoch_size):
        i = i*num_step
        X = indice[:, i: i+num_step]
        Y = indice[:, i+1:i+num_step+1]
        yield X, Y

In [11]:
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size = shape), device = device, dtype = torch.float)
        return  nn.Parameter(ts, requires_grad = True)
    
    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                nn.Parameter(torch.zeros(num_hiddens, device = device, dtype = torch.float, requires_grad = True)))
    
    W_xf, W_hf, b_f = _three()            #parameters of forget gate
    W_xi, W_hi, b_i = _three()            #parameters of input gate
    W_xc, W_hc, b_c = _three()            #parameters of candidate gate
    W_xo, W_ho, b_o = _three()            #parameters of output gate
    
    #输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = nn.Parameter(torch.zeros(num_outputs, device = device, dtype = torch.float), requires_grad=True)
    
    return nn.ParameterList([W_xf, W_hf, b_f, W_xi, W_hi, b_i, W_xc, W_hc, b_c, W_xo, W_ho, b_o, W_hq, b_q])

params = get_params()

In [14]:
def init_lstm_state(batch_size, num_hiddens):
    C = torch.zeros((batch_size, num_hiddens), dtype = torch.float, device = device)       #C initial matrix
    H = torch.zeros((batch_size, num_hiddens), dtype = torch.float, device = device)       #H initial matrix
    return (C, H)

def lstm(inputs, init_state, params):
    W_xf, W_hf, b_f, W_xi, W_hi, b_i, W_xc, W_hc, b_c, W_xo, W_ho, b_o, W_hq, b_q = params
    C, H = init_state
    outputs = []
    for X in inputs:
        F = torch.sigmoid(torch.mm(X, W_xf)+torch.mm(H, W_hf)+b_f)
        I = torch.sigmoid(torch.mm(X, W_xi)+torch.mm(H, W_hi)+b_i)
        O = torch.sigmoid(torch.mm(X, W_xo)+torch.mm(H, W_ho)+b_o)
        C_tilda = torch.tanh(torch.mm(X, W_xc)+torch.mm(H, W_hc)+b_c)
        C = F*C + I*C_tilda
        H = O*torch.tanh(C) 
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (C, H)

In [16]:
def predict_lstm(prefix, num_chars, init_state_fn, model, params, idx_to_char_dict, char_to_idx_dict):
    W_xf, W_hf, b_f, W_xi, W_hi, b_i, W_xc, W_hc, b_c, W_xo, W_ho, b_o, W_hq, b_q = params
    vocab_size = len(idx_to_char_dict)
    output = [prefix[0]]
    test_state = init_state_fn(len(output), W_hf.shape[1]) 
    
    for t in range(num_chars+len(prefix)-1):
        X_indice = torch.tensor([char_to_idx_dict[output[-1]]], device = device).view(-1,1)
        X_one_hot = to_onehot(X_indice, vocab_size)
        Y, test_state = model(X_one_hot, test_state, params)
        if t<len(prefix)-1: 
            output.append(prefix[t+1])
        else:    
            output.append(idx_to_char_dict[int(Y[0].argmax(dim = 1).item())])
    return ''.join(output)

In [37]:
def train_and_predict(corpus_indice, data_iter, batch_size, num_steps, num_hiddens, num_epoches,
                      device, init_state_fn, model, lr, clipping_theta, params,
                      idx_to_char_dict, char_to_idx_dict, 
                      prefixes, num_chars):
    
    vocab_size = len(idx_to_char_dict)
    loss_fn = nn.CrossEntropyLoss()
    
    for epoch in range(num_epoches): 
        train_loss_sum, n = 0.0, 0
        train_state = init_state_fn(batch_size, num_hiddens)
        for X_indice_list, Y_indice_list in data_iter(corpus_indice, batch_size, num_steps, device):
            train_state = (train_state[0].detach(), train_state[1].detach())
            X_one_hot = to_onehot(X_indice_list, vocab_size)
            Y_hat_one_hot, train_state = lstm(X_one_hot, train_state, params)
            outputs = torch.cat(Y_hat_one_hot, dim=0)
            Y_indice_list = torch.transpose(Y_indice_list, 0, 1).contiguous().view(-1).long()
            loss = loss_fn(outputs, Y_indice_list)
            
            if params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            
            loss.backward()
            
            grad_clipping(params, clipping_theta, device)
            sgd(params, lr, 1)
            train_loss_sum += loss.item()
            n += 1
            
        if (epoch+1)%50 == 0:
            train_loss = train_loss_sum/n
            try:
                perplexity = math.exp(train_loss)
            except OverflowError:
                perplexity = float('inf')
            print('epoch: %d, train_loss: %.2f, perplexity: %f'%(epoch+1, train_loss, perplexity))
            
            for prefix in prefixes:
                print(predict_lstm(prefix, num_chars, init_state_fn, model, params, idx_to_char_dict, char_to_idx_dict))

In [40]:
num_epoches, num_steps, batch_size, lr, clipping_theta = 300, 50, 32, 1e3, 1e-2
num_chars, prefixes = 50, ['分开','不分开']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

In [41]:
train_and_predict(corpus_indice, data_iter_consecutive, batch_size, num_steps, num_hiddens, num_epoches,
                      device, init_lstm_state, lstm, lr, clipping_theta, params,
                      idx_to_char_dict, char_to_idx_dict, 
                      prefixes, num_chars)

epoch: 50, train_loss: 3.24, perplexity: 25.495932
分开  我  我   我 我      我 我   我 我 我 我 我 我 我 我 我 我 我 我 我 
不分开暴你说  我      我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 
epoch: 100, train_loss: 2.33, perplexity: 10.245249
分开始的手不  我的外  我想  我想你的我的可爱你的可不  爱情  我想 我想你的看远    我   
不分开  我不能  我想 我不能 我想要 我想你的黑白 看 一个 我爱女  我想想 我想   我想你 我也
epoch: 150, train_loss: 1.66, perplexity: 5.256422
分开始的让我 一个许在那已经过  是我后能不是一定实很久了汉堡 在抽离开的我的 的 一切又 一个满  我
不分开暴风 我不想想了很美走 就怎么我跟透不会金 一步三步望著我 一枝步两  相爱我可爱女人  你爱我不能
epoch: 200, train_loss: 1.43, perplexity: 4.169124
分开始共山  在那里 不达米的快就耳 哼哼哼哈兮 飞 飞 如果我 在最  你  你 太彻   没有你是让
不分开暴  不舍 爱你 你 太快就这样 难过去 不放开不 后出的 对不能知后 哼哼哈兮 它在抽离  爱  
epoch: 250, train_loss: 1.09, perplexity: 2.987868
分开始共多 难过了 一颗三颗   你 你的完 是  的回忆  对不 后  我胸  你 你   你 你 太
不分开     你 你    你 你 太快就能活  这样 印地安的在 爱我  爱你 你    你 你 太快
epoch: 300, train_loss: 1.04, perplexity: 2.824185
分开始 一只看远  我 我的手 我胸起  我 我 我的手 我的手 我胸定 我的手不能知不会痛  我 我 
不分开暴 我 我不想  我的手不觉  我不能了离  我 我 我 我的手 我的手 我胸定 我的手不能知不会痛
epoch: 350, train_loss: 0.93, perplexity: 2.530883


KeyboardInterrupt: 