In [1]:
import warnings
warnings.filterwarnings('ignore')

import random
import time
import multiprocessing as mp
import numpy as np

import mxnet as mx
from mxnet import nd, gluon, autograd

import gluonnlp as nlp
import pickle

random.seed(123)
np.random.seed(123)
mx.random.seed(123)

In [2]:
train_dataset = pickle.load(open('../data/train_processed.p', 'rb'))
test_dataset = pickle.load(open('../data/dev_processed.p', 'rb'))

In [3]:
vocabulary = {'<pad>': [0, 1], '<unk>': [1, 1], '<BOS>': [2, 1], '<EOS>': [3, 1]}
for item in train_dataset + test_dataset:
    words = item[2].split(' ')
    for word in words:
        if word in vocabulary:
            vocabulary[word][1] += 1
        else:
            vocabulary[word] = [len(vocabulary), 1]

In [4]:
def preprocess(x):
    name, audio, words = x
    split_words = ['<BOS>'] + words.split(' ') + ['<EOS>']
    return audio, np.array([vocabulary[word][0] for word in split_words]), float(len(audio)), float(len(split_words))

def get_length(x):
    return float(len(x[1]))

def preprocess_dataset(dataset):
    start = time.time()
    with mp.Pool() as pool:
        dataset = gluon.data.SimpleDataset(pool.map(preprocess, dataset))
        lengths = gluon.data.SimpleDataset(pool.map(get_length, dataset))
    end = time.time()
    print('Done! Processing Time={:.2f}s, #Samples={}'.format(end - start, len(dataset)))
    return dataset, lengths

In [5]:
train_dataset, train_data_lengths = preprocess_dataset(train_dataset)
test_dataset, test_data_lengths = preprocess_dataset(test_dataset)

Done! Processing Time=153.64s, #Samples=28539
Done! Processing Time=5.79s, #Samples=2703


In [6]:
learning_rate, batch_size = 0.005, 32
bucket_num, bucket_ratio = 10, 0.2

def get_dataloader():
    batchify_fn = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(),
        nlp.data.batchify.Pad(),
        nlp.data.batchify.Stack(dtype='float32'),
        nlp.data.batchify.Stack(dtype='float32'))
    batch_sampler = nlp.data.sampler.FixedBucketSampler(
        train_data_lengths,
        batch_size=batch_size,
        num_buckets=bucket_num,
        ratio=bucket_ratio,
        shuffle=True)
    print(batch_sampler.stats())

    train_dataloader = gluon.data.DataLoader(
        dataset=train_dataset,
        batch_sampler=batch_sampler,
        batchify_fn=batchify_fn)
    test_dataloader = gluon.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        batchify_fn=batchify_fn)
    return train_dataloader, test_dataloader

train_dataloader, test_dataloader = get_dataloader()

FixedBucketSampler:
  sample_num=28539, batch_num=894
  key=[14, 21, 28, 35, 42, 49, 56, 63, 70, 77]
  cnt=[1723, 1713, 2319, 5188, 8205, 6607, 2317, 420, 44, 3]
  batch_size=[35, 32, 32, 32, 32, 32, 32, 32, 32, 32]


In [7]:
context = mx.cpu()

In [8]:
for i, example in enumerate(train_dataloader):
    if i >= 1:
        break
    print(example)

(
[[[ -8.37042654  -6.16026004   6.29350528 ...   0.59649253   9.17632351
     8.05719958]
  [ -8.51415536  -6.47229985   5.25051731 ...   3.97691196  19.89046768
    12.82517394]
  [ -8.86324903  -0.31330465  -1.34714977 ...  -4.47258676  19.94017961
    11.05409373]
  ...
  [  0.           0.           0.         ...   0.           0.
     0.        ]
  [  0.           0.           0.         ...   0.           0.
     0.        ]
  [  0.           0.           0.         ...   0.           0.
     0.        ]]

 [[ -7.23614367 -23.17302696 -28.19189459 ...   6.40932545 -10.95141835
     5.9777419 ]
  [ -4.05327935 -11.33491852 -29.78987646 ...  -1.95894314  11.89255368
    -9.4640987 ]
  [ -3.8196864  -19.94633816 -43.70269579 ... -14.56101723   5.60094435
    -2.42812674]
  ...
  [  0.           0.           0.         ...   0.           0.
     0.        ]
  [  0.           0.           0.         ...   0.           0.
     0.        ]
  [  0.           0.           0.         ...

In [None]:
class EncoderRNN(Block):
    def __init__(self, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        with self.name_scope():
            self.rnn = rnn.GRU(hidden_size, input_size=self.hidden_size)

    def forward(self, input, hidden):
        input = input.swap_axes(0, 1)
        output, hidden = self.rnn(input, hidden)
        return output, hidden

    def initHidden(self, ctx):
        return [F.zeros((1, 1, self.hidden_size), ctx=ctx)]

In [None]:
class DecoderRNN(Block):
    def __init__(self, output_size, hidden_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        with self.name_scope():
            self.embedding = nn.Embedding(output_size, hidden_size)
            self.rnn = rnn.GRU(hidden_size, input_size=self.hidden_size)
            self.out = nn.Dense(output_size, in_units=self.hidden_size)

    def forward(self, input, hidden):
        output = self.embedding(input).swapaxes(0, 1)
        output, hidden = self.rnn(output, hidden)
        output = self.out(output)
        return output, hidden

    def initHidden(self, ctx):
        return [F.zeros((1, 1, self.hidden_size), ctx=ctx)]

In [None]:
class Audio2TranscriptionNet(gluon.HybridBlock):
    def __init__(self, dropout, prefix=None, params=None):
        super(Audio2TranscriptionNet, self).__init__(prefix=prefix, params=params)
        with self.name_scope():
            self.encoder = EncoderRNN(hidden_size=256)
            self.decoder = DecoderRNN(hidden_size=256, output_size=34798)

    def hybrid_forward(self, F, audio, words, audio_length, words_length): # pylint: disable=arguments-differ
        encoded = self.encoder(audio)  
        agg_state = self.agg_layer(encoded, valid_length)
        out = self.output(agg_state)
        return out

In [None]:
encoder = EncoderRNN(hidden_size=256)
decoder = DecoderRNN(hidden_size=256, output_size=34798)

def train(net, context, epochs):
    trainer = gluon.Trainer(net.collect_params(), 'ftml',
                            {'learning_rate': learning_rate})
    loss = gluon.loss.SigmoidBCELoss()

    parameters = net.collect_params().values()

    for epoch in range(epochs):
        start_epoch_time = time.time()
        epoch_L = 0.0
        epoch_sent_num = 0
        epoch_wc = 0

        start_log_interval_time = time.time()
        log_interval_wc = 0
        log_interval_sent_num = 0
        log_interval_L = 0.0

        for i, ((audio, words, alength, wlength), label) in enumerate(train_dataloader):
            L = 0
            wc = alength.sum().asscalar()
            log_interval_wc += wc
            epoch_wc += wc
            log_interval_sent_num += audio.shape[1]
            epoch_sent_num += audio.shape[1]
            with autograd.record():
                hidden = encoder.initHidden(context)
                enc_output = encoder(audio.as_in_context(context), hidden)
#                 L = L + loss(output, label.as_in_context(context)).mean()
#             L.backward()
#             # Clip gradient
#             if grad_clip:
#                 gluon.utils.clip_global_norm(
#                     [p.grad(context) for p in parameters],
#                     grad_clip)
#             # Update parameter
#             trainer.step(1)
#             log_interval_L += L.asscalar()
#             epoch_L += L.asscalar()
#             if (i + 1) % log_interval == 0:
#                 print(
#                     '[Epoch {} Batch {}/{}] elapsed {:.2f} s, '
#                     'avg loss {:.6f}, throughput {:.2f}K wps'.format(
#                         epoch, i + 1, len(train_dataloader),
#                         time.time() - start_log_interval_time,
#                         log_interval_L / log_interval_sent_num, log_interval_wc
#                         / 1000 / (time.time() - start_log_interval_time)))
#                 # Clear log interval training stats
#                 start_log_interval_time = time.time()
#                 log_interval_wc = 0
#                 log_interval_sent_num = 0
#                 log_interval_L = 0
#         end_epoch_time = time.time()
#         test_avg_L, test_acc = evaluate(net, test_dataloader, context)
#         print('[Epoch {}] train avg loss {:.6f}, test acc {:.2f}, '
#               'test avg loss {:.6f}, throughput {:.2f}K wps'.format(
#                   epoch, epoch_L / epoch_sent_num, test_acc, test_avg_L,
#                   epoch_wc / 1000 / (end_epoch_time - start_epoch_time)))