# Bayes by Backprop with ``gluon`` (RNNs for sequence prediction)

In this chapter, we apply [Bayes by Backprop](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop-gluon.ipynb) ``(BBB)`` to learning [recurrent neural networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb) for sequence prediction.

As we've seen, Bayes-by-backprop lets us fit expressive models efficiently and lets us represent uncertainty about  model parameters. Representing uncertainty helps avoid overfitting and is an important part of sound decision making.

Thankfully, ``BBB`` for RNNs is not much more difficult than in the feed-forward case. It mainly requires swapping a recurrent neural network for a feed-forward one and changing the log-likelihood to something appropriate for sequence modeling.

In what follows, we reimplement the sequence model from [''Bayesian Recurrent Neural Networks'', by Fortunato et al.](https://arxiv.org/pdf/1704.02798.pdf) and recreate the authors' experiments on the Penn Tree-Bank dataset, which we also used in the Straight-Dope chapter, [recurrent neural networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb).

If you have not looked at the chapters [Bayes by Backprop](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop-gluon.ipynb) or [Recurrent Neural Networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb), it is worth doing so since we reuse a lot that code.


## Import Packages

To begin with, we need to make the following necessary imports.

In [1]:
import math
import os
import time
import numpy as np
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import nn, rnn

## Initialize Configuration Variables and Hyperparameters

Next we perform basic configuration and initialize model hyperparameters.

In [2]:
#context = mx.gpu(0)
#args_home = '/home/ubuntu'
#args_data_root = args_home + '/mxnet-the-straight-dope/data'
#args_data = args_data_root + '/nlp/ptb.'
context = mx.cpu(0)
args_data = '../data/nlp/ptb.'
args_model = 'lstm'
args_emsize = 100
args_nhid = 100
args_nlayers = 2
args_lr = 1.0
args_clip = 0.2
args_epochs = 1
args_batch_size = 32
args_bptt = 5
args_tied = True
args_cuda = 'store_true'
args_log_interval = 500
args_save = 'model.param'

## Define Classes for Loading the Language Data

To load and access the Penn Tree-Bank data, we use the classes we defined in the Straight-Dope chapter on [recurrent neural networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb). As in that notebook, the ``Dictionary`` class maps words (i.e. strings) to numeric token IDs. And the ``Corpus`` class stores the sequence of words in the training, validation and testing sets. The functions ``batchify`` and ``get_batch`` allow us to iterate over the corpus data.

In [3]:
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(path + 'train.txt')
        self.valid = self.tokenize(path + 'valid.txt')
        self.test = self.tokenize(path + 'test.txt')

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r') as f:
            ids = np.zeros((tokens,), dtype='int32')
            token = 0
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1

        return mx.nd.array(ids, dtype='int32')


def batchify(data, batch_size):
    """Reshape data into (num_example, batch_size)"""
    nbatch = data.shape[0] // batch_size
    data = data[:nbatch * batch_size]
    data = data.reshape((batch_size, nbatch)).T
    return data

def get_batch(source, i):
    seq_len = min(args_bptt, source.shape[0] - 1 - i)
    data = source[i : i + seq_len]
    target = source[i + 1 : i + 1 + seq_len]
    return data, target.reshape((-1,))


## Define our Recurrent Neural Network Class

Next we define the recurrent neural network class which we saw in the chapter [recurrent neural networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb).

We've added a helper method, ``set_params_to``, which is used to set the parameters of the RNN to ones sampled from a variational posterior.

We also define an auxiliary function, ``detach``, which we use to detach the hidden state from the computation graph after every minibatch. By detaching the hidden state after each minibatch, we relieve MXNet of trying to propagate  the gradient indefinitely far back in time.

In [4]:
class RNNModel(gluon.Block):
    """A model with an encoder, recurrent layer, and a decoder."""

    def __init__(self, mode, vocab_size, num_embed, num_hidden,
                 num_layers, dropout=0.5, tie_weights=False, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        with self.name_scope():
            self.drop = nn.Dropout(dropout)
            self.encoder = nn.Embedding(vocab_size, num_embed,
                                        weight_initializer = mx.init.Uniform(0.1))
            if mode == 'rnn_relu':
                self.rnn = rnn.RNN(num_hidden, num_layers, activation='relu',
                                   dropout=dropout, input_size=num_embed)
            elif mode == 'rnn_tanh':
                self.rnn = rnn.RNN(num_hidden, num_layers, dropout=dropout,
                                   input_size=num_embed)
            elif mode == 'lstm':
                self.rnn = rnn.LSTM(num_hidden, num_layers, dropout=dropout,
                                    input_size=num_embed)
            elif mode == 'gru':
                self.rnn = rnn.GRU(num_hidden, num_layers, dropout=dropout,
                                   input_size=num_embed)
            else:
                raise ValueError("Invalid mode %s. Options are rnn_relu, "
                                 "rnn_tanh, lstm, and gru"%mode)
            if tie_weights:
                self.decoder = nn.Dense(vocab_size, in_units = num_hidden,
                                        params = self.encoder.params)
            else:
                self.decoder = nn.Dense(vocab_size, in_units = num_hidden)
            self.num_hidden = num_hidden

    def forward(self, inputs, hidden):
        emb = self.drop(self.encoder(inputs))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.reshape((-1, self.num_hidden)))
        return decoded, hidden

    def begin_state(self, *args, **kwargs):
        return self.rnn.begin_state(*args, **kwargs)

    def set_params_to(self, new_values):
        for model_param, new_value in zip(self.collect_params().values(), new_values):
            model_param_ctx = model_param.list_ctx()[0]
            model_param._data[ model_param_ctx ] = new_value
        return



def detach(hidden):
    if isinstance(hidden, (tuple, list)):
        hidden = [i.detach() for i in hidden]
    else:
        hidden = hidden.detach()
    return hidden

## Load the Penn Tree-Bank Corpus

By now we should be familiar with the Penn Tree-Bank corpus. It's basically the MNIST of sequence modeling.

In [5]:
corpus = Corpus(args_data)
ntokens = len(corpus.dictionary)
train_data = batchify(corpus.train, args_batch_size).as_in_context(context)
val_data = batchify(corpus.valid, args_batch_size).as_in_context(context)
test_data = batchify(corpus.test, args_batch_size).as_in_context(context)
num_batches = int(np.ceil( (train_data.shape[0] - 1)/args_bptt) )

## Define and Initialize Our Favorite Recurrent Neural Network

In [6]:
args_dropout = 0.5
model = RNNModel(args_model, ntokens, args_emsize, args_nhid, args_nlayers, args_dropout, args_tied)
model.collect_params().initialize(mx.init.Xavier(), ctx=context)

trainer = gluon.Trainer(
    model.collect_params(), 'sgd',
    {'learning_rate': args_lr, 'momentum': 0, 'wd': 0}
)

smce_loss = gluon.loss.SoftmaxCrossEntropyLoss()

## Define Standard RNN Train and Evaluate Procedures

In [7]:
def train_standard():
    best_val = float("Inf")
    for epoch in range(args_epochs):
        total_L = 0.0
        start_time = time.time()
        hidden = model.begin_state(func = mx.nd.zeros, batch_size = args_batch_size, ctx = context)
        for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args_bptt)):
            data, target = get_batch(train_data, i)
            hidden = detach(hidden)
            with autograd.record():
                output, hidden = model(data, hidden)
                L = smce_loss(output, target)
                L.backward()

            grads = [i.grad(context) for i in model.collect_params().values()]
            # Here gradient is for the whole batch.
            # So we multiply max_norm by batch_size and bptt size to balance it.
            gluon.utils.clip_global_norm(grads, args_clip * args_bptt * args_batch_size)

            trainer.step(args_batch_size)
            total_L += mx.nd.sum(L).asscalar()

            if ibatch % args_log_interval == 0 and ibatch > 0:
                cur_L = total_L / args_bptt / args_batch_size / args_log_interval
                print('[Epoch %d Batch %d] loss %.2f, perplexity %.2f' % (
                    epoch + 1, ibatch, cur_L, math.exp(cur_L)))
                total_L = 0.0

        val_L = evaluate(val_data, model)

        print('[Epoch %d] time cost %.2fs, validation loss %.2f, validation perplexity %.2f' % (
            epoch + 1, time.time() - start_time, val_L, math.exp(val_L)))

        if val_L < best_val:
            best_val = val_L
            test_L = evaluate(test_data, model)
            model.save_params(args_save)
            print('test loss %.2f, test perplexity %.2f' % (test_L, math.exp(test_L)))
        else:
            args_lr = args_lr * 0.25
            trainer._init_optimizer('sgd',
                                    {'learning_rate': args_lr,
                                     'momentum': 0,
                                     'wd': 0})
            model.load_params(args_save, context)
    return


def evaluate(data_source, model):
    total_L = 0.0
    ntotal = 0
    hidden = model.begin_state(func = mx.nd.zeros, batch_size = args_batch_size, ctx=context)
    for i in range(0, data_source.shape[0] - 1, args_bptt):
        data, target = get_batch(data_source, i)
        output, hidden = model(data, hidden)
        L = smce_loss(output, target)
        total_L += mx.nd.sum(L).asscalar()
        ntotal += L.size
    return total_L / ntotal

## Train and Evaluate the Standard LSTM Baseline

In [8]:
train_standard()
model.load_params(args_save, context)
test_L = evaluate(test_data, model)
print('Best test loss %.2f, test perplexity %.2f'%(test_L, math.exp(test_L)))

[Epoch 1 Batch 500] loss 6.87, perplexity 962.37
[Epoch 1 Batch 1000] loss 6.61, perplexity 738.91
[Epoch 1 Batch 1500] loss 6.40, perplexity 602.46
[Epoch 1 Batch 2000] loss 6.29, perplexity 538.11
[Epoch 1 Batch 2500] loss 6.18, perplexity 480.82
[Epoch 1 Batch 3000] loss 6.06, perplexity 428.86
[Epoch 1 Batch 3500] loss 6.06, perplexity 426.40
[Epoch 1 Batch 4000] loss 5.93, perplexity 375.06
[Epoch 1 Batch 4500] loss 5.91, perplexity 368.07
[Epoch 1 Batch 5000] loss 5.90, perplexity 364.07
[Epoch 1 Batch 5500] loss 5.89, perplexity 362.32
[Epoch 1] time cost 440.98s, validation loss 5.68, validation perplexity 292.24
test loss 5.65, test perplexity 283.49
Best test loss 5.65, test perplexity 283.49


# Define Bayes-by-Backprop Classes

Now we begin the Bayes-by-backprop portion of the chapter in earnest. To start, we define a Gaussian prior over the RNN model's parameters. It is initialized with a prior mean and standard deviation $\sigma$, and has one important method, ``log_prob``, which computes the log-probability of the RNN model's paramters under the given prior.

In [9]:
def log_gaussian_prob(x, mu, sigma):
    return -0.5 * np.log(2.0 * np.pi) - mx.nd.log(sigma) - (x - mu) ** 2 / (2 * sigma ** 2)


class Prior(object):

    def __init__(self, prior_mu, prior_sigma):
        self.prior_mu = mx.nd.array([prior_mu], ctx=context)
        self.prior_sigma = mx.nd.array([prior_sigma], ctx=context)
        return

    def log_prob(self, model_params):
        log_probs = [
            mx.nd.sum(log_gaussian_prob(model_param, self.prior_mu, self.prior_sigma))
            for model_param in model_params
        ]
        total_log_prob = log_probs[0]
        for log_prob in log_probs[1:]:
            total_log_prob = total_log_prob + log_prob
        return total_log_prob

## The Variational Posterior

Next we define the variational posterior class. Like the prior class, the variational posterior is able to compute the log-probability of the RNN's parameters under its variational posterior distribution.

However, unlike the prior class, the variational posterior's parameters get updated during training, whereas the prior's parameters are never updated. They are, after all, "prior".

Additionally, the variational posterior provides a method to sample the RNN's parameters from its current posterior distribution. This is done with the ``sample_model_params`` method.

In [15]:
class VarPosterior(object):

    def __init__(self, model, var_sigma, var_mu_init_scale):
        self.var_sigma = mx.nd.array([var_sigma], ctx=context)
        self.var_mus = []
        self.raw_var_mus = []
        for i, model_param in enumerate(model.collect_params().values()):
            mu = gluon.Parameter(
                'mu_{}'.format(i), shape=model_param.shape,
                init=mx.init.Normal(var_mu_init_scale))
            mu.initialize(ctx=context)
            self.var_mus.append(mu)
            self.raw_var_mus.append(mu.data(context))
        return

    def sample_model_params(self):
        model_params = []
        for raw_var_mu in self.raw_var_mus:
            epsilon = mx.nd.random_normal(shape=raw_var_mu.shape, loc=0., scale=1.0, ctx=context)
            model_param = raw_var_mu + self.var_sigma * epsilon
            model_params.append(model_param)
        return model_params

    def log_prob(self, model_params):
        log_probs = [
            mx.nd.sum(log_gaussian_prob(model_param, raw_var_mu, self.var_sigma))
            for (model_param, raw_var_mu) in zip(model_params, self.raw_var_mus)
        ]
        total_log_prob = log_probs[0]
        for log_prob in log_probs[1:]:
            total_log_prob = total_log_prob + log_prob
        return total_log_prob

    def num_params(self):
        return sum([
            np.prod(param.shape)
            for param in self.var_mus
        ])

## Define the Bayes-by-Backprop Loss

As discussed, the Bayes-by-backprop loss is the sum of the expected log-likelihood on the training data and the KL-divergence between the variational posterior and the prior:

 ``INSERT FORMULA HERE``

In [16]:
class BBB_RNN_Loss(gluon.loss.Loss):

    def __init__(self, prior, var_posterior, log_likelihood, num_batches, weight=None, batch_axis=0, **kwargs):
        super(BBB_RNN_Loss, self).__init__(weight, batch_axis, **kwargs)
        self.prior = prior
        self.var_posterior = var_posterior
        self.log_likelihood = log_likelihood
        self.num_batches = num_batches
        return
    
    def forward(self, yhat, y, sampled_params, sample_weight=None):
        neg_log_likelihood = mx.nd.sum(self.log_likelihood(yhat, y))
        prior_log_prob = mx.nd.sum(self.prior.log_prob(sampled_params))
        var_post_log_prob = mx.nd.sum(self.var_posterior.log_prob(sampled_params))
        kl_loss = var_post_log_prob - prior_log_prob
        var_loss = neg_log_likelihood + kl_loss / self.num_batches
        return var_loss, neg_log_likelihood

## Initialize BBB-relevant classes

In [17]:
args_dropout = 0.0
model = RNNModel(args_model, ntokens, args_emsize, args_nhid, args_nlayers, args_dropout, args_tied)
model.collect_params().initialize(mx.init.Xavier(), ctx=context)

prior = Prior(0.0, 2)
var_posterior = VarPosterior(model, 0.1, 0.1)
bbb_rnn_loss = BBB_RNN_Loss(prior,
                            var_posterior,
                            gluon.loss.SoftmaxCrossEntropyLoss(),
                            num_batches)

trainer = gluon.Trainer(
    var_posterior.var_mus, 'sgd',
    { 'learning_rate': args_lr, 'momentum': 0, 'wd': 0 }
)

## Define the BBB-RNN training routine

In [18]:
def train_bbb_rnn():
    best_val = float("Inf")

    for epoch in range(args_epochs):
        total_L = 0.0
        start_time = time.time()
        hidden = model.begin_state(func = mx.nd.zeros, batch_size = args_batch_size, ctx = context)

        for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args_bptt)):
            x, y = get_batch(train_data, i)
            hidden = detach(hidden)

            with autograd.record():
                sampled_params = var_posterior.sample_model_params()
                model.set_params_to(sampled_params)
                yhat, hidden = model(x, hidden)
                var_loss, L = bbb_rnn_loss(yhat, y, sampled_params)
                var_loss.backward()

            grads = [var_mu.grad(context) for var_mu in var_posterior.var_mus]
            gluon.utils.clip_global_norm(grads, args_clip * var_posterior.num_params())
            trainer.step(args_batch_size)
            total_L += mx.nd.sum(L).asscalar()

            if ibatch % args_log_interval == 0 and ibatch > 0:
                cur_L = total_L / args_bptt / args_batch_size / args_log_interval
                print('[Epoch %d Batch %d] loss %.2f, perplexity %.2f' % (
                    epoch + 1, ibatch, cur_L, math.exp(cur_L)))
                total_L = 0.0

        model.set_params_to(var_posterior.raw_var_mus)
        val_L = evaluate(val_data, model)

        print('[Epoch %d] time cost %.2fs, validation loss %.2f, validation perplexity %.2f' % (
            epoch + 1, time.time() - start_time, val_L, math.exp(val_L)))

        if val_L < best_val:
            best_val = val_L
            model.set_params_to(var_posterior.raw_var_mus)
            test_L = evaluate(test_data, model)
            model.save_params(args_save)
            print('test loss %.2f, test perplexity %.2f' % (test_L, math.exp(test_L)))
        else:
            args_lr = args_lr * 0.25
            trainer._init_optimizer('sgd',
                                    {'learning_rate': args_lr,
                                     'momentum': 0,
                                     'wd': 0})
            model.load_params(args_save, context)
    return

## Start BBB training and evaluation

In [19]:
train_bbb_rnn()
model.load_params(args_save, context)
model.set_params_to(var_posterior.raw_var_mus)
test_L = evaluate(test_data, model, var_posterior.raw_var_mus)
print('Best test loss %.2f, test perplexity %.2f'%(test_L, math.exp(test_L)))

[Epoch 1 Batch 500] loss 6.85, perplexity 942.83
[Epoch 1 Batch 1000] loss 6.28, perplexity 533.58
[Epoch 1 Batch 1500] loss 6.08, perplexity 436.50


KeyboardInterrupt: 