In [41]:
import os
import numpy as np
import mxnet as mx

from bucket_io import BucketSentenceIter, default_build_vocab

data_dir = 'data'

def Perplexity(label, pred):
    """ Calculates prediction perplexity
    Args:
        label (mx.nd.array): labels array
        pred (mx.nd.array): prediction array
    Returns:
        float: calculated perplexity
    """
    # collapse the time, batch dimension
    label = label.reshape((-1,))
    pred = pred.reshape((-1, pred.shape[-1]))
    
    
    
    loss = 0.
    for i in range(pred.shape[0]):
        loss += -np.log(max(1e-10, pred[i][int(label[i])]))
    return np.exp(loss / label.size)

In [42]:
batch_size = 128
buckets = [11, 21, 31, 41]
num_hidden = 200
num_embed = 200
num_lstm_layer = 2

num_epoch = 2
learning_rate = 0.01
momentum = 0.0

# Update count per available GPUs
gpu_count = 1
contexts = [mx.context.gpu(i) for i in range(gpu_count)]
# contexts = mx.cpu()
vocab = default_build_vocab(os.path.join(data_dir, 'path_train.txt'))

In [43]:
vocab.append(0)

len(vocab)

214

In [44]:
init_h = [mx.io.DataDesc('LSTM_state', (num_lstm_layer, batch_size, num_hidden), layout='TNC')]
init_c = [mx.io.DataDesc('LSTM_state_cell', (num_lstm_layer, batch_size, num_hidden), layout='TNC')]
init_states = init_c + init_h

data_train = BucketSentenceIter(os.path.join(data_dir, 'path_train.txt'),
                                vocab, buckets, batch_size, init_states,
                                time_major=True)
data_val = BucketSentenceIter(os.path.join(data_dir, 'path_val.txt'),
                              vocab, buckets, batch_size, init_states,
                              time_major=True)

def sym_gen(seq_len):
    """ Generates the MXNet symbol for the RNN
    Args:
        seq_len (int): input sequence length
    Returns:
        tuple: tuple containing symbol, data_names, label_names
    """
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('softmax_label')
    embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
                             output_dim=num_embed, name='embed')

    # TODO(tofix)
    # currently all the LSTM parameters are concatenated as
    # a huge vector, and named '<name>_parameters'. By default
    # mxnet initializer does not know how to initilize this
    # guy because its name does not ends with _weight or _bias
    # or anything familiar. Here we just use a temp workaround
    # to create a variable and name it as LSTM_bias to get
    # this demo running. Note by default bias is initialized
    # as zeros, so this is not a good scheme. But calling it
    # LSTM_weight is not good, as this is 1D vector, while
    # the initialization scheme of a weight parameter needs
    # at least two dimensions.
    rnn_params = mx.sym.Variable('LSTM_bias')

    # RNN cell takes input of shape (time, batch, feature)
    rnn = mx.sym.RNN(data=embed, state_size=num_hidden,
                     num_layers=num_lstm_layer, mode='lstm',
                     name='LSTM',
                     # The following params can be omitted
                     # provided we do not need to apply the
                     # workarounds mentioned above
                     parameters=rnn_params)

    # the RNN cell output is of shape (time, batch, dim)
    # if we need the states and cell states in the last time
    # step (e.g. when building encoder-decoder models), we
    # can set state_outputs=True, and the RNN cell will have
    # extra outputs: rnn['LSTM_output'], rnn['LSTM_state']
    # and for LSTM, also rnn['LSTM_state_cell']

    # now we collapse the time and batch dimension to do the
    # final linear logistic regression prediction
    hidden = mx.sym.Reshape(data=rnn, shape=(-1, num_hidden))

    pred = mx.sym.FullyConnected(data=hidden, num_hidden=len(vocab),
                                 name='pred')

    # reshape to be of compatible shape as labels
    pred_tm = mx.sym.Reshape(data=pred, shape=(seq_len, -1, len(vocab)))

    sm = mx.sym.SoftmaxOutput(data=pred_tm, label=label, preserve_shape=True,
                              name='softmax')

    data_names = ['data', 'LSTM_state', 'LSTM_state_cell']
    label_names = ['softmax_label']

    return sm, data_names, label_names

if len(buckets) == 1:
    mod = mx.mod.Module(*sym_gen(buckets[0]), context=contexts)
else:
    mod = mx.mod.BucketingModule(sym_gen,
                                 default_bucket_key=data_train.default_bucket_key,
                                 context=contexts)

bucket of len  11 : 7796 samples
bucket of len  21 : 10202 samples
bucket of len  31 : 3127 samples
bucket of len  41 : 203 samples
bucket of len  11 : 1039 samples
bucket of len  21 : 1361 samples
bucket of len  31 : 417 samples
bucket of len  41 : 27 samples


In [45]:
import logging

head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

mod.fit(data_train, eval_data=data_val, num_epoch=num_epoch,
        eval_metric=mx.metric.np(Perplexity),
        batch_end_callback=mx.callback.Speedometer(batch_size, 50),
        initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
        optimizer='sgd',
        optimizer_params={'learning_rate': learning_rate,
                          'momentum': momentum, 'wd': 0.00001})

2018-02-23 14:13:05,055 Epoch[0] Batch [50]	Speed: 7655.35 samples/sec	Perplexity=123.461754
2018-02-23 14:13:05,380 Epoch[0] Batch [100]	Speed: 9901.76 samples/sec	Perplexity=62.459987
2018-02-23 14:13:05,758 Epoch[0] Batch [150]	Speed: 8509.26 samples/sec	Perplexity=60.688413
2018-02-23 14:13:06,135 Epoch[0] Batch [200]	Speed: 8511.43 samples/sec	Perplexity=54.228900
2018-02-23 14:13:06,474 Epoch[0] Batch [250]	Speed: 9492.28 samples/sec	Perplexity=50.319478
2018-02-23 14:13:06,845 Epoch[0] Batch [300]	Speed: 8684.47 samples/sec	Perplexity=56.297255
2018-02-23 14:13:07,039 Epoch[0] Train-Perplexity=51.318953
2018-02-23 14:13:07,040 Epoch[0] Time cost=2.434
2018-02-23 14:13:07,313 Epoch[0] Validation-Perplexity=55.509679
2018-02-23 14:13:07,684 Epoch[1] Batch [50]	Speed: 8804.51 samples/sec	Perplexity=57.895443
2018-02-23 14:13:08,045 Epoch[1] Batch [100]	Speed: 8918.52 samples/sec	Perplexity=51.374932
2018-02-23 14:13:08,430 Epoch[1] Batch [150]	Speed: 8349.27 samples/sec	Perplexity=

In [46]:
data_test = BucketSentenceIter(os.path.join(data_dir, 'path_test.txt'),
                              vocab, buckets, batch_size, init_states,
                              time_major=True)

bucket of len  11 : 1559 samples
bucket of len  21 : 2040 samples
bucket of len  31 : 625 samples
bucket of len  41 : 41 samples


In [47]:
perplexity = mx.metric.np(Perplexity)
mod.score(data_test, perplexity)
print(perplexity)

EvalMetric: {'Perplexity': 55.544038522751499}
