<a href="https://colab.research.google.com/github/skywalker0803r/mxnet_course/blob/master/mxnet_lyrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install mxnet
#!pip install d2lzh
from mxnet import autograd,gluon,init,nd,cpu
import d2lzh as d2l
import math
from mxnet.gluon import loss as gloss,nn,rnn
import time
from google.colab import drive
import random
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
def load_lyrics(path):
  with open(path,encoding='utf8') as f:
    corpus_chars = f.read().replace('\n',' ').replace('\r',' ').replace('\u3000','')
  idx_to_char = list(set(corpus_chars))
  char_to_idx = dict([(char,i) for i,char in enumerate(idx_to_char)])
  vocab_size = len(char_to_idx)
  corpus_indices = [char_to_idx[char] for char in corpus_chars]
  return corpus_indices,char_to_idx,idx_to_char,vocab_size

# load data

In [3]:
(corpus_indices,char_to_idx,idx_to_char,vocab_size) = load_lyrics('/content/drive/My Drive/mxnet_dataset/歌詞.txt')
print('chars:',[idx_to_char[i] for i in corpus_indices[:20]])
print('indices:',corpus_indices[:20])

chars: ['窗', '外', '的', '麻', '雀', '在', '電', '線', '桿', '上', '多', '嘴', ' ', '妳', '說', '這', '一', '句', '很', '有']
indices: [77, 91, 52, 17, 48, 34, 2, 30, 76, 137, 101, 14, 39, 47, 136, 87, 79, 82, 21, 122]


# model

In [0]:
num_hiddens = 256
rnn_layer = rnn.RNN(num_hiddens)
rnn_layer.initialize()

In [5]:
state = rnn_layer.begin_state(batch_size=2)
state[0].shape

(1, 2, 256)

In [0]:
class RNNModel(nn.Block):
  def __init__(self,rnn_layer,vocab_size,**kwargs):
    super().__init__(**kwargs)
    self.rnn = rnn_layer
    self.vocab_size = vocab_size
    self.dense = nn.Dense(vocab_size)
  
  def forward(self,inputs,state):
    X = nd.one_hot(inputs.T,self.vocab_size)
    Y,state = self.rnn(X,state)
    output = self.dense(Y.reshape((-1,Y.shape[-1])))
    return output ,state
  
  def begin_state(self,*args,**kwargs):
    return self.rnn.begin_state(*args,**kwargs) 

# 訓練前

In [0]:
model = RNNModel(rnn_layer,vocab_size)
model.initialize(force_reinit=True)

In [8]:
d2l.predict_rnn_gluon('窗外',10,model,vocab_size,cpu(),idx_to_char,char_to_idx)

'窗外 ~節上結結結香上永'

# 開始訓練

In [9]:
d2l.train_and_predict_rnn_gluon(model = model,
                                num_hiddens = 256,
                                vocab_size = vocab_size,
                                ctx = cpu(),
                                corpus_indices = corpus_indices,
                                idx_to_char = idx_to_char,
                                char_to_idx = char_to_idx,
                                num_epochs = 500,
                                num_steps = 11,
                                lr = 1e2,
                                clipping_theta = 1e-2,
                                batch_size = 32,
                                pred_period = 10,
                                pred_len = 10,
                                prefixes=['窗外'])

epoch 10, perplexity 93.312315, time 0.02 sec
 - 窗外的的的的的的的的的的
epoch 20, perplexity 81.229718, time 0.02 sec
 - 窗外的的的的的的的的的的
epoch 30, perplexity 51.952918, time 0.02 sec
 - 窗外的愛溢出就像雨 出就
epoch 40, perplexity 35.027126, time 0.02 sec
 - 窗外的愛溢出就像雨念出就
epoch 50, perplexity 24.136526, time 0.03 sec
 - 窗外的愛溢出的一雨我的愛
epoch 60, perplexity 18.228982, time 0.02 sec
 - 窗外是我的愛溢出就像雨水
epoch 70, perplexity 14.291189, time 0.02 sec
 - 窗外是我詩愛溢出就像雨水
epoch 80, perplexity 14.617773, time 0.02 sec
 - 窗外是我也無法我的愛溢出
epoch 90, perplexity 9.723547, time 0.02 sec
 - 窗外夜我詩愛溢出就像雨 
epoch 100, perplexity 12.044058, time 0.02 sec
 - 窗外的著溢出就像雨水 窗
epoch 110, perplexity 7.347678, time 0.02 sec
 - 窗外的愛溢出就像雨水 窗
epoch 120, perplexity 5.172585, time 0.02 sec
 - 窗外的著寫在就像雨水 嘴
epoch 130, perplexity 4.973076, time 0.02 sec
 - 窗外夜我的愛溢出就像雨水
epoch 140, perplexity 5.311710, time 0.02 sec
 - 窗外的麻雀在就像雨水 窗
epoch 150, perplexity 3.582867, time 0.02 sec
 - 窗外的麻雀在電像雨水多窗
epoch 160, perplexity 3.184749, time 0.02 sec
 - 窗外的麻溢在電像的上念厚
epoch 17

# 測試

In [10]:
d2l.predict_rnn_gluon('窗外',10,model,vocab_size,cpu(),idx_to_char,char_to_idx)

'窗外的麻雀在電線桿上多嘴'