# Bayes by Backprop for Recurrent Neural Networks (RNNs)

In this chapter, we apply [Bayes by Backprop](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop-gluon.ipynb), or "BBB" for short, to a more challenging modeling problem, learning [recurrent neural networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb) for sequence prediction.

As we've seen, Bayes-by-backprop lets us fit expressive models efficiently and lets us represent uncertainty about  our model's parameters. Representing uncertainty not only helps to avoid overfitting, it is an important part of sound decision making.

Thankfully, ``BBB`` for RNNs is not much more difficult than in the feed-forward case. It really just requires replacing the feed-forward neural network with a recurrent one, and changing the log-likelihood to something appropriate for sequence modeling.

In what follows, we reimplement the sequence model from [''Bayesian Recurrent Neural Networks'', by Fortunato et al.](https://arxiv.org/pdf/1704.02798.pdf) and rerun the authors' experiments on the Penn Treebank dataset, which you may recall using in chapter, [recurrent neural networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb).

If you have not looked at the chapters [Bayes by Backprop](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop-gluon.ipynb) and [Recurrent Neural Networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb), it is worth doing so since we reuse a lot that code.

## Import Packages and Initialize Configuration and Hyperparameters

First we make some necessary package imports, perform basic configuration and set some model hyperparameters.

In [1]:
import math
import os
import time
import numpy as np
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import nn, rnn

In [2]:
context = mx.gpu(0)
args_data = '../data/nlp/ptb.'
args_model = 'lstm'
args_emsize = 100
args_nhid = 100
args_nlayers = 2
args_lr = 10.0
args_clip = 0.2
args_epochs = 2
args_batch_size = 32
args_bptt = 5
args_dropout = 0.2
args_tied = True
args_cuda = 'store_true'
args_log_interval = 500
args_save = 'model.param'

## Define Classes for Loading the Language Data

Now let's load the Penn Treebank data as we did in [chapter 5, recurrent neural networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb).

In [3]:
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(path + 'train.txt')
        self.valid = self.tokenize(path + 'valid.txt')
        self.test = self.tokenize(path + 'test.txt')

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r') as f:
            ids = np.zeros((tokens,), dtype='int32')
            token = 0
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1

        return mx.nd.array(ids, dtype='int32')


def batchify(data, batch_size):
    """Reshape data into (num_example, batch_size)"""
    nbatch = data.shape[0] // batch_size
    data = data[:nbatch * batch_size]
    data = data.reshape((batch_size, nbatch)).T
    return data

def get_batch(source, i):
    seq_len = min(args_bptt, source.shape[0] - 1 - i)
    data = source[i : i + seq_len]
    target = source[i + 1 : i + 1 + seq_len]
    return data, target.reshape((-1,))


In [4]:
corpus = Corpus(args_data)
ntokens = len(corpus.dictionary)
train_data = batchify(corpus.train, args_batch_size).as_in_context(context)
val_data = batchify(corpus.valid, args_batch_size).as_in_context(context)
test_data = batchify(corpus.test, args_batch_size).as_in_context(context)
num_batches = int(np.ceil( (train_data.shape[0] - 1)/args_bptt) )

# Define the Recurrent Neural Network Model

Now let's resurrect our recurrent neural network from [chapter 5, recurrent neural networks](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb).

Here we've added a convenience method to the RNN model class called ``set_params_to``. This method is used by Bayes-by-backprop to set the RNN parameters to ones sampled from our variational posterior (details below).

We've also defined an auxiliary function, ``detach``, that detaches a hidden state from the computation graph. By detaching the hidden state after each minibatch, we relieve MXNet of trying to back-propagate the gradient across minibatches, and thus, indefinitely far back in time.

In [5]:
class RNNModel(gluon.Block):
    """A model with an encoder, recurrent layer, and a decoder."""

    def __init__(self, mode, vocab_size, num_embed, num_hidden,
                 num_layers, dropout=0.5, tie_weights=False, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        with self.name_scope():
            self.drop = nn.Dropout(dropout)
            self.encoder = nn.Embedding(vocab_size, num_embed,
                                        weight_initializer = mx.init.Uniform(0.1))
            if mode == 'rnn_relu':
                self.rnn = rnn.RNN(num_hidden, num_layers, activation='relu',
                                   dropout=dropout, input_size=num_embed)
            elif mode == 'rnn_tanh':
                self.rnn = rnn.RNN(num_hidden, num_layers, dropout=dropout,
                                   input_size=num_embed)
            elif mode == 'lstm':
                self.rnn = rnn.LSTM(num_hidden, num_layers, dropout=dropout,
                                    input_size=num_embed)
            elif mode == 'gru':
                self.rnn = rnn.GRU(num_hidden, num_layers, dropout=dropout,
                                   input_size=num_embed)
            else:
                raise ValueError("Invalid mode %s. Options are rnn_relu, "
                                 "rnn_tanh, lstm, and gru"%mode)
            if tie_weights:
                self.decoder = nn.Dense(vocab_size, in_units = num_hidden,
                                        params = self.encoder.params)
            else:
                self.decoder = nn.Dense(vocab_size, in_units = num_hidden)
            self.num_hidden = num_hidden

    def forward(self, inputs, hidden):
        emb = self.drop(self.encoder(inputs))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.reshape((-1, self.num_hidden)))
        return decoded, hidden

    def begin_state(self, *args, **kwargs):
        return self.rnn.begin_state(*args, **kwargs)

    def set_params_to(self, new_values):
        for model_param, new_value in zip(self.collect_params().values(), new_values):
            model_param_ctx = model_param.list_ctx()[0]
            model_param._data[ model_param_ctx ] = new_value
        return



def detach(hidden):
    if isinstance(hidden, (tuple, list)):
        hidden = [i.detach() for i in hidden]
    else:
        hidden = hidden.detach()
    return hidden

## Initialize a Baseline RNN

For comparison purposes, let's initialize our RNN from [chapter 5](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb) so we can verify that our "BBB RNN" performs just as well. Of course we also need to train and evaluate this baseline model. The ``train_baseline`` and ``evaluate`` routines, also from chapter 5, do this.

In [6]:
baseline_model = RNNModel(args_model, ntokens, args_emsize, args_nhid, args_nlayers, args_dropout, args_tied)
baseline_model.collect_params().initialize(mx.init.Xavier(), ctx=context)

trainer = gluon.Trainer(
    baseline_model.collect_params(), 'sgd',
    {'learning_rate': args_lr, 'momentum': 0, 'wd': 0})

smce_loss = gluon.loss.SoftmaxCrossEntropyLoss()

In [7]:
def train_baseline(model):
    global args_lr
    best_val = float("Inf")
    for epoch in range(args_epochs):
        total_L = 0.0
        start_time = time.time()
        hidden = model.begin_state(func = mx.nd.zeros, batch_size = args_batch_size, ctx = context)
        for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args_bptt)):
            data, target = get_batch(train_data, i)
            hidden = detach(hidden)
            with autograd.record():
                output, hidden = model(data, hidden)
                L = smce_loss(output, target)
                L.backward()

            grads = [i.grad(context) for i in model.collect_params().values()]
            # Here gradient is for the whole batch.
            # So we multiply max_norm by batch_size and bptt size to balance it.
            gluon.utils.clip_global_norm(grads, args_clip * args_bptt * args_batch_size)

            trainer.step(args_batch_size * args_bptt)
            total_L += mx.nd.sum(L).asscalar()

            if ibatch % args_log_interval == 0 and ibatch > 0:
                cur_L = total_L / args_bptt / args_batch_size / args_log_interval
                print('[Epoch %d Batch %d] loss %.2f, perplexity %.2f' % (
                    epoch + 1, ibatch, cur_L, math.exp(cur_L)))
                total_L = 0.0

        val_L = evaluate(val_data, model)

        print('[Epoch %d] time cost %.2fs, validation loss %.2f, validation perplexity %.2f' % (
            epoch + 1, time.time() - start_time, val_L, math.exp(val_L)))

        if val_L < best_val:
            best_val = val_L
            test_L = evaluate(test_data, model)
            model.save_parameters(args_save)
            print('test loss %.2f, test perplexity %.2f' % (test_L, math.exp(test_L)))
        else:
            args_lr = args_lr * 0.25
            trainer._init_optimizer('sgd',
                                    {'learning_rate': args_lr,
                                     'momentum': 0,
                                     'wd': 0})
            model.load_parameters(args_save, context)
    return


def evaluate(data_source, model):
    total_L = 0.0
    ntotal = 0
    hidden = model.begin_state(func = mx.nd.zeros, batch_size = args_batch_size, ctx=context)
    for i in range(0, data_source.shape[0] - 1, args_bptt):
        data, target = get_batch(data_source, i)
        output, hidden = model(data, hidden)
        L = smce_loss(output, target)
        total_L += mx.nd.sum(L).asscalar()
        ntotal += L.size
    return total_L / ntotal

## Train and Evaluate our Baseline RNN

Okay, let's refresh our memory on how well the RNN from [chapter 5](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/rnns-gluon.ipynb) performs on the Penn Treebank data.

In [8]:
train_baseline(baseline_model)
baseline_model.load_parameters(args_save, context)
test_L = evaluate(test_data, baseline_model)
print('Best test loss %.2f, test perplexity %.2f'%(test_L, math.exp(test_L)))

[Epoch 1 Batch 500] loss 6.73, perplexity 834.37
[Epoch 1 Batch 1000] loss 6.14, perplexity 462.27
[Epoch 1 Batch 1500] loss 5.89, perplexity 360.56
[Epoch 1 Batch 2000] loss 5.81, perplexity 333.35
[Epoch 1 Batch 2500] loss 5.68, perplexity 292.76
[Epoch 1 Batch 3000] loss 5.56, perplexity 260.22
[Epoch 1 Batch 3500] loss 5.56, perplexity 260.70
[Epoch 1 Batch 4000] loss 5.43, perplexity 227.89
[Epoch 1 Batch 4500] loss 5.41, perplexity 222.73
[Epoch 1 Batch 5000] loss 5.40, perplexity 222.27
[Epoch 1 Batch 5500] loss 5.41, perplexity 223.43
[Epoch 1] time cost 56.27s, validation loss 5.30, validation perplexity 199.93
test loss 5.27, test perplexity 194.90
[Epoch 2 Batch 500] loss 5.38, perplexity 217.45
[Epoch 2 Batch 1000] loss 5.31, perplexity 202.51
[Epoch 2 Batch 1500] loss 5.27, perplexity 194.25
[Epoch 2 Batch 2000] loss 5.31, perplexity 201.45
[Epoch 2 Batch 2500] loss 5.26, perplexity 192.06
[Epoch 2 Batch 3000] loss 5.18, perplexity 177.65
[Epoch 2 Batch 3500] loss 5.22, pe

# Bayes-by-Backprop for RNNs

With our baseline RNN trained and evaluated, we can now move on to Bayes-by-backprop for RNNs.

Being good Bayesians, the first thing we should do is define a prior probability distribution over the parameters of our model.

As in [''Bayesian Recurrent Neural Networks'', by Fortunato et al.](https://arxiv.org/pdf/1704.02798.pdf) and [chapter 18](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop-gluon.ipynb), we define a "scale mixture" prior over the parameters to be a mixture of two Gaussians with different scales, or variances.

\begin{equation*}
\text{Prior}(w_i) = \prod_i \bigg ( \alpha \mathcal{N}(w_i\ |\ 0,\sigma_1^2) + (1 - \alpha) \mathcal{N}(w_i\ |\ 0,\sigma_2^2)\bigg )
\end{equation*}

The first Gaussian has a small scale, preferring parameters which are close to zero. The second has a larger scale, allowing parameter values to stray from zero. By making the prior be a mixture of these two scales, we can induce models where many parameters are close to zero, but some are far from zero. In other words, this prior prefers sparse models, models in which many parameters are effectively zero.

The amount of sparsity is determined by the hyperparameter $\alpha \in [0,1]$ which controls how much emphasis is placed on each Gaussian in the prior. Of course, the scale parameters $\sigma_1$ and $\sigma_2$ also control the sparsity since smaller $\sigma$'s induce smaller parameter values.

In [9]:
class ScaleMixturePrior(object):

    def __init__(self, alpha, sigma1, sigma2):
        self.alpha = mx.nd.array([alpha], ctx=context)
        self.one_minus_alpha = mx.nd.array([1 - alpha], ctx=context)
        self.zero = mx.nd.array([0.0], ctx=context)
        self.sigma1 = mx.nd.array([sigma1], ctx=context)
        self.sigma2 = mx.nd.array([sigma2], ctx=context)
        return

    def log_prob(self, model_params):
        total_log_prob = None
        for i, model_param in enumerate(model_params):
            p1 = gaussian_prob(model_param, self.zero, self.sigma1)
            p2 = gaussian_prob(model_param, self.zero, self.sigma2)
            log_prob = mx.nd.sum(mx.nd.log(self.alpha * p1 + self.one_minus_alpha * p2))
            if i == 0: total_log_prob = log_prob
            else: total_log_prob = total_log_prob + log_prob
        return total_log_prob


# Define some auxiliary functions
def log_gaussian_prob(x, mu, sigma):
    return - mx.nd.log(sigma) - (x - mu) ** 2 / (2 * sigma ** 2)

def gaussian_prob(x, mu, sigma):
    scaling = 1.0 / mx.nd.sqrt(2.0 * np.pi * (sigma ** 2))
    bell = mx.nd.exp(-(x - mu)**2 / (2.0 * sigma ** 2))
    return scaling * bell

## Define the Variational Posterior

What comes after the prior? Why, the posterior of course!

In this case, since we are doing variational Bayes, we will define a _variational posterior_. A variational posterior is just a parametric distribution that we choose which we will fit to the actual the posterior distribution during learning. The variational posterior is itself parameterized with a set of parameters, aptly named the "variational parameters". Variational inference consists of finding the set of variational parameters that best match the variational posterior to the actual posterior distribution of the model parameters.

As in [chapter 18](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop-gluon.ipynb), we define the variational posterior so that the posterior for each parameter in the model is Gaussian with its own mean and variance. Given our scale mixture prior and the data $\mathcal{D}$, we define the variational posterior as:

\begin{equation*}
P(w_i\ |\ \mathcal{D}, \alpha, \sigma_1, \sigma_2) 
= \mathcal{N}\left(w_i\ \big|\ \mu^{\text{var}}_i, \left(\sigma^{\text{var}}_i\right)^2\right)
\end{equation*}

where $\mu^{\text{var}}_i$ and $\left(\sigma^{\text{var}}_i\right)^2$ are the variational parameters determining the variational posterior for model parameter, $w_i$.

As in [chapter 18](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop-gluon.ipynb), we avoid the need for positivity constraints on the variational scale parameters by reparameterizing $\sigma^{\text{var}}_i$ as $\rho^{\text{var}}_i$ such that

\begin{equation*}
\sigma^{\text{var}}_i = \log(1 + \exp(\rho^{\text{var}}_i))
\end{equation*}

You might recognize $f(x) = \log(1 + \exp(x))$ as the "softplus" function.

### The Variational Posterior Class

Like our ``ScaleMixturePrior`` class, the ``VariationalPosterior`` class needs a method for computing the log-probability of the model parameters. But in order to run Bayes-by-backprop, it also needs a method for sampling a set of model parameters from the variational posterior distribution. We've implemented this method and named it ``sample_model_params``.

In [10]:
class VariationalPosterior(object):

    def __init__(self, model, var_mu_init_scale, var_sigma_init_scale):
        self.var_mus = []
        self.var_rhos = []
        self.raw_var_mus = []
        self.raw_var_rhos = []
        var_rho_init_scale = inv_softplus(var_sigma_init_scale)

        for i, model_param in enumerate(model.collect_params().values()):

            var_mu = gluon.Parameter(
                'var_mu_{}'.format(i), shape=model_param.shape,
                init=mx.init.Normal(var_mu_init_scale))
            var_mu.initialize(ctx=context)
            self.var_mus.append(var_mu)
            self.raw_var_mus.append(var_mu.data(context))

            var_rho = gluon.Parameter(
                'var_rho_{}'.format(i), shape=model_param.shape,
                init=mx.init.Constant(var_rho_init_scale))
            var_rho.initialize(ctx=context)
            self.var_rhos.append(var_rho)
            self.raw_var_rhos.append(var_rho.data(context))

        self.var_params = self.var_mus + self.var_rhos
        return

    def log_prob(self, model_params):
        log_probs = [
            mx.nd.sum(log_gaussian_prob(model_param, raw_var_mu, softplus(raw_var_rho)))
            for (model_param, raw_var_mu, raw_var_rho)
            in zip(model_params, self.raw_var_mus, self.raw_var_rhos)
        ]
        total_log_prob = log_probs[0]
        for log_prob in log_probs[1:]:
            total_log_prob = total_log_prob + log_prob
        return total_log_prob

    def sample_model_params(self):
        model_params = []
        for raw_var_mu, raw_var_rho in zip(self.raw_var_mus, self.raw_var_rhos):
            epsilon = mx.nd.random_normal(shape=raw_var_mu.shape, loc=0., scale=1.0, ctx=context)
            var_sigma = softplus(raw_var_rho)
            model_param = raw_var_mu + var_sigma * epsilon
            model_params.append(model_param)
        return model_params

    def num_params(self):
        return sum([
            2 * np.prod(param.shape)
            for param in self.var_mus
        ])


# Define some auxiliary functions
def softplus(x):
    return mx.nd.log(1. + mx.nd.exp(x))

def inv_softplus(x):
    if x <= 0: raise ValueError("x must be > 0: {}".format(x))
    return np.log(np.exp(x) - 1.0)

## Implementing the Bayes-by-Backprop Loss

We're almost done setting up the Bayes-by-backprop infrastructure. The final piece is approximating the variational loss, which is defined as the expected negative log-likelihood of the data (under the variational posterior) plus the KL-divergence between the variational posterior and the prior over the model parameters.

Denoting the set of variational parameters $\theta = \{(\mu_{\text{var}}^{(i)}, \sigma_{\text{var}}^{(i)}) \}$, we write the variational loss as

\begin{equation*}
\begin{split}
\text{loss}_{\text{var}}(\theta) = 
  \mathbb{E}_{q(\mathbf{w}\ |\ \mathbf{\theta})}[- \log P(\mathcal{D}\ |\ \mathbf{w})] +
  \text{KL}[q(\mathbf{w}\ |\ \mathbf{\theta})\ ||\ P(\mathbf{w})] .
\end{split}
\end{equation*}

Notice that computing this loss involves an integral over $\mathbf{w}$. In Bayes-by-backprop, we approximate this integral with a Monte Carlo estimate obtained by drawing samples of the model parameters from the variational posterior and approximating the loss on these samples.

Specifically, let $\{ \mathbf{w}^{(1)}, \ldots, \mathbf{w}^{(M)} \}$ be a sample of model parameters drawn from $q(\mathbf{w}\ |\ \theta)$. Then we can approximate the variational loss with the Monte Carlo estimate

\begin{equation*}
\text{loss}_{\text{mc}}(\theta\ ;\ \mathbf{w}^{(1)}, \ldots, \mathbf{w}^{(M)} ) =
\frac{1}{M} \sum_{m=1}^M \left(
 -\log P(\mathcal{D}\ |\ \mathbf{w}^{(m)}) 
 +\log q(\mathbf{w}^{(m)}\ |\ \theta)
 -\log \text{Prior}(\mathbf{w}^{(m)})
\right).
\end{equation*}

Of course, this requires evaluating the negative log-likelihood on the full data, $\mathcal{D}$. We can make a further approximation by merely evaluating the negative log-likelihood on a randomly sampled mini-batch of data, $\mathcal{D}^{(n)}$, at each iteration. Assuming that $N$ minibatches constitute a full pass over the entire data set, the Bayes-by-backprop loss function we seek to minimize is

\begin{equation*}
\text{loss}_{\text{bbb}}(\theta\ ;\ \mathcal{D}^{(n)}, \mathbf{w}^{(1)}, \ldots, \mathbf{w}^{(M)} ) =
\frac{1}{M} \sum_{m=1}^M \left(
 -\log P(\mathcal{D}^{(n)}\ |\ \mathbf{w}^{(m)}) 
 +\frac{1}{N} \left( \log q(\mathbf{w}^{(m)}\ |\ \theta)
 -\log \text{Prior}(\mathbf{w}^{(m)}) \right)
\right)
\end{equation*}

where we scale the approximation to the KL term by $1/N$ so that it has the right magnitude after we sum over the $N$ minibatches in a full pass over the data. Also, in practice, we set $M = 1$ so the outer sum from $m = 1 \ldots M$ disappears.

The ``forward`` method of the ``BBB_Loss`` class implements the $\text{loss}_{\text{bbb}}$ function.


In [11]:
class BBB_Loss(gluon.loss.Loss):

    def __init__(self, prior, var_posterior, log_likelihood, num_batches, weight=None, batch_axis=0, **kwargs):
        super(BBB_Loss, self).__init__(weight, batch_axis, **kwargs)
        self.prior = prior
        self.var_posterior = var_posterior
        self.log_likelihood = log_likelihood
        self.num_batches = num_batches
        return
    
    def forward(self, yhat, y, sampled_params, sample_weight=None):
        neg_log_likelihood = mx.nd.sum(self.log_likelihood(yhat, y))
        prior_log_prob = mx.nd.sum(self.prior.log_prob(sampled_params))
        var_post_log_prob = mx.nd.sum(self.var_posterior.log_prob(sampled_params))
        kl_loss = var_post_log_prob - prior_log_prob
        var_loss = neg_log_likelihood + kl_loss / self.num_batches
        return var_loss, neg_log_likelihood

## Bayes-by-Backprop Training for Recurrent Neural Nets

Okay, now that we've defined the classes needed for Bayes-by-backprop, let's write our BBB training and evaluation routine. It's much the same as what we used for the baseline model. However, the ``autograd.record`` block and evaluation code have some differences.

Here, since we are doing BBB, we need to sample the model parameters from our variational posterior and set the model's parameters to these sampled parameters. Then we can run the training forward pass with this "sampled model" and update the variational parameters accordingly to mininize the BBB-loss.

Additionally, at evaluation time, rather than sampling a set of model parameters like we do for training, we set the model's parameters to the variational $\mu$'s, since these represent typical model parameters.

One more final detail. We need to be careful about the size of our gradient step. Since the BBB-loss is a function of both the training data and the variational parameters, the effective sample size for each minibatch is proportional to the number of training instances plus the number of variational parameters. We therefore set the variable ``effective_sample_size`` to this quantity and use it to control the step size and gradient norm accordingly.


In [12]:
def train_bbb(model):
    global args_lr
    global args_ess_multiplier
    best_val = float("Inf")

    for epoch in range(args_epochs):
        total_L = 0.0
        start_time = time.time()
        hidden = model.begin_state(func = mx.nd.zeros, batch_size = args_batch_size, ctx = context)

        for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args_bptt)):
            x, y = get_batch(train_data, i)
            hidden = detach(hidden)

            with autograd.record():
                sampled_params = var_posterior.sample_model_params()
                model.set_params_to(sampled_params)
                yhat, hidden = model(x, hidden)
                var_loss, L = bbb_loss(yhat, y, sampled_params)
                var_loss.backward()

            grads = [var_mu.grad(context) for var_mu in var_posterior.var_mus]
            effective_batch_size = (args_bptt * args_batch_size) + (var_posterior.num_params() / num_batches)
            gluon.utils.clip_global_norm(grads, args_clip * effective_batch_size)
            trainer.step(args_clip * effective_batch_size)
            total_L += mx.nd.sum(L).asscalar()

            if ibatch % args_log_interval == 0 and ibatch > 0:
                cur_L = total_L / args_bptt / args_batch_size / args_log_interval
                print('[Epoch %d Batch %d] loss %.2f, perplexity %.2f' % (
                    epoch + 1, ibatch, cur_L, math.exp(cur_L)))
                total_L = 0.0

        model.set_params_to(var_posterior.raw_var_mus)
        val_L = evaluate(val_data, model)

        print('[Epoch %d] time cost %.2fs, validation loss %.2f, validation perplexity %.2f' % (
            epoch + 1, time.time() - start_time, val_L, math.exp(val_L)))

        if val_L < best_val:
            best_val = val_L
            model.set_params_to(var_posterior.raw_var_mus)
            test_L = evaluate(test_data, model)
            model.save_parameters(args_save)
            print('test loss %.2f, test perplexity %.2f' % (test_L, math.exp(test_L)))
        else:
            args_lr = args_lr * 0.25
            trainer._init_optimizer('sgd',
                                    {'learning_rate': args_lr,
                                     'momentum': 0,
                                     'wd': 0})
            model.load_parameters(args_save, context)
    return

We are just about ready to go. We can instantiate our model, the scale-mixture prior, and the variational posterior.

Some things to note before we pull the trigger are:
* Since we are learning the variational parameters, these are what need to be updated by the model trainer.
* Since dropout can be interpreted as a different approach to Bayesian learning [\[Gal and Ghahramani, 2016\]](http://proceedings.mlr.press/v48/gal16.pdf), we should turn it off for training.


In [13]:
bbb_model = RNNModel(args_model, ntokens, args_emsize, args_nhid, args_nlayers, dropout=0.0, tie_weights=args_tied)
bbb_model.collect_params().initialize(mx.init.Xavier(), ctx=context)

prior = ScaleMixturePrior(alpha = 0.75, sigma1 = 0.001, sigma2 = 0.75)

var_posterior = VariationalPosterior(bbb_model,
                                     var_mu_init_scale = 0.05,
                                     var_sigma_init_scale = 0.01)

bbb_loss = BBB_Loss(prior,
                    var_posterior,
                    gluon.loss.SoftmaxCrossEntropyLoss(),
                    num_batches)

trainer = gluon.Trainer(
    var_posterior.var_params, 'sgd',
    { 'learning_rate': args_lr, 'momentum': 0, 'wd': 0 })

Now, let's start training!

In [14]:
train_bbb(bbb_model)
bbb_model.load_parameters(args_save, context)
bbb_model.set_params_to(var_posterior.raw_var_mus)
test_L = evaluate(test_data, bbb_model)
print('Best test loss %.2f, test perplexity %.2f'%(test_L, math.exp(test_L)))

[Epoch 1 Batch 500] loss 6.56, perplexity 706.71
[Epoch 1 Batch 1000] loss 5.89, perplexity 362.56
[Epoch 1 Batch 1500] loss 5.64, perplexity 280.99
[Epoch 1 Batch 2000] loss 5.58, perplexity 264.55
[Epoch 1 Batch 2500] loss 5.45, perplexity 233.16
[Epoch 1 Batch 3000] loss 5.33, perplexity 206.05
[Epoch 1 Batch 3500] loss 5.33, perplexity 207.40
[Epoch 1 Batch 4000] loss 5.21, perplexity 182.22
[Epoch 1 Batch 4500] loss 5.18, perplexity 177.47
[Epoch 1 Batch 5000] loss 5.19, perplexity 179.08
[Epoch 1 Batch 5500] loss 5.20, perplexity 182.08
[Epoch 1] time cost 278.34s, validation loss 5.24, validation perplexity 189.40
test loss 5.21, test perplexity 183.07
[Epoch 2 Batch 500] loss 5.19, perplexity 178.91
[Epoch 2 Batch 1000] loss 5.10, perplexity 163.53
[Epoch 2 Batch 1500] loss 5.05, perplexity 156.11
[Epoch 2 Batch 2000] loss 5.10, perplexity 164.57
[Epoch 2 Batch 2500] loss 5.07, perplexity 158.48
[Epoch 2 Batch 3000] loss 4.98, perplexity 144.95
[Epoch 2 Batch 3500] loss 5.01, p

Okay, not bad. We do about as well as dropout. But with BBB, we also have the ability to assess the certainty of our model parameters. This is useful in pruning weights from a model, for example, or in applications where the model needs to interact with its environment. In such cases, having a model that "knows what it knows" is quite useful.

## Conclusion

We have implemented Bayes-for-backprop for recurrent neural networks as described in [''Bayesian Recurrent Neural Networks'' by Fortunato et al.](https://arxiv.org/pdf/1704.02798.pdf), and rerun the authors' experiments on the Penn Treebank data. The comparable results shows Bayes-by-backprop's applicability to problems more sophisticated than classification and regression.

For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)