In [1]:
from __future__ import print_function
import argparse
import math
import sys
import time

import numpy as np
import six

import chainer
from chainer import cuda
import chainer.links as L
from chainer import optimizers
from chainer import serializers
import chainer.functions as F

In [2]:
class RNNLM(chainer.Chain):

    """Recurrent neural net languabe model for penn tree bank corpus.

    This is an example of deep LSTM network for infinite length input.

    """
    def __init__(self, n_vocab, n_units, train=True):
        super(RNNLM, self).__init__(
            embed=L.EmbedID(n_vocab, n_units),
            l1=L.LSTM(n_units, n_units),
            l2=L.LSTM(n_units, n_units),
            l3=L.Linear(n_units, n_vocab),
        )
        self.train = train

    def reset_state(self):
        self.l1.reset_state()
        self.l2.reset_state()

    def __call__(self, x):
        h0 = self.embed(x)
        h1 = self.l1(F.dropout(h0, train=self.train))
        h2 = self.l2(F.dropout(h1, train=self.train))
        y = self.l3(F.dropout(h2, train=self.train))
        return y

In [15]:
vocab = {}


def load_data(filename):
    global vocab, n_vocab
    words = open(filename).read().replace('\n', '<eos>').strip().split()
    dataset = np.ndarray((len(words),), dtype=np.int32)
    for i, word in enumerate(words):
        if word not in vocab:
            vocab[word] = len(vocab)
        dataset[i] = vocab[word]
    return dataset

In [16]:
train_data = load_data('ptb.train.txt')
valid_data = load_data('ptb.valid.txt')
test_data = load_data('ptb.test.txt')

In [17]:
print('#vocab =', len(vocab))

#vocab = 10000


In [39]:
n_epoch = 5   # number of epochs
n_units = 650  # number of units per layer
batchsize = 5   # minibatch size
bprop_len = 20   # length of truncated BPTT
grad_clip = 5    # gradient norm threshold to clip

In [40]:
# Prepare RNNLM model, defined in net.py
lm = RNNLM(len(vocab), n_units)
model = L.Classifier(lm)
model.compute_accuracy = False  # we only want the perplexity

In [41]:
for param in model.params():
    data = param.data
    data[:] = np.random.uniform(-0.1, 0.1, data.shape)

In [42]:
# Setup optimizer
optimizer = optimizers.SGD(lr=1.)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.GradientClipping(grad_clip))

In [43]:
def evaluate(dataset):
    # Evaluation routine
    evaluator = model.copy()  # to use different state
    evaluator.predictor.reset_state()  # initialize state

    sum_log_perp = 0
    for i in six.moves.range(dataset.size - 1):
        x = chainer.Variable(np.asarray(dataset[i:i + 1]), volatile='on')
        t = chainer.Variable(np.asarray(dataset[i + 1:i + 2]), volatile='on')
        loss = evaluator(x, t)
        sum_log_perp += loss.data
    return math.exp(float(sum_log_perp) / (dataset.size - 1))

In [45]:
# Learning loop
whole_len = train_data.shape[0]
jump = whole_len // batchsize
cur_log_perp = np.zeros(())
epoch = 0
start_at = time.time()
cur_at = start_at
accum_loss = 0
batch_idxs = list(range(batchsize))
print('going to train {} iterations'.format(jump * n_epoch))

going to train 929585 iterations


In [46]:
for i in six.moves.range(jump * n_epoch):
    x = chainer.Variable(np.asarray(
        [train_data[(jump * j + i) % whole_len] for j in batch_idxs]))
    t = chainer.Variable(np.asarray(
        [train_data[(jump * j + i + 1) % whole_len] for j in batch_idxs]))
    loss_i = model(x, t)
    accum_loss += loss_i
    cur_log_perp += loss_i.data

    if (i + 1) % bprop_len == 0:  # Run truncated BPTT
        model.zerograds()
        accum_loss.backward()
        accum_loss.unchain_backward()  # truncate
        accum_loss = 0
        optimizer.update()

    if (i + 1) % 10000 == 0:
        now = time.time()
        throuput = 10000. / (now - cur_at)
        perp = math.exp(float(cur_log_perp) / 10000)
        print('iter {} training perplexity: {:.2f} ({:.2f} iters/sec)'.format(
            i + 1, perp, throuput))
        cur_at = now
        cur_log_perp.fill(0)

    if (i + 1) % jump == 0:
        epoch += 1
        print('evaluate')
        now = time.time()
        perp = evaluate(valid_data)
        print('epoch {} validation perplexity: {:.2f}'.format(epoch, perp))
        cur_at += time.time() - now  # skip time of evaluation

        if epoch >= 6:
            optimizer.lr /= 1.2
            print('learning rate =', optimizer.lr)

    sys.stdout.flush()

iter 10000 training perplexity: nan (14.40 iters/sec)


KeyboardInterrupt: 

In [None]:
# Evaluate on test dataset
print('test')
test_perp = evaluate(test_data)
print('test perplexity:', test_perp)

In [None]:
# Save the model and the optimizer
print('save the model')
serializers.save_npz('rnnlm.model', model)
print('save the optimizer')
serializers.save_npz('rnnlm.state', optimizer)