In [14]:
import os
import numpy as np
import mxnet as mx

from bucket_io import BucketSentenceIter, default_build_vocab, SimpleBatch

data_dir = 'data'

def Perplexity(label, pred):
    """ Calculates prediction perplexity
    Args:
        label (mx.nd.array): labels array
        pred (mx.nd.array): prediction array
    Returns:
        float: calculated perplexity
    """
    
#     print(label.shape)
#     print(pred.shape)
    
    # collapse the time, batch dimension
    label = label.reshape((-1,))
    pred = pred.reshape((-1, pred.shape[-1]))
    
    loss = 0.
    for i in range(pred.shape[0]):
        loss += -np.log(max(1e-10, pred[i][int(label[i])]))
    return np.exp(loss / label.size)

In [4]:
batch_size = 32
buckets = [6, 11, 16, 21, 26, 31, 36, 41]
num_hidden = 200
num_embed = 200
num_lstm_layer = 2

num_epoch = 20
learning_rate = 0.01
momentum = 0.9

# Update count per available GPUs
gpu_count = 1
contexts = [mx.context.gpu(i) for i in range(gpu_count)]
# contexts = mx.cpu()
vocab = default_build_vocab(os.path.join(data_dir, 'path_train.txt'))

In [5]:
vocab.append(0)

len(vocab)

214

In [60]:
init_h = [mx.io.DataDesc('LSTM_state', (num_lstm_layer, batch_size, num_hidden), layout='TNC')]
init_c = [mx.io.DataDesc('LSTM_state_cell', (num_lstm_layer, batch_size, num_hidden), layout='TNC')]
init_states = init_c + init_h


data_train = BucketSentenceIter(os.path.join(data_dir, 'path_train.txt'),
                                vocab, buckets, batch_size, init_states,
                                time_major=True)

data_val = BucketSentenceIter(os.path.join(data_dir, 'path_val.txt'),
                              vocab, buckets, batch_size, init_states,
                              time_major=True)

def sym_gen(seq_len):
    """ Generates the MXNet symbol for the RNN
    Args:
        seq_len (int): input sequence length
    Returns:
        tuple: tuple containing symbol, data_names, label_names
    """
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('softmax_label')
    embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
                             output_dim=num_embed, name='embed')

    # TODO(tofix)
    # currently all the LSTM parameters are concatenated as
    # a huge vector, and named '<name>_parameters'. By default
    # mxnet initializer does not know how to initilize this
    # guy because its name does not ends with _weight or _bias
    # or anything familiar. Here we just use a temp workaround
    # to create a variable and name it as LSTM_bias to get
    # this demo running. Note by default bias is initialized
    # as zeros, so this is not a good scheme. But calling it
    # LSTM_weight is not good, as this is 1D vector, while
    # the initialization scheme of a weight parameter needs
    # at least two dimensions.
    rnn_params = mx.sym.Variable('LSTM_bias')

    # RNN cell takes input of shape (time, batch, feature)
    rnn = mx.sym.RNN(data=embed, state_size=num_hidden,
                     num_layers=num_lstm_layer, mode='lstm',
                     name='LSTM',
                     # The following params can be omitted
                     # provided we do not need to apply the
                     # workarounds mentioned above
                     parameters=rnn_params)

    # the RNN cell output is of shape (time, batch, dim)
    # if we need the states and cell states in the last time
    # step (e.g. when building encoder-decoder models), we
    # can set state_outputs=True, and the RNN cell will have
    # extra outputs: rnn['LSTM_output'], rnn['LSTM_state']
    # and for LSTM, also rnn['LSTM_state_cell']

    # now we collapse the time and batch dimension to do the
    # final linear logistic regression prediction
    hidden = mx.sym.Reshape(data=rnn, shape=(-1, num_hidden))

    pred = mx.sym.FullyConnected(data=hidden, num_hidden=len(vocab),
                                 name='pred')

    # reshape to be of compatible shape as labels
    pred_tm = mx.sym.Reshape(data=pred, shape=(seq_len, -1, len(vocab)))

    sm = mx.sym.SoftmaxOutput(data=pred_tm, label=label, preserve_shape=True,
                              name='softmax')

    data_names = ['data', 'LSTM_state', 'LSTM_state_cell']
    label_names = ['softmax_label']

    return sm, data_names, label_names

if len(buckets) == 1:
    mod = mx.mod.Module(*sym_gen(buckets[0]), context=contexts)
else:
    mod = mx.mod.BucketingModule(sym_gen,
                                 default_bucket_key=data_train.default_bucket_key,
                                 context=contexts)

bucket of len   6 : 2360 samples
bucket of len  11 : 5436 samples
bucket of len  16 : 6162 samples
bucket of len  21 : 4040 samples
bucket of len  26 : 2162 samples
bucket of len  31 : 965 samples
bucket of len  36 : 187 samples
bucket of len  41 : 16 samples
bucket of len   6 : 308 samples
bucket of len  11 : 731 samples
bucket of len  16 : 800 samples
bucket of len  21 : 561 samples
bucket of len  26 : 309 samples
bucket of len  31 : 108 samples
bucket of len  36 : 23 samples
bucket of len  41 : 4 samples


In [62]:
import logging

head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

# TODO: add epoch_end_callback
mod.fit(data_train, eval_data=data_val, num_epoch=20,
        eval_metric=mx.metric.np(Perplexity),
        batch_end_callback=mx.callback.Speedometer(batch_size, 50),
        initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
        optimizer='sgd',
        optimizer_params={'learning_rate': learning_rate,
                          'momentum': momentum, 'wd': 0.00001})

2018-02-27 13:33:37,945 Already bound, ignoring bind()
2018-02-27 13:33:37,947 optimizer already initialized, ignoring.
2018-02-27 13:33:38,235 Epoch[0] Batch [50]	Speed: 5649.66 samples/sec	Perplexity=67.972870
2018-02-27 13:33:38,514 Epoch[0] Batch [100]	Speed: 5780.47 samples/sec	Perplexity=70.640827
2018-02-27 13:33:38,825 Epoch[0] Batch [150]	Speed: 5179.75 samples/sec	Perplexity=70.425884
2018-02-27 13:33:39,214 Epoch[0] Batch [200]	Speed: 4135.93 samples/sec	Perplexity=70.993498
2018-02-27 13:33:39,489 Epoch[0] Batch [250]	Speed: 5842.64 samples/sec	Perplexity=72.166677
2018-02-27 13:33:39,673 Epoch[0] Batch [300]	Speed: 8827.84 samples/sec	Perplexity=71.373799
2018-02-27 13:33:39,869 Epoch[0] Batch [350]	Speed: 8498.03 samples/sec	Perplexity=72.537611
2018-02-27 13:33:40,058 Epoch[0] Batch [400]	Speed: 8635.72 samples/sec	Perplexity=75.066196
2018-02-27 13:33:40,231 Epoch[0] Batch [450]	Speed: 9367.29 samples/sec	Perplexity=73.479239
2018-02-27 13:33:40,419 Epoch[0] Batch [500]

2018-02-27 13:33:56,037 Epoch[5] Train-Perplexity=91.274512
2018-02-27 13:33:56,039 Epoch[5] Time cost=2.627
2018-02-27 13:33:56,389 Epoch[5] Validation-Perplexity=73.947179
2018-02-27 13:33:56,566 Epoch[6] Batch [50]	Speed: 9244.38 samples/sec	Perplexity=67.529894
2018-02-27 13:33:56,739 Epoch[6] Batch [100]	Speed: 9413.58 samples/sec	Perplexity=68.645420
2018-02-27 13:33:56,949 Epoch[6] Batch [150]	Speed: 7669.15 samples/sec	Perplexity=68.382595
2018-02-27 13:33:57,199 Epoch[6] Batch [200]	Speed: 6513.65 samples/sec	Perplexity=69.015512
2018-02-27 13:33:57,395 Epoch[6] Batch [250]	Speed: 8292.91 samples/sec	Perplexity=70.410132
2018-02-27 13:33:57,574 Epoch[6] Batch [300]	Speed: 9007.44 samples/sec	Perplexity=68.635914
2018-02-27 13:33:57,763 Epoch[6] Batch [350]	Speed: 8533.15 samples/sec	Perplexity=70.627293
2018-02-27 13:33:57,948 Epoch[6] Batch [400]	Speed: 8756.86 samples/sec	Perplexity=72.991670
2018-02-27 13:33:58,126 Epoch[6] Batch [450]	Speed: 9082.92 samples/sec	Perplexity=

2018-02-27 13:34:14,135 Epoch[11] Batch [650]	Speed: 9044.20 samples/sec	Perplexity=70.749973
2018-02-27 13:34:14,172 Epoch[11] Train-Perplexity=86.996012
2018-02-27 13:34:14,174 Epoch[11] Time cost=2.826
2018-02-27 13:34:14,479 Epoch[11] Validation-Perplexity=74.063836
2018-02-27 13:34:14,656 Epoch[12] Batch [50]	Speed: 9330.04 samples/sec	Perplexity=66.527801
2018-02-27 13:34:14,886 Epoch[12] Batch [100]	Speed: 7028.06 samples/sec	Perplexity=67.234855
2018-02-27 13:34:15,105 Epoch[12] Batch [150]	Speed: 7391.84 samples/sec	Perplexity=67.410690
2018-02-27 13:34:15,301 Epoch[12] Batch [200]	Speed: 8298.07 samples/sec	Perplexity=67.709186
2018-02-27 13:34:15,481 Epoch[12] Batch [250]	Speed: 8963.48 samples/sec	Perplexity=69.160155
2018-02-27 13:34:15,662 Epoch[12] Batch [300]	Speed: 8935.99 samples/sec	Perplexity=67.530514
2018-02-27 13:34:15,851 Epoch[12] Batch [350]	Speed: 8594.67 samples/sec	Perplexity=69.341479
2018-02-27 13:34:16,039 Epoch[12] Batch [400]	Speed: 8667.53 samples/sec

2018-02-27 13:34:31,602 Epoch[17] Batch [550]	Speed: 7908.54 samples/sec	Perplexity=69.426898
2018-02-27 13:34:31,770 Epoch[17] Batch [600]	Speed: 9644.71 samples/sec	Perplexity=66.368258
2018-02-27 13:34:31,955 Epoch[17] Batch [650]	Speed: 8765.11 samples/sec	Perplexity=70.395892
2018-02-27 13:34:31,994 Epoch[17] Train-Perplexity=84.033655
2018-02-27 13:34:31,996 Epoch[17] Time cost=2.609
2018-02-27 13:34:32,315 Epoch[17] Validation-Perplexity=74.429727
2018-02-27 13:34:32,524 Epoch[18] Batch [50]	Speed: 7846.75 samples/sec	Perplexity=66.125691
2018-02-27 13:34:32,860 Epoch[18] Batch [100]	Speed: 4815.85 samples/sec	Perplexity=66.396661
2018-02-27 13:34:33,119 Epoch[18] Batch [150]	Speed: 6405.92 samples/sec	Perplexity=66.254964
2018-02-27 13:34:33,308 Epoch[18] Batch [200]	Speed: 8570.08 samples/sec	Perplexity=66.912909
2018-02-27 13:34:33,494 Epoch[18] Batch [250]	Speed: 8661.24 samples/sec	Perplexity=68.508803
2018-02-27 13:34:33,681 Epoch[18] Batch [300]	Speed: 8688.10 samples/sec

In [None]:
# TODO: add the inference model for generating the next nodes

In [28]:
data_val = BucketSentenceIter(os.path.join(data_dir, 'path_val.txt'),
                              vocab, buckets, batch_size, init_states,
                              time_major=True)

cnt = 0
for batch in data_val:
    print(type(batch))
    print((batch.data))
    print((batch.label))
    
print(cnt)
    

bucket of len   6 : 308 samples
bucket of len  11 : 731 samples
bucket of len  16 : 800 samples
bucket of len  21 : 561 samples
bucket of len  26 : 309 samples
bucket of len  31 : 108 samples
bucket of len  36 : 23 samples
bucket of len  41 : 4 samples
<class 'bucket_io.SimpleBatch'>
[
[[ 181.  114.   34.  126.  187.   61.   95.   87.   42.   45.   33.   31.
   200.  164.   22.  141.  147.  125.  109.   71.  155.  134.   96.  211.
    44.   82.   66.  125.   50.   46.   32.   51.]
 [  55.  181.   97.  189.   55.   81.  141.  183.  199.  198.   22.  195.
   127.   36.  122.  111.  188.  123.   70.  161.  204.  132.  167.  176.
   130.   47.   71.  123.   69.  141.  130.   76.]
 [ 135.  116.  168.   91.  135.  150.  111.  180.   73.  130.   77.    1.
   202.  126.   11.  209.   48.   53.  104.  133.   78.   22.  159.  133.
   197.  138.  161.   53.   85.  111.  197.   19.]
 [  96.   66.  171.  133.   96.  184.  142.  200.  186.  197.   84.  103.
   131.  189.  169.   70.  202.  178.  120

[
[[  37.  117.   54.   58.  199.   59.  108.   70.  190.  131.   95.  165.
   209.  127.  124.    8.  121.  141.  184.  171.   10.  184.   78.   38.
    12.   74.   13.  201.   42.   64.  105.  204.]
 [  24.  183.  140.  140.   73.  124.  143.  104.   59.   38.  141.   23.
    70.  202.  145.   14.   37.  111.   98.   37.  166.   98.  165.  108.
    50.  116.   96.   67.  199.    2.   60.   78.]
 [  73.  180.  144.  144.  186.  145.   44.  120.  124.  108.  111.   42.
   104.  131.  112.  126.   24.  142.   20.   24.   98.   20.   23.  143.
   118.   66.  135.   94.   73.   80.   93.  165.]
 [ 186.  200.  190.  190.  211.  112.  130.   16.  145.  143.  142.  199.
   120.   38.   35.   36.   73.   89.   80.   73.   20.   80.   42.   44.
    85.   71.   55.   91.  186.    9.  187.   23.]
 [ 211.  127.   59.   59.  210.   35.  197.  128.  112.   44.   89.   73.
    99.  108.   63.  164.  186.  186.    9.  186.   80.    9.  199.  130.
    61.  197.  181.  133.  211.  122.   82.   42.]
 [ 

[
[[ 149.  110.  108.  134.   32.   36.   89.  165.  201.  162.  193.  176.
   162.   91.  202.   75.  178.  164.  179.  110.   65.  202.  134.    4.
   211.    2.  184.    8.   51.  157.  104.  209.]
 [   7.  178.  143.  132.  130.  164.  186.   23.   51.   28.    5.  170.
    94.  189.  131.   19.   62.   36.  176.  178.   33.  131.  132.  201.
   210.   80.  150.    5.    7.   33.  120.  142.]
 [  19.   62.   44.   22.  197.  138.  211.   42.    7.   76.    8.  197.
    91.  126.   38.    7.    5.   99.  133.   53.   22.    3.    1.   51.
    95.    9.   81.   62.  149.   22.   16.   89.]
 [  75.    5.  130.   77.   71.   47.  210.  199.  149.   19.   14.  130.
   133.   36.  108.   51.    8.  120.   91.  159.   77.   21.  103.    7.
    46.   77.   12.  178.   82.   77.  128.  186.]
 [ 161.    8.  197.   84.  161.   82.   95.   73.   82.    7.  126.   44.
   176.  164.  143.  201.   14.   16.   94.  167.   84.  206.  185.  149.
   118.   84.   50.   53.  187.   84.   41.  211.]
 [ 

[
[[  99.   73.    9.  193.   16.   36.   67.  128.   36.  128.   53.   44.
   210.  124.    5.  104.   61.  159.   80.   66.  103.    5.    9.  143.
    60.  130.  149.    1.   84.  138.  178.    7.]
 [  36.  186.   77.   31.  104.  126.  201.   41.  126.   16.  159.  130.
    95.  145.    8.  120.   81.  167.    9.  116.    1.  193.   77.  108.
   105.  197.   82.  132.  110.   47.   62.   19.]
 [ 164.  211.   84.  195.  120.   14.   51.  154.   14.  104.  167.  197.
    46.  112.   14.   99.  150.   96.  122.  181.  195.   31.   84.   38.
   115.   71.  187.   22.  178.   82.    5.   75.]
 [ 138.  176.  110.    1.   99.    8.    7.   90.    8.   70.   96.   71.
   118.   35.   67.   36.  184.  135.   11.   55.   31.  195.  110.    3.
    29.  161.   55.   77.   53.  187.    8.  161.]
 [  47.  133.  178.  132.   36.    5.  149.  185.    5.  209.  135.  161.
    85.   63.  201.  126.   98.   55.  169.  135.  193.    1.  178.   21.
    63.  133.  135.   84.  159.   55.   14.  133.]
 [ 

[
[[ 176.  170.   53.  202.  141.  128.   94.  159.   94.  162.   36.    8.
    33.   28.    9.  181.  164.  178.  116.  178.  181.  142.  105.   55.
   166.   91.  123.  105.   22.  126.   74.  110.]
 [ 170.  197.  178.  131.  111.   16.   91.  167.   91.   94.  126.    5.
    22.   76.   77.  116.   36.   62.   66.   53.   55.   89.  115.  181.
    98.  133.   53.   60.   77.   14.  116.  178.]
 [ 161.   71.   62.    3.  142.  104.  133.   96.  133.   91.   14.  193.
    77.   19.   84.   66.  126.    5.   71.  159.  105.  186.   29.  116.
    20.  176.  178.   93.   84.    8.  181.   62.]
 [  75.   66.    5.   21.   89.  120.  176.  135.  176.  133.    8.   31.
    84.   75.  110.   71.   14.    8.  161.  167.  115.  211.   63.   66.
    80.  170.   62.  187.  110.    5.   55.    5.]
 [  19.  116.    8.  206.  186.   99.  170.   55.  170.  210.    5.  195.
   110.  161.  178.  197.    8.   14.  133.   96.   29.  176.   35.   71.
     9.  197.    5.   55.  178.  193.  105.    8.]
 [ 

[
[[ 190.   11.  169.   59.  144.  206.  166.   35.   59.  190.    9.  110.
   118.   87.  166.   58.  140.   63.  140.  183.  109.  103.  206.   41.
   184.  140.   21.  145.   77.   48.  204.  117.]
 [  59.  169.  103.  124.  190.   87.  184.   63.  124.   59.   77.  178.
    46.  183.  184.  140.  144.   29.  144.  180.   70.    1.   87.  128.
   150.  144.  206.  112.   84.  202.   78.  183.]
 [ 124.  103.    1.  145.   59.  183.  150.   29.  145.  124.   84.   62.
   141.  180.  150.  144.  190.  115.  190.  200.  104.  195.  183.   16.
    81.  190.   87.   35.  110.  131.  165.  180.]
 [ 145.    1.  195.  112.  124.  180.   81.  115.  112.  145.  110.    5.
   111.  200.   81.  190.   59.  105.   59.  127.  120.   31.  180.  104.
    12.   59.  183.   63.  178.   38.   23.  200.]
 [ 112.  195.   31.   35.  145.  200.   12.  105.   35.  112.  178.    8.
   142.  127.   12.   59.  124.   60.  124.  202.   99.  193.  200.   70.
    50.  124.  180.   29.   62.  108.   42.  127.]
 [ 

[
[[ 169.  197.  125.  120.  165.   62.   38.  112.   45.   63.  101.   97.
   134.   95.  166.  167.   10.  105.   62.   77.   46.   21.  186.  187.
   110.   73.  110.   45.  154.  209.   65.  111.]
 [ 103.   71.  123.   99.   23.    5.  108.   35.  198.   29.  198.  168.
   132.  141.   98.  159.  166.   60.    5.   65.  141.  206.  211.   55.
   178.  186.  178.  198.   41.   70.   33.  209.]
 [   1.  161.   53.   36.   42.    8.  143.   63.  130.  115.  130.  171.
    22.  111.   20.   53.   98.   93.    8.   33.  111.   87.  210.  135.
    53.  211.   53.  130.  128.  104.   22.   70.]
 [ 195.  133.  178.  126.  199.   14.   44.   29.  197.  105.  197.   37.
    77.  209.   11.  178.   20.  187.   14.   22.  142.  183.   95.   96.
   159.  176.  159.  197.   16.  120.   77.  104.]
 [  31.  210.   62.  189.   73.   67.  130.  115.   71.   60.   71.   24.
    84.   70.  169.   62.   11.   82.  126.  122.   89.  180.   46.  167.
   167.  133.  167.   71.  104.   16.   84.  120.]
 [ 

[
[[ 131.  204.  143.  103.  121.    1.   78.   95.    4.  181.   71.   10.
   120.  185.   71.    4.  176.    2.   11.   29.  133.   84.   41.  116.
    62.   65.   47.   62.   96.  185.  161.  189.]
 [  38.   78.   44.    1.   37.  103.  165.  141.   28.   55.  161.  166.
    16.   90.  161.  201.  133.   80.  169.  115.  176.  110.  128.   66.
   178.   33.  138.    5.  135.  103.  133.  126.]
 [ 108.  165.  130.  195.   24.  185.   23.  111.  162.  135.  133.   98.
   128.  154.  133.   67.  210.    9.   10.  105.  170.  178.   16.   71.
    53.   22.  164.    8.   55.    1.   91.   36.]
 [ 143.   23.  197.   31.   73.   90.   42.  209.   94.   96.   91.   20.
    41.   41.  210.   14.   95.  122.  166.   60.  197.   62.  104.  161.
   159.  122.   36.   14.  181.  132.   94.   99.]
 [  44.   42.   71.  193.  186.  154.  199.   70.   91.  167.   94.   11.
   154.  128.   95.    8.   46.   11.  184.   93.   71.    5.  120.  133.
   167.   11.   99.  126.  116.   22.   67.  120.]
 [ 

[
[[ 209.   33.   97.  124.  180.   87.  190.   26.  165.  204.  105.  109.
   115.   48.  140.  206.  124.   59.   61.  150.    5.  144.   73.  115.
    48.   33.   93.   59.  141.   82.  165.   33.]
 [  70.   22.  168.  145.  200.  183.   59.  180.   23.   78.   60.   70.
   105.  202.  144.   87.  145.  124.   12.   81.   62.  190.  186.  105.
   202.   22.  187.  124.  111.  187.   23.   22.]
 [ 104.  122.  171.  112.  127.  180.  124.  200.   42.  165.   93.  104.
    60.  131.  190.  183.  112.  145.   50.   12.  178.   59.  211.   60.
   131.  122.   55.  145.  142.   55.   42.  122.]
 [ 120.   11.   37.   35.  202.  200.  145.  127.  199.   23.  187.  120.
    93.   38.   59.  180.   35.  112.  118.   50.   53.  124.  176.   93.
    38.   11.  181.  112.   89.  181.  199.   11.]
 [  99.  169.   24.   63.  131.  127.  112.  202.   73.   42.   55.   99.
   187.  108.  124.  200.   63.   35.   46.  118.  159.  145.  170.  187.
   108.  169.  116.   35.  186.  116.   73.  169.]
 [ 

[
[[  13.  147.   35.   98.  206.   23.  164.  184.  104.   42.  143.  187.
   142.  112.   46.   97.  135.   38.  163.  172.  178.   98.   76.  123.
   194.  168.   91.  171.  206.  164.  199.  120.]
 [  96.  100.   63.   20.   87.   42.   36.  150.   70.  199.   44.   55.
    89.   35.  141.  168.   55.  108.   26.   66.   62.   20.   19.   53.
   117.  171.  189.   37.   87.  138.   73.   99.]
 [ 135.   78.   29.   80.  183.  199.  126.   81.  209.   73.  130.  181.
   186.   63.  111.  171.  181.  143.  180.   71.    5.   80.   75.  159.
   183.   37.  126.   24.  183.   47.  186.   36.]
 [  55.  165.  115.    9.  180.   73.  189.   12.  142.  186.  197.  116.
   211.   29.  142.   37.  116.   44.  200.  197.    8.    9.  161.  167.
   180.   24.   36.   73.  180.   82.  211.  164.]
 [ 181.   23.  105.   77.  200.  186.   91.   50.   89.  211.   71.   66.
   176.  115.   89.   24.   66.  130.  127.  130.   14.   77.  170.   96.
   200.   73.  164.  186.  200.  187.  176.  138.]
 [ 

[
[[  61.  127.   95.  187.   73.  131.  180.   96.   63.  176.   42.  147.
    55.  127.   42.   93.   21.  116.   33.  147.   80.   78.   24.   81.
   176.  164.   78.   61.  180.   82.   26.  140.]
 [  81.  202.   46.   55.  186.   38.  200.  135.   29.  170.  199.  100.
   181.  202.  199.  187.  206.   66.   22.  100.    9.  165.   73.   12.
   170.   36.  165.   12.  200.  149.  180.  144.]
 [ 150.  131.  118.  181.  211.  108.  127.   55.  115.  197.   73.   78.
   116.  131.   73.   82.   87.   71.  122.   78.   77.   23.  186.   50.
   197.   99.   23.   50.  127.    7.  200.  190.]
 [ 184.   38.   85.  116.  176.  143.  202.  181.  105.   71.  186.  165.
    66.   38.  186.  149.  183.  161.   11.  165.   84.   42.  211.  118.
    71.  120.   42.  118.  202.   51.  127.   59.]
 [  98.  108.   61.   66.  170.   44.  131.  116.   60.   66.  211.   23.
    71.  108.  211.    7.  180.  133.  169.   23.  110.  199.  176.   46.
    66.   16.  199.   46.  131.  201.  202.  124.]
 [ 

In [64]:
init_h = [mx.io.DataDesc('LSTM_state', (num_lstm_layer, 1, num_hidden), layout='TNC')]
init_c = [mx.io.DataDesc('LSTM_state_cell', (num_lstm_layer, 1, num_hidden), layout='TNC')]
init_states = init_c + init_h

In [65]:
['data'] + [x[0] for x in init_states]

['data', 'LSTM_state_cell', 'LSTM_state']

In [66]:
data = mx.nd.zeros((6, 1))
label = mx.nd.zeros((6, 1))
data_all = [mx.nd.array(data)] + [mx.nd.zeros(x[1]) for x in init_states]
label_all = [mx.nd.array(label)]

In [71]:
batch = SimpleBatch(data=data_all, label=label_all,
            data_names=['data', 'LSTM_state_cell', 'LSTM_state'], 
            data_layouts=['TN', 'TNC', 'TNC'],
            label_names=['softmax_label'],
            label_layouts=['TN'],
            bucket_key=6)

In [72]:
mod.forward(batch)

In [98]:
np.argmax(mod.get_outputs()[0].asnumpy()[0, 0, :], axis=0)

130

In [118]:
mod.

[
 [[[  3.63800535e-03   1.00255553e-02   5.02926996e-03 ...,   5.17999148e-03
      2.60499131e-04   4.59771603e-04]]
 
  [[  5.57415746e-03   1.30859939e-02   1.09013659e-03 ...,   1.17595214e-02
      5.51762678e-05   1.51559885e-04]]
 
  [[  1.31021151e-02   7.19729904e-03   5.83784713e-04 ...,   1.62128974e-02
      5.80270316e-05   6.49125213e-05]]
 
  [[  1.88673493e-02   3.40263848e-03   9.63349943e-04 ...,   1.49647761e-02
      1.21328609e-04   4.95084132e-05]]
 
  [[  2.95093488e-02   4.20263177e-03   1.51370210e-03 ...,   1.99012626e-02
      1.89775106e-04   3.59318547e-05]]
 
  [[  4.37201411e-02   4.07776190e-03   2.15858687e-03 ...,   1.81674566e-02
      3.21687519e-04   3.34183387e-05]]]
 <NDArray 6x1x214 @gpu(0)>]

In [125]:
mod.symbol

<Symbol softmax>