In [1]:
#%load_ext autoreload
#%autoreload 2

In [1]:
from model import MLP
from optimizers import SGLD, H_SA_SGHMC
from mnist_utils import load_mnist_dataset
from trainer import BNNTrainer
import torch
from torch.nn import functional as F

import numpy as np
import dill as pickle
from pathlib import Path

In [2]:
batch_size = 500
input_dim = 784
width = 100
depth = 2
output_dim = 10
lr = 1e-3
n_epoch = 100
alpha0, beta0 = 10, 10
resample_prior_every = 15
resample_momentum_every = 50
burn_in_epochs = 10
save_freq = 2
resample_prior_until = 50

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
trainloader, valloader = load_mnist_dataset('data', batch_size)

In [4]:
model = MLP(input_dim=input_dim, width=width, depth=depth, output_dim=output_dim)

In [5]:
#optimizer = SGLD(model.parameters(), lr=lr, alpha0=alpha0, beta0=beta0)
optimizer = H_SA_SGHMC(model.parameters(), lr=lr, alpha0=alpha0, beta0=beta0)
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=False)

In [6]:
def nll_func(y_hat, y):
    nll = F.cross_entropy(y_hat, y, reduction='sum')
    return nll

def err_func(y_hat, y):
    err = y_hat.argmax(-1).ne(y)
    return err

In [7]:
trainer = BNNTrainer(model, optimizer, nll_func, err_func, trainloader, valloader, device=device, 
    resample_prior_every=resample_prior_every,
    resample_momentum_every=resample_momentum_every,
    save_freq=save_freq,
    batch_size=batch_size
    )

In [8]:
trainer.train(n_epoch=n_epoch, burn_in_epochs=burn_in_epochs, resample_prior_until=resample_prior_until)

2020-08-14 10:31:14,779 Epoch 0 finished. Val loss 0.17838871479034424, Val error 0.0536
2020-08-14 10:31:23,712 Epoch 1 finished. Val loss 0.14581242203712463, Val error 0.0433
2020-08-14 10:31:35,410 Epoch 2 finished. Val loss 0.1339890956878662, Val error 0.0406
2020-08-14 10:31:45,928 Epoch 3 finished. Val loss 0.12925595045089722, Val error 0.0386
2020-08-14 10:31:56,968 Epoch 4 finished. Val loss 0.1298900991678238, Val error 0.0385
2020-08-14 10:32:07,354 Epoch 5 finished. Val loss 0.14497973024845123, Val error 0.0375
2020-08-14 10:32:15,923 Epoch 6 finished. Val loss 0.14951393008232117, Val error 0.0391
2020-08-14 10:32:26,081 Epoch 7 finished. Val loss 0.14880023896694183, Val error 0.0386
2020-08-14 10:32:35,523 Epoch 8 finished. Val loss 0.15642748773097992, Val error 0.0399


KeyboardInterrupt: 

In [None]:
opt_with_priors = trainer.optimizer

In [None]:
weights_set = trainer.weight_set_samples[-(n_epoch - resample_prior_until) // 2:]
pickle.dump(weights_set, Path('weights.pkl').open('wb'))

In [None]:
def state_dict_to_vec(state_dict):
    return torch.cat([w_i.view(-1) for w_i in state_dict.values()])

In [13]:
squeezed_weights = [state_dict_to_vec(w) for w in weights_set]

In [None]:
def get_prediction(x, model):
    return F.softmax(model(x), dim=-1)

In [None]:
def get_binary_prediction(x, model, classes):
    assert len(classes) == 2
    return F.softmax(model(x)[..., classes], dim=-1)

In [15]:
models = [MLP(input_dim=input_dim, width=width, depth=depth, output_dim=output_dim) for w in weights_set]
for w, model in zip(weights_set, models):
    model.load_state_dict(w)

In [None]:
def compute_mc_estimate(function: callable, models, x: torch.tensor):
    res = 0.0
    for model in models:
        res += function(x, model)
    return res / len(models)

In [69]:
def compute_naive_variance(function:callable, control_variate: callable, models, x: torch.tensor):
    return torch.sum(torch.tensor([(function(x, model) - control_variate(x, model))**2 for model in models])) / (len(models) - 1)

In [None]:
def stein_control_variate(phi_function, x, y, model):
    log_likelihood = compute_log_likelihood(x, y, model)
    log_likelihood.backward()
    weight = state_dict_to_vec(model.state_dict())
    phi_weigth = phi_function(weight, x)
    phi_weight.backward()
    
    control_variate = 


In [None]:
def compute_log_likelihood(x, y, model):
    assert len(classes) = 2
    y_hat = model(x)
    log_likelihood = -F.cross_entropy(y_hat, y, reduction='sum')

    return log_likelihood

In [1]:
from torch import nn

class LinearPhi(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.layer = nn.Linear(input_dim, input_dim)

    def forward(self, weights, x):
        return  self.layer(weights)

class BottleneckPhi(nn.Module):
    def __init__(self, input_dim, hidden_dim, depth=2):
        super().__init__()
        self.layer_1 = nn.Linear(input_dim, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, input_dim)


    def forward(self, weights, x):
        return  self.layer(weights)

In [49]:
phi_linear = LinearPhi(squeezed_weights[0].shape[0])

phi_linear.to(device)

In [None]:
out = phi_linear(squeezed_weights[0])

In [None]:
out

In [41]:
out.backward()

In [None]:
def 