In [None]:
import torch
from torch import nn
import numpy as np
import pyro
import pyro.contrib.bnn as bnn
import pyro.optim
import torch.nn.functional as F
from pyro.infer import SVI, TraceMeanField_ELBO
from torch.distributions import constraints
from pyro import poutine
import pyro.distributions as dist

In [None]:
class BayesianNeuralRegression(nn.Module):
    """
    See variational dropout and the local reparameterization trick paper.
    The pre-activations are sampled directly:
        z ~ N(Xμ, Xσ²)
    For a hidden layer of 1000 in, 1000 out.
    This reduces sampling 1M parameters to 1000 pre-activations.

    This also leads to Gaussian dropout.
    Gaussian dropout; The posterior of the weights

    q(w) = N(Φ, αΦ²) where α = p / (1 - p)
    """
    def __init__(self, input_size, hidden_size, output_size, dropout=0.2):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout = dropout

    def model(self, x, y, kl_factor=1):
        z1_mean = torch.zeros(self.input_size, self.hidden_size)
        z1_scale = torch.ones(self.input_size, self.hidden_size)
        z1_dropout = torch.tensor(gaussian_dropout_alpha(self.dropout))

        # Note the extra column is the bias. Is automatically added.
        z2_mean = torch.zeros(self.hidden_size + 1, self.hidden_size)
        z2_scale = torch.ones(self.hidden_size + 1, self.hidden_size)
        z2_dropout = torch.tensor(gaussian_dropout_alpha(self.dropout))

        z3_mean = torch.zeros(self.hidden_size + 1, self.output_size)
        z3_scale = torch.ones(self.hidden_size + 1, self.output_size)
        z3_dropout = torch.tensor(gaussian_dropout_alpha(0.2))

        with pyro.plate('data', size=x.shape[0]):
            a = pyro.sample('a1',
                             bnn.HiddenLayer(x, z1_mean, z1_dropout * z1_scale,
                                             non_linearity=F.leaky_relu,
                                             KL_factor=kl_factor)
                             )
            a = pyro.sample('a2',
                             bnn.HiddenLayer(a, z2_mean, z2_dropout * z2_scale,
                                             non_linearity=F.leaky_relu,
                                             KL_factor=kl_factor)
                             )
            a = pyro.sample('last_activation',
                             bnn.HiddenLayer(a, z3_mean, z3_dropout * z3_scale,
                                             non_linearity=F.leaky_relu,
                                             KL_factor=kl_factor,
                                             include_hidden_bias=False)
                             )
            sigma = pyro.sample('sigma', dist.HalfNormal(scale=torch.tensor(1.)))
            return pyro.sample('obs', dist.Normal(loc=a, scale=sigma), obs=y)

    def guide(self, x, y=None, kl_factor=1):
        z1_mean = pyro.param('z1_mean', 0.01 * torch.randn(self.input_size, self.hidden_size))
        z1_scale = pyro.param('z1_scale', 0.1 * torch.ones(self.input_size, self.hidden_size),
                              constraint=constraints.greater_than(0.01))
        z1_dropout = pyro.param('z1_dropout', torch.tensor(self.dropout),
                                constraint=constraints.interval(0.1, 1.0))

        z2_mean = pyro.param('z2_mean', 0.01 * torch.randn(self.hidden_size + 1, self.hidden_size))
        z2_scale = pyro.param('z2_scale', 0.1 * torch.ones(self.hidden_size + 1, self.hidden_size),
                              constraint=constraints.greater_than(0.01))
        z2_dropout = pyro.param('z2_dropout', torch.tensor(self.dropout),
                                constraint=constraints.interval(0.1, 1.0))

        z3_mean = pyro.param('z3_mean', 0.01 * torch.randn(self.hidden_size + 1, self.output_size))
        z3_scale = pyro.param('z3_scale', 0.1 * torch.ones(self.hidden_size + 1, self.output_size),
                              constraint=constraints.greater_than(0.01))
        z3_dropout = pyro.param('z3_dropout', torch.tensor(self.dropout),
                                constraint=constraints.interval(0.1, 1.0))
        sigma_scale = pyro.param('sigma_scale', torch.tensor(1.), constraint=constraints.interval(0.01, 2.))

        with pyro.plate('data', size=x.shape[0]):
            a = pyro.sample('a1',
                             bnn.HiddenLayer(x, z1_mean, z1_dropout * z1_scale,
                                             non_linearity=F.leaky_relu,
                                             KL_factor=kl_factor)
                             )
            a = pyro.sample('a2',
                             bnn.HiddenLayer(a, z2_mean, z2_dropout * z2_scale,
                                             non_linearity=F.leaky_relu,
                                             KL_factor=kl_factor)
                             )
            a = pyro.sample('last_activation',
                             bnn.HiddenLayer(a, z3_mean, z3_dropout * z3_scale,
                                             non_linearity=F.leaky_relu,
                                             KL_factor=kl_factor,
                                             include_hidden_bias=False
                                             )
                             )
            pyro.sample('sigma', dist.HalfNormal(scale=sigma_scale))

    def fit(self, x_train, y_train, lr=0.01, epochs=30, batch_size=1024):
        optim = pyro.optim.Adam({'lr': lr})
        elbo = TraceMeanField_ELBO()
        svi = SVI(self.model, self.guide, optim, elbo)
        kl_factor = batch_size / x_train.shape[0]

        train_idx = np.arange(y_train.shape[0])
        max_idx_train = train_idx.shape[0]
        global_step = 0
        losses = []
        for e in range(epochs):
            np.random.shuffle(train_idx)
            c = 0
            print('\nEpoch:', e)
            while c < max_idx_train - batch_size:
                selection = train_idx[c: c + batch_size]
                x = x_train[selection]
                y = y_train[selection]

                loss = svi.step(x, y, kl_factor=kl_factor)

                if c % (100) == 0:
                    loss /= len(selection)
                    losses.append(loss)
                    print("[iteration {:6}] loss: {:.2f}".format(c + 1, loss))

                c += batch_size
                global_step += batch_size
        return losses

    def forward(self, x, n_samples=10):
        with torch.no_grad():
            res = []
            for i in range(n_samples):
                t = poutine.trace(self.guide).get_trace(x, None)
                res.append(t.nodes['last_activation']['value'])
            return torch.stack(res, dim=0)