In [None]:
from tqdm import tqdm, tqdm_notebook
from mxnet import np, npx
import mxnet as mx
from mxnet import nd, autograd, gluon
from mxnet.gluon import nn
import mxnet.gluon.probability as mgp
from mxnet.gluon.probability import StochasticBlock, StochasticSequential
import matplotlib
import matplotlib.pyplot as plt
from scipy import ndimage, misc

npx.set_np()
data_ctx = mx.cpu()
model_ctx = mx.gpu(2)

In [None]:
def load_data(batch_size):
    """
    Load MNIST
    """
    mnist_train = gluon.data.vision.MNIST(train=True)
    mnist_test = gluon.data.vision.MNIST(train=False)
    num_worker = 4
    transformer = gluon.data.vision.transforms.ToTensor()
    return (gluon.data.DataLoader(mnist_train.transform_first(transformer),
                                batch_size, shuffle=True,
                                num_workers=num_worker),
          gluon.data.DataLoader(mnist_test.transform_first(transformer),
                                batch_size, shuffle=False,
                                num_workers=num_worker))

Construct A Bayesian dense layer using local reparameterization trick

In [105]:
class LocalReparamDense(StochasticBlock):
    def __init__(self, in_features, out_features, activation=None, flatten=True, dtype='float32'):
        super(LocalReparamDense, self).__init__()
        self._flatten = flatten
        self.qw_x = None
        self._in_features = in_features
        self._out_features = out_features
        # Parameter of weight
        self.loc_w = gluon.Parameter('loc_w', shape=(out_features, in_features),
                                    dtype=dtype)
        self.scale_w = gluon.Parameter('log_scale_w', shape=(out_features, in_features),
                                    dtype=dtype)
        # Parameter of bias
        self.bias = gluon.Parameter('bias', shape=(out_features,),
                                    dtype=dtype)
        if activation is not None:
            self.act = gluon.nn.Activation(activation)
        else:
            self.act = None

    @StochasticBlock.collectLoss
    def hybrid_forward(self, F, x, loc_w, scale_w, bias):
        # We use `fc` operator for matrix multiplication.
        fc = F.npx.fully_connected
        # Directly acquire parameter for A = dot(x, W)
        # with local reparameterization trick:
        qa_loc = fc(x, loc_w, bias=None, no_bias=True, num_hidden=self._out_features,
                    flatten=self._flatten)
        qa_scale = F.np.sqrt(fc(x ** 2, scale_w ** 2, bias=None, no_bias=True,
                      num_hidden=self._out_features, flatten=self._flatten))
        self.qw_x = mgp.Normal(
            loc=qa_loc,
            scale=qa_scale
        )
        # KL(qw_x || px), where px ~ N(0, 1)
        kl = mgp.kl_divergence(self.qw_x, mgp.Normal(0, 1)).sum(-1)
        self.add_loss(kl)
        # Sampling from the network
        fc_samples = self.qw_x.sample() + bias
        if self.act is not None:
            out = self.act(fc_samples)
        else:
            out = fc_samples
        return out

## MNIST classfication with BNN

In [106]:
def train(net, n_epoch, train_iter, test_iter, baseline=False):
    trainer = gluon.Trainer(net.collect_params(), 'adam',
                      {'learning_rate': .001})
    training_loss = []
    validation_loss = []
    loss_func = gluon.loss.SoftmaxCrossEntropyLoss(from_logits=(not baseline))
    metric = mx.gluon.metric.Accuracy()
    for epoch in tqdm_notebook(range(n_epoch), desc='epochs'):
        epoch_loss = 0
        metric.reset()
        for batch in train_iter:
            data = batch[0].as_in_context(model_ctx).reshape(-1, 28 * 28)
            label = batch[1].as_in_context(model_ctx)
            kl_loss = 0
            with autograd.record():
                logits = net(data)
                classification_loss = loss_func(logits, label)
                # `baseline` model stands for deterministic MLP
                if baseline:
                    loss = classification_loss
                else:
                    for layer_kl_loss in net.losses:
                        kl_loss = kl_loss + layer_kl_loss[0]
                    loss = classification_loss + kl_loss / data.shape[0]
            loss.backward()
            trainer.step(data.shape[0])
            epoch_loss += np.mean(classification_loss)
        print(epoch_loss)
        test_loss = 0
        for batch in test_iter:
            data = batch[0].as_in_context(model_ctx).reshape(-1, 28 * 28)
            label = batch[1].as_in_context(model_ctx)
            logits = net(data)
            classification_loss = loss_func(logits, label)
            test_loss += np.mean(classification_loss)
            metric.update([label], [logits.as_nd_ndarray()])
        name, acc = metric.get()
        print('[Epoch %d] Training: %s=%f'%(epoch, name, acc))

        

In [None]:
mlp = nn.HybridSequential()
mlp.add(nn.Dense(256, activation='relu'))
mlp.add(nn.Dense(256, activation='relu'))
mlp.add(nn.Dense(10))
mlp.initialize(ctx=model_ctx)
mlp.hybridize()

In [None]:
batch_size = 256
train_set, test_set = load_data(batch_size)
train(
    net=mlp,
    n_epoch=1,
    train_iter=train_set,
    test_iter=test_set,
    baseline=True)

In [107]:
bnn = StochasticSequential()
bnn.add(LocalReparamDense(784, 256, activation='relu'))
bnn.add(LocalReparamDense(256, 256, activation='relu'))
bnn.add(LocalReparamDense(256, 10))
bnn.initialize(ctx=model_ctx)
bnn.hybridize()

In [108]:
batch_size = 256
train_set, test_set = load_data(batch_size)
train(
    net=bnn,
    n_epoch=2,
    train_iter=train_set,
    test_iter=test_set,
    baseline=False)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, description='epochs', max=2.0, style=ProgressStyle(description_width='…

-21708.85
[Epoch 0] Training: accuracy=0.913200
-43544.555
[Epoch 1] Training: accuracy=0.944400



Next, we perform classfication on white noise to demonstrate the advantages of a Bayesian neural network.

In [None]:
# Generate random noise
x = np.random.randn(28, 28)
plt.imshow(x.asnumpy(), cmap='gray')

In [None]:
num_samples = 100
bnn_prediction = npx.softmax(bnn(np.repeat(np.expand_dims(x, 0), 100, 0).as_in_context(model_ctx))).mean(0).asnumpy()
mlp_prediction = npx.softmax(mlp(np.repeat(np.expand_dims(x, 0), 100, 0).as_in_context(model_ctx))).mean(0).asnumpy()

In [None]:
labels = [str(i) for i in range(10)]
x = np.arange(len(labels)).asnumpy()  # the label locations
width = 0.35  # the width of the bars

fig, ax = plt.subplots()
rects1 = ax.bar(x - width/2, bnn_prediction, width, label='BNN prediction')
rects2 = ax.bar(x + width/2, mlp_prediction, width, label='MLP prediction')

ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()