In [7]:
import torch
import random
import zipfile
import time
import math
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [122]:
#梯度裁剪虽然能解决梯度爆炸的问题，但无法解决梯度衰减
def grad_clipping(params, theta, device):
    norm = torch.tensor([0.0], device = device)
    for param in params:
        norm += (param.grad.data**2).sum()
    norm = norm.sqrt().item()
    if norm > theta:
        for param in params:
            param.grad.data *= (theta/norm) 

In [59]:
def sgd(params, lr, batch_size):
    for param in params:
        param.data -= lr*param.grad/batch_size

In [38]:
def one_hot(x,n_class):
    x = x.long()
    res=torch.zeros(x.shape[0],n_class, dtype = torch.float, device = x.device)
    res.scatter_(1,x.view(-1,1),1)
    return res
def to_onehot(X, n_class):
    return [one_hot(X[:,i],n_class) for i in range(X.shape[1])]

In [1]:
#（偏）导数反映的就是两个变量之间的依赖关系，因为两个变量的微分是通过（偏）导数来连接，反映线性依赖关系。
#循环神经网络的设计其实就是为了捕捉损失和各个网络变量之间的线性关系，梯度消失导致线性依赖关系消失
#向前传播建立依赖关系，向后传播捕捉依赖关系。但传统的rnn往往由于固定的依赖关系的建立而失去对依赖关系的更好的捕捉，而学习门来控制依赖关系的建立，是
#gated recurrent unit 门控循环单元， 通过学习门来控制信息流动

In [5]:
def load_data_jay_lyrics():
    with zipfile.ZipFile('./Dataset/JayLyrics/jaychou_lyrics.txt.zip') as zin:
        with zin.open('jaychou_lyrics.txt') as f:
            corpus_chars = f.read().decode('utf-8')
    corpus_chars = corpus_chars.replace('\n',' ').replace('\r', ' ')
    corpus_chars = corpus_chars[:10000]
    idx_to_char = list(set(corpus_chars))
    char_to_idx_dict = dict([(char, i) for i, char in enumerate(idx_to_char)])
    idx_to_char_dict = dict([(i, char) for i, char in enumerate(idx_to_char)])
    corpus_indice = [char_to_idx_dict[char] for char in corpus_chars]
    return corpus_chars, corpus_indice, idx_to_char_dict, char_to_idx_dict, len(idx_to_char_dict)

In [131]:
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size = shape), device = device, dtype = torch.float)
        return  nn.Parameter(ts, requires_grad = True)
    
    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                nn.Parameter(torch.zeros(num_hiddens, device = device, dtype = torch.float, requires_grad = True)))
    
    W_xr, W_hr, b_r = _three()
    W_xz, W_hz, b_z = _three()
    W_xh, W_hh, b_h = _three()
    
    #输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = nn.Parameter(torch.zeros(num_outputs, device = device, dtype = torch.float), requires_grad=True)
    
    return nn.ParameterList([W_xr, W_hr, b_r, W_xz, W_hz, b_z, W_xh, W_hh, b_h, W_hq, b_q])

In [150]:
def init_gru_state(batch_size, num_hiddens):
    return (torch.zeros((batch_size, num_hiddens), dtype = torch.float, device = device), )
def gru(inputs, state, params):
    W_xr, W_hr, b_r, W_xz, W_hz, b_z, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        R = torch.sigmoid(torch.mm(X, W_xr)+torch.mm(H, W_hr)+b_r)
        Z = torch.sigmoid(torch.mm(X, W_xz)+torch.mm(H, W_hz)+b_z)
        H_tilda = torch.tanh(torch.mm(X, W_xh)+torch.mm(R*H,W_hh)+b_h)
        H = Z*H +(1-Z)*H_tilda 
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H,)

In [151]:
def data_iter_consecutive(corpus_indice, batch_size, num_step, device):
    corpus_indice = torch.tensor(corpus_indice, dtype = torch.float, device = device)
    data_len = len(corpus_indice)
    batch_len = data_len//batch_size
    indice = corpus_indice[0:batch_len*batch_size].view(batch_size, batch_len)
    epoch_size = batch_len//num_step
    
    for i in range(epoch_size):
        i = i*num_step
        X = indice[:, i: i+num_step]
        Y = indice[:, i+1:i+num_step+1]
        yield X, Y

In [152]:
def predict_rnn(prefix, num_chars, init_state, model, params, idx_to_char_dict, char_to_ida_dict):
    W_xr, W_hr, b_r, W_xz, W_hz, b_z, W_xh, W_hh, b_h, W_hq, b_q = params
    vocab_size = len(idx_to_char_dict)
    output = [prefix[0]]
    test_state = init_state(len(output), W_hr.shape[1]) 
    
    for t in range(num_chars+len(prefix)-1):
        X_indice = torch.tensor([char_to_idx_dict[output[-1]]], device = device).view(-1,1)
        X_one_hot = to_onehot(X_indice, vocab_size)
        Y, test_state = model(X_one_hot, test_state, params)
        if t<len(prefix)-1: 
            output.append(prefix[t+1])
        else:    
            output.append(idx_to_char_dict[int(Y[0].argmax(dim = 1).item())])
    return ''.join(output)

In [174]:
def train_and_predict(corpus_indice, data_iter, batch_size, num_steps, num_hiddens, device,
                      num_epoches, init_state, model, lr, clipping_theta, params,
                      idx_to_char_dict, char_to_idx_dict, 
                      prefixes, num_chars):
    
    vocab_size = len(idx_to_char_dict)
    loss_fn = nn.CrossEntropyLoss()
    params = params.to(device)
    
    for epoch in range(num_epoches): 
        train_loss_sum, n = 0.0, 0
        train_state = init_state(batch_size, num_hiddens)
        for X_indice_list, Y_indice_list in data_iter(corpus_indice, batch_size, num_steps, device):
            train_state = (train_state[0].detach(),)
            X_one_hot = to_onehot(X_indice_list, vocab_size)
            Y_hat_one_hot, state = gru(X_one_hot, train_state, params)
            outputs = torch.cat(Y_hat_one_hot, dim=0)
            Y_indice_list = torch.transpose(Y_indice_list, 0, 1).contiguous().view(-1).long()
            loss = loss_fn(outputs, Y_indice_list)
            
            if params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            
            loss.backward()
            
            grad_clipping(params, clipping_theta, device)
            sgd(params, lr, 1)
            train_loss_sum += loss.item()
            n += 1
            
        if (epoch+1)%50 == 0:
            train_loss = train_loss_sum/n
            try:
                perplexity = math.exp(train_loss)
            except OverflowError:
                perplexity = float('inf')
            print('epoch: %d, train_loss: %.2f, perplexity: %f'%(epoch+1, train_loss, perplexity))
            
            for prefix in prefixes:
                print(predict_rnn(prefix, num_chars, init_state, model, params, idx_to_char_dict, char_to_idx_dict))

In [132]:
corpus_chars, corpus_indice, idx_to_char_dict, char_to_idx_dict, vocab_size = load_data_jay_lyrics()

In [162]:
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
num_chars, prefixes = 40, ['分开','不分开']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
params = get_params()

In [175]:
train_and_predict(corpus_indice, data_iter_consecutive, batch_size, num_steps, num_hiddens, device,
                      num_epochs, init_gru_state, gru, lr, clipping_theta, params,
                      idx_to_char_dict, char_to_idx_dict, 
                      prefixes, num_chars)

epoch: 50, train_loss: 0.04, perplexity: 1.036668
分开 我不能再想 我不 我不 我不能 爱情走的太快就像龙卷风 不能承受我已无处可躲 我不要再想 我不要再
不分开 她已经 娘子默子在 这样沙什么 不手不多 你给空 难墟烟有 在在转中 全隐村日 恨伤己苦 全面了纵
epoch: 100, train_loss: 0.04, perplexity: 1.035857
分开 我够能再生活 静静悄悄默默离开 陷入了危险边缘Baby  我的世界已狂风暴雨 Wu  爱情来的太快
不分开 她已经 后再简一切 后后风也目到 我都想的有样有样 我好声都满堡对唱都也张意义 或许在每后能开到 
epoch: 150, train_loss: 0.04, perplexity: 1.042424
分开 我爱能的风写在 别力的最像 我的念有已静欣赏赏那张张蝠 或来心人威过一只止一远 决轻的叹息 随制狠
不分开 她已经 娘子这这在江袋 我 就带你骑棒对 想这样没担忧 唱着歌 一直走 我想就这样牵着你的手不放开
epoch: 200, train_loss: 0.04, perplexity: 1.045460
分开 我不能再想 我不要再想 我不 我不 我不要再想你 不知不觉 你打我离 这样对吗干嘛这样 何必让酒牵
不分开几远单纯没有悲哀 我 想带你骑单车 我 想和你看棒球 想这样没担忧 唱着歌 一直走 我想就这样牵着你
epoch: 250, train_loss: 0.03, perplexity: 1.034784
分开 我爱能再风 我不好好生 我不能 让不走 这样球 气墟病夫 招家的梦 恨自放动 恨没有红的没用 情绪
不分开几远她纯 默是在一个坦日日 一记令我 在色村外废弃 它乡情起的我有样  后道道其其写回雨 脸今汹初失
epoch: 300, train_loss: 0.04, perplexity: 1.036067
分开 风色海烛 温暖了空屋 蔓色蜡烛 温暖了空屋 双截棍 岩烧店的烟味弥漫 古底沙沙a盯 带领下我满得事
不分开几远单纯 迎著一断 再来了渡每一天 手牵手 一步两步三步四步望著天 看星星 一颗两颗三颗四颗 连成线
