## Tasks
We will implement a Bayesian Neural Network (BNN) and fit it to toy data.

#### Todo:
* Implement the `VariationalLinear` class (in **Create Layers** section)
* Implement the `VariationalFlow` class (in **Create Flow Layers** section)
* Train mean-field & flow-based BNNs and compare
* **Bonus:** Study the effect of some parameter (e.g. flow architecture, NN architecture, initialization, etc.) and improve performance.

## Create Dataset

In [None]:
!pip install einops

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import Dataset, DataLoader
from einops import rearrange, reduce, repeat
import plotly.express as px
import plotly.graph_objects as go

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def plot_preds(dataset, model=None, title=None):
    fig = go.Figure()
    fig.update_layout(title=title, title_x=0.5)
    fig.add_trace(go.Scatter(x=dataset.x[:,0], y=dataset.y[:,0], mode='markers', name='data'))
    if model:
      x = torch.linspace(-5,5,steps=200).to(device)
      ys = [model(x.unsqueeze(-1))[0].detach()[:,0] for _ in range(10)]
      for y in ys: fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name='pred', marker = {'color' : '#EF553B'}))
    return fig

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
class ToyData(Dataset):
  def __init__(self, num, sigma=0.5):
      self.x = 8 * torch.rand(num, 1) - 4
      self.y  = 0.1 * self.x**3 - self.x + sigma * torch.randn(num, 1)
      self.num = num
      self.sigma = sigma

  def __getitem__(self, i):
      return self.x[i], self.y[i]

  def __len__(self):
      return len(self.x)

torch.manual_seed(42)
dataset = ToyData(128)
testset = ToyData(1024)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
plot_preds(dataset)

## Create Layers

In [None]:
class StandardGaussian(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.register_buffer('loc', torch.zeros(dim))
        self.register_buffer('scale', torch.ones(dim))

    def log_prob(self, x):
        return Normal(self.loc, self.scale).log_prob(x).sum(1)

    def sample(self, num_samples):
        return Normal(self.loc, self.scale).sample((num_samples,))

    def sample_with_log_prob(self, num_samples):
        x = self.sample(num_samples)
        return x, self.log_prob(x)


class Gaussian(nn.Module):

    def __init__(self, dim, mean_std_init=0.05, std_init=0.05):
        super().__init__()
        self.loc = nn.Parameter(mean_std_init*torch.randn(dim))
        self.log_scale = nn.Parameter(np.log(std_init)*torch.ones(dim))

    def sample_with_log_prob(self, num_samples):
        d = Normal(self.loc, self.log_scale.exp())
        x = d.rsample((num_samples,))
        return x, d.log_prob(x).sum(1)

In [None]:
def split_wb(wb, i, o):
    w, b = torch.split_with_sizes(wb, (i*o, o))
    return w.reshape(o, i), b


class VariationalLinear(nn.Module):

    def __init__(self, q, p, dim_in, dim_out):
        super().__init__()
        self.q = q # Variational distribution
        self.p = p # Prior distribution
        self.dim_in = dim_in
        self.dim_out = dim_out

    def forward(self, x):
        '''TODO: Implement w, b & kl'''
        '''Hint: Use p, q & the split_wb function above'''
        return F.linear(x, w, b), kl

    
# Test VariationalLinear
m = VariationalLinear(Gaussian(10*5+5), StandardGaussian(10*5+5), 10, 5)
m(torch.randn(1, 10))

(tensor([[ 0.4509,  0.6837,  2.1835, -4.6944,  7.5346]]), tensor([0.]))

## Create Flow Layers

In [None]:
def build_mask(in_dim, out_dim, ar_dim, causal=False):
    assert in_dim % ar_dim == 0; assert out_dim % ar_dim == 0
    base = torch.ones(ar_dim, ar_dim)
    base = base.tril(-1) if causal else base.tril(0)
    return repeat(base, 'h w -> (o h) (i w)', i=in_dim // ar_dim, o=out_dim // ar_dim)


class MaskedLinear(nn.Module):

    def __init__(self, in_dim, out_dim, ar_dim, causal=False):
        super().__init__()
        self.register_buffer('mask', build_mask(in_dim, out_dim, ar_dim, causal))
        self.weight = nn.Parameter(torch.empty(out_dim, in_dim))
        self.bias = nn.Parameter(torch.empty(out_dim))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_normal_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return F.linear(x, self.weight * self.mask, self.bias)


def make_net(dim): 
    return nn.Sequential(
        MaskedLinear(dim,   dim*4, ar_dim=dim), nn.GELU(),
        MaskedLinear(dim*4, dim*4, ar_dim=dim), nn.GELU(),
        MaskedLinear(dim*4, dim*2, ar_dim=dim, causal=True),
    )

In [None]:
class Reverse(nn.Module):

    def forward(self, x):
        return x.flip(-1), x.new_zeros(x.shape[0])

    def inverse(self, z):
        return z.flip(-1)


class IAF(nn.Module):

    def __init__(self, ar_net):
        super().__init__()
        self.ar_net = ar_net

    def forward(self, x):
        p = self.ar_net(x)
        m, u = torch.chunk(p, 2, dim=1)
        u = u + 1.0 # Improve initialization
        s, log_s = torch.sigmoid(u), F.logsigmoid(u)
        z = s * x + (1-s) * m
        ldj = log_s.sum(1)
        return z, ldj


class VariationalFlow(nn.Module):

    def __init__(self, base_dist, bijections):
        super().__init__()
        self.base_dist = base_dist
        self.bijections = nn.ModuleList(bijections)

    def sample_with_log_prob(self, num_samples):
        '''TODO: Implement z, log_prob'''
        '''Hint: Sample from base_dist, loop through bijections'''
        return z, log_prob


def make_flow(dim):
    return VariationalFlow(
        base_dist=StandardGaussian(dim=dim),
        bijections=[
            IAF(make_net(dim)), Reverse(),
            IAF(make_net(dim)), Reverse(),
            IAF(make_net(dim)), Reverse(),
            IAF(make_net(dim)),
        ],
    )

# Test VariationalFlow
flow = make_flow(5)
flow.sample_with_log_prob(2)

(tensor([[ 0.4556,  0.0111,  0.0035,  0.2923, -0.1062],
         [-0.5595,  0.1934, -0.0331, -0.0753,  0.4863]], grad_fn=<AddBackward0>),
 tensor([-0.1170, -0.5190], grad_fn=<SubBackward0>))

## Create BNN

In [None]:
class BNN(nn.Module):

    def __init__(self, I, O, H, scale, num, mean_field=True):
        super().__init__()
        D = lambda i,o: i*o+o
        q = lambda i,o: Gaussian(D(i,o), mean_std_init=1/np.sqrt(i)) if mean_field else make_flow(D(i,o))
        self.layer0 = VariationalLinear(q(I,H), StandardGaussian(D(I,H)), I, H)
        self.layer1 = VariationalLinear(q(H,O), StandardGaussian(D(H,O)), H, O)
        self.scale = scale
        self.num = num

    def forward(self, x):
        h, kl0 = self.layer0(x)
        h = F.gelu(h)
        h, kl1 = self.layer1(h)
        return h, kl0+kl1

    def elbo(self, x, y):
        yhat, kl = self(x)
        return Normal(yhat, self.scale).log_prob(y).mean() - kl/self.num

get_model = lambda mean_field: BNN(1, 1, 32, scale=dataset.sigma, num=dataset.num, mean_field=mean_field).to(device)

In [None]:
def train(model, epochs=1000):
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

  print('Training...')
  model = model.train()
  elbos = []
  for epoch in range(epochs):
      loss_sum = 0.0
      for i, (x,y) in enumerate(train_loader):
          x, y = x.to(device), y.to(device)
          optimizer.zero_grad()
          loss = -model.elbo(x, y)
          loss.backward()
          optimizer.step()
          loss_sum += loss.detach().cpu().item()
      elbos.append(loss_sum/len(train_loader))
      if (epoch+1) % 100 == 0:
        print('Epoch: {}/{}, Loss: {:.3f}'.format(epoch+1, epochs, elbos[-1]))
  model = model.eval()
  return model, elbos

## Train Mean-Field BNN

In [None]:
torch.manual_seed(42)
model_mf, elbos_mf = train(get_model(True))
plot_preds(dataset, model_mf, title='Mean-Field BNN')

Training...
Epoch: 100/1000, Loss: 4.521
Epoch: 200/1000, Loss: 3.637
Epoch: 300/1000, Loss: 3.533
Epoch: 400/1000, Loss: 3.192
Epoch: 500/1000, Loss: 2.846
Epoch: 600/1000, Loss: 2.807
Epoch: 700/1000, Loss: 2.817
Epoch: 800/1000, Loss: 2.699
Epoch: 900/1000, Loss: 2.565
Epoch: 1000/1000, Loss: 2.480


## Train Flow-Based BNN

In [None]:
torch.manual_seed(42)
model_flow, elbos_flow = train(get_model(False))
plot_preds(dataset, model_flow, title='Flow-Based BNN')

Training...
Epoch: 100/1000, Loss: 4.943
Epoch: 200/1000, Loss: 3.913
Epoch: 300/1000, Loss: 3.706
Epoch: 400/1000, Loss: 2.041
Epoch: 500/1000, Loss: 2.156
Epoch: 600/1000, Loss: 1.848
Epoch: 700/1000, Loss: 2.123
Epoch: 800/1000, Loss: 1.870
Epoch: 900/1000, Loss: 1.732
Epoch: 1000/1000, Loss: 1.984


## Compare Models

In [None]:
df = pd.DataFrame({'mean_field':elbos_mf, 'flow':elbos_flow})
px.line(df, y=['mean_field','flow'])

In [None]:
print(f'Mean-Field: Avg. loss over last 100 epochs: {np.mean(elbos_mf[-100:]):.4f}')
print(f'Flow:       Avg. loss over last 100 epochs: {np.mean(elbos_flow[-100:]):.4f}')

Mean-Field: Avg. loss over last 100 epochs: 2.5634
Flow:       Avg. loss over last 100 epochs: 1.8884


In [None]:
test_mse_mf = [F.mse_loss(model_mf(testset.x)[0], testset.y).detach().item() for _ in range(100)]
test_mse_flow = [F.mse_loss(model_flow(testset.x)[0], testset.y).detach().item() for _ in range(100)]

print(f'Mean-Field: Test MSE: {np.mean(test_mse_mf):.4f}')
print(f'Flow:       Test MSE: {np.mean(test_mse_flow):.4f}')

Mean-Field: Test MSE: 0.4711
Flow:       Test MSE: 0.4392
