In [1]:
import warnings
warnings.filterwarnings('ignore')

import random
import time
import multiprocessing as mp
import numpy as np

import mxnet as mx
from mxnet import nd, gluon, autograd

import gluonnlp as nlp
import pickle

random.seed(123)
np.random.seed(123)
mx.random.seed(123)

In [2]:
train_dataset = pickle.load(open('../data/dev_processed.p', 'rb'))
test_dataset = pickle.load(open('../data/dev_processed.p', 'rb'))

In [3]:
vocabulary = {'<pad>': [0, 1], '<unk>': [1, 1], '<BOS>': [2, 1], '<EOS>': [3, 1]}
for item in train_dataset + test_dataset:
    words = item[2].split(' ')
    for word in words:
        if word in vocabulary:
            vocabulary[word][1] += 1
        else:
            vocabulary[word] = [len(vocabulary), 1]

In [4]:
def preprocess(x):
    name, audio, words = x
    split_words = ['<BOS>'] + words.split(' ') + ['<EOS>']
    return audio, np.array([vocabulary[word][0] for word in split_words]), float(len(audio)), float(len(split_words))

def get_length(x):
    return float(len(x[1]))

def preprocess_dataset(dataset):
    start = time.time()
    with mp.Pool() as pool:
        dataset = gluon.data.SimpleDataset(pool.map(preprocess, dataset))
        lengths = gluon.data.SimpleDataset(pool.map(get_length, dataset))
    end = time.time()
    print('Done! Processing Time={:.2f}s, #Samples={}'.format(end - start, len(dataset)))
    return dataset, lengths

In [5]:
train_dataset, train_data_lengths = preprocess_dataset(train_dataset)
test_dataset, test_data_lengths = preprocess_dataset(test_dataset)

Done! Processing Time=3.28s, #Samples=2703
Done! Processing Time=3.72s, #Samples=2703


In [39]:
learning_rate, batch_size = 0.005, 32
bucket_num, bucket_ratio = 10, 0.2
grad_clip = None
log_interval = 5

def get_dataloader():
    batchify_fn = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(dtype='float32'),
        nlp.data.batchify.Pad(dtype='float32'),
        nlp.data.batchify.Stack(dtype='float32'),
        nlp.data.batchify.Stack(dtype='float32'))
    batch_sampler = nlp.data.sampler.FixedBucketSampler(
        train_data_lengths,
        batch_size=batch_size,
        num_buckets=bucket_num,
        ratio=bucket_ratio,
        shuffle=True)
    print(batch_sampler.stats())

    train_dataloader = gluon.data.DataLoader(
        dataset=train_dataset,
        batch_sampler=batch_sampler,
        batchify_fn=batchify_fn)
    test_dataloader = gluon.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        batchify_fn=batchify_fn)
    return train_dataloader, test_dataloader

train_dataloader, test_dataloader = get_dataloader()

FixedBucketSampler:
  sample_num=2703, batch_num=82
  key=[13, 22, 31, 40, 49, 58, 67, 76, 85, 94]
  cnt=[837, 805, 531, 282, 120, 70, 30, 15, 8, 5]
  batch_size=[46, 32, 32, 32, 32, 32, 32, 32, 32, 32]


In [40]:
context = mx.cpu()

In [41]:
for i, example in enumerate(train_dataloader):
    if i >= 1:
        break
    print(example)

(
[[[ -5.684265    -5.481247   -31.677723   ... -21.189108     1.9651077
    -1.8781087 ]
  [ -5.8898716   -3.5039613  -30.943327   ... -28.498064    -4.1945057
   -10.528649  ]
  [ -6.200805    -1.0122288  -26.037184   ... -21.092169    -0.6559814
   -12.771611  ]
  ...
  [  0.           0.           0.         ...   0.           0.
     0.        ]
  [  0.           0.           0.         ...   0.           0.
     0.        ]
  [  0.           0.           0.         ...   0.           0.
     0.        ]]

 [[ -9.333745   -15.022931   -23.424726   ... -15.981082     1.5638912
    -2.4987006 ]
  [ -9.199183   -15.421494   -26.592268   ...  -7.841849     7.154292
     1.0559975 ]
  [ -9.145427   -15.609611   -27.623438   ...  -9.7564335   -0.18035188
     0.16993964]
  ...
  [  0.           0.           0.         ...   0.           0.
     0.        ]
  [  0.           0.           0.         ...   0.           0.
     0.        ]
  [  0.           0.           0.         ...   0. 

In [123]:
import numpy as np
import mxnet as mx
from io import open
from mxnet import gluon, autograd
from mxnet.gluon import nn, rnn, Block
from mxnet import ndarray as F

class SubSampler(gluon.HybridBlock):
    def __init__(self, size=3, prefix=None, params=None):
        super(SubSampler, self).__init__(prefix=prefix, params=params)
        self.size = size

    def forward(self, data, valid_length):
        masked_encoded = F.SequenceMask(data,
                                        sequence_length=valid_length,
                                        use_sequence_length=True)
        subsampled = F.Pooling(masked_encoded.swapaxes(0,2), kernel=(self.size), pool_type='max', stride=self.size).swapaxes(0,2)
        return subsampled

In [124]:
class EncoderRNN(Block):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        with self.name_scope():
            self.rnn1 = rnn.GRU(hidden_size, input_size=self.input_size)
            self.subsampler = SubSampler(size=3)
            self.rnn2 = rnn.GRU(hidden_size, input_size=self.hidden_size)

    def forward(self, input, hidden, lengths):
        input = input.swapaxes(0, 1)
        output, hidden = self.rnn1(input, hidden)
        subsampled = self.subsampler(output, lengths)
        output, hidden = self.rnn2(subsampled, hidden)
        return output, hidden

    def initHidden(self, batchsize, ctx):
        return [F.zeros((1, batchsize, self.hidden_size), ctx=ctx)]

In [125]:
class DecoderRNN(Block):
    def __init__(self, output_size, hidden_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        with self.name_scope():
            self.embedding = nn.Embedding(output_size, hidden_size)
            self.rnn = rnn.GRU(hidden_size, input_size=self.hidden_size)
            self.out = nn.Dense(output_size, in_units=self.hidden_size, flatten=False)

    def forward(self, input, hidden):
        output = self.embedding(input).swapaxes(0, 1)
        output, hidden = self.rnn(output, hidden)
        output = self.out(output)
        return output, hidden

    def initHidden(self, batchsize, ctx):
        return [F.zeros((1, batchsize, self.hidden_size), ctx=ctx)]

In [126]:
encoder = EncoderRNN(input_size=13, hidden_size=256)
decoder = DecoderRNN(hidden_size=256, output_size=34798)

encoder.initialize(mx.init.Xavier(), ctx=context)
decoder.initialize(mx.init.Xavier(), ctx=context)

def train(net, context, epochs):
    trainer = gluon.Trainer(net.collect_params(), 'ftml',
                            {'learning_rate': learning_rate})
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    parameters = net.collect_params().values()

    for epoch in range(epochs):
        start_epoch_time = time.time()
        epoch_L = 0.0
        epoch_sent_num = 0
        epoch_wc = 0

        start_log_interval_time = time.time()
        log_interval_wc = 0
        log_interval_sent_num = 0
        log_interval_L = 0.0

        for i, (audio, words, alength, wlength) in enumerate(train_dataloader):
            wc = alength.sum().asscalar()
            log_interval_wc += wc
            epoch_wc += wc
            log_interval_sent_num += audio.shape[1]
            epoch_sent_num += audio.shape[1]
            with autograd.record():
                encoder_hidden = encoder.initHidden(len(audio), context)
                encoder_outputs, encoder_hidden = encoder(audio.as_in_context(context), encoder_hidden, alength)
                # TODO: Get the last hidden state/attended state according to the lengths
                decoder_hidden = encoder_hidden
                decoder_outputs, decoder_hidden = decoder(words.as_in_context(context), decoder_hidden)
                decoder_outputs = decoder_outputs.swapaxes(0,1)
                L = loss(decoder_outputs, words.as_in_context(context)).sum()
            L.backward()
            
            if grad_clip:
                gluon.utils.clip_global_norm(
                    [p.grad(context) for p in parameters],
                    grad_clip)
            
            trainer.step(1)
            log_interval_L += L.asscalar()
            epoch_L += L.asscalar()
            if (i + 1) % log_interval == 0:
                print(
                    '[Epoch {} Batch {}/{}] elapsed {:.2f} s, '
                    'avg loss {:.6f}, throughput {:.2f}K fps'.format(
                        epoch, i + 1, len(train_dataloader),
                        time.time() - start_log_interval_time,
                        log_interval_L / log_interval_sent_num, log_interval_wc
                        / 1000 / (time.time() - start_log_interval_time)))
                start_log_interval_time = time.time()
                log_interval_wc = 0
                log_interval_sent_num = 0
                log_interval_L = 0
        end_epoch_time = time.time()
        test_avg_L, test_acc = evaluate(net, test_dataloader, context)
        print('[Epoch {}] train avg loss {:.6f}, test acc {:.2f}, '
              'test avg loss {:.6f}, throughput {:.2f}K fps'.format(
                  epoch, epoch_L / epoch_sent_num, test_acc, test_avg_L,
                  epoch_wc / 1000 / (end_epoch_time - start_epoch_time)))

In [127]:
train(encoder, context, 1)

[Epoch 0 Batch 5/82] elapsed 135.79 s, avg loss 0.311807, throughput 0.79K fps


KeyboardInterrupt: 