In [1]:
#%load_ext autoreload
#%autoreload 2

In [1]:
import sys
sys.path.append('..')
from control_variates.model import MLP
from control_variates.optim import LangevinSGD as SGLD, ScaleAdaSGHMC as H_SA_SGHMC
from mnist_utils import load_mnist_dataset
from control_variates.trainer import BNNTrainer
import torch
from torch.nn import functional as F

import numpy as np
import dill as pickle
from pathlib import Path
from functools import partial

In [2]:
batch_size = 500
input_dim = 784
width = 100
depth = 2
output_dim = 10
lr = 1e-3
n_epoch = 300
alpha0, beta0 = 10, 10
resample_prior_every = 15
resample_momentum_every = 50
burn_in_epochs = 20
save_freq = 2
resample_prior_until = 100

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
Path('../data', 'mnist').mkdir(exist_ok=True, parents=True)
trainloader, valloader = load_mnist_dataset(Path('../data', 'mnist'), batch_size)

In [4]:
model = MLP(input_dim=input_dim, width=width, depth=depth, output_dim=output_dim)

In [5]:
#optimizer = SGLD(model.parameters(), lr=lr, alpha0=alpha0, beta0=beta0)
optimizer = H_SA_SGHMC(model.parameters(), lr=lr, alpha0=alpha0, beta0=beta0)
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=False)

In [6]:
def nll_func(y_hat, y):
    nll = F.cross_entropy(y_hat, y, reduction='sum')
    return nll

def err_func(y_hat, y):
    err = y_hat.argmax(-1).ne(y)
    return err

In [7]:
trainer = BNNTrainer(model, optimizer, nll_func, err_func, trainloader, valloader, device=device, 
    resample_prior_every=resample_prior_every,
    resample_momentum_every=resample_momentum_every,
    save_freq=save_freq,
    batch_size=batch_size
    )

In [8]:
trainer.train(n_epoch=n_epoch, burn_in_epochs=burn_in_epochs, resample_prior_until=resample_prior_until)

 loss 0.06701449304819107, Val error 0.0199
2020-08-14 12:53:41,426 Epoch 80 finished. Val loss 0.07195600867271423, Val error 0.0225
2020-08-14 12:53:49,692 Epoch 81 finished. Val loss 0.07047393172979355, Val error 0.0203
2020-08-14 12:53:58,345 Epoch 82 finished. Val loss 0.0689258798956871, Val error 0.0214
2020-08-14 12:54:07,573 Epoch 83 finished. Val loss 0.07056793570518494, Val error 0.0219
2020-08-14 12:54:16,372 Epoch 84 finished. Val loss 0.06595133990049362, Val error 0.0207
2020-08-14 12:54:27,838 Epoch 85 finished. Val loss 0.0649980679154396, Val error 0.0202
2020-08-14 12:54:37,796 Epoch 86 finished. Val loss 0.07060126960277557, Val error 0.0218
2020-08-14 12:54:49,623 Epoch 87 finished. Val loss 0.0725543275475502, Val error 0.0226
2020-08-14 12:54:58,163 Epoch 88 finished. Val loss 0.05860758200287819, Val error 0.0185
2020-08-14 12:55:06,278 Epoch 89 finished. Val loss 0.06630147993564606, Val error 0.0204
2020-08-14 12:55:14,271 Epoch 90 finished. Val loss 0.06601

In [352]:
weights_set = trainer.weight_set_samples[-(n_epoch - resample_prior_until) // save_freq:]

print(len(weights_set))

Path('../saved_samples', 'mnist_weights').mkdir(exist_ok=True, parents=True)
pickle.dump(weights_set, Path('../saved_samples', 'mnist_weights', 'weights.pkl').open('wb'))

100


In [353]:
weights_set = weights_set[::10]

In [354]:
len(weights_set)

10

In [406]:
def state_dict_to_vec(state_dict):
    return torch.cat([w_i.view(-1) for w_i in state_dict.values()])

squeezed_weights = [state_dict_to_vec(w) for w in weights_set]

models = [MLP(input_dim=input_dim, width=width, depth=depth, output_dim=output_dim) for w in weights_set]
for w, model in zip(weights_set, models):
    model.load_state_dict(w)

In [435]:
opt_with_priors = trainer.optimizer

priors = {}
group_params = opt_with_priors.param_groups[0]['params']
for (n, _), p in zip(model.named_parameters(), group_params):  
    state = opt_with_priors.state[p]  
    priors[n] = state['weight_decay']

In [436]:
def get_prediction(x, model):
    return F.softmax(model(x), dim=-1)

def get_binary_prediction(x, model, classes):
    assert len(classes) == 2
    return F.softmax(model(x)[..., classes], dim=-1)[..., -1]

In [437]:
def compute_log_likelihood(x, y, model):
    y_hat = model(x)
    log_likelihood = -F.cross_entropy(y_hat, y, reduction='mean')
    return log_likelihood

def compute_mc_estimate(function: callable, models, x: torch.tensor):
    res = 0.0
    for model in models:
        res += function(x, model)
    return res / len(models)

def compute_naive_variance(function:callable, control_variate: callable, models, x: torch.tensor):
    sample_mean = compute_mc_estimate(lambda x_, model: function(x_, model) - control_variate(x_, model), models, x)
    v = 0
    for model in models:
        v += (function(x, model) - control_variate(x, model) - sample_mean)**2 / (len(models) - 1)
    return v

In [476]:
def compute_tricky_divergence(model):
    div = 0
    for n, p in model.named_parameters():
        if p.grad is not None:
            div += p.grad.sum()

    return div

def stein_control_variate(psy_model, model, train_x, train_y, new_x, priors, N_train):
    model.zero_grad()
    log_likelihood = compute_log_likelihood(train_x, train_y, model) * N_train
    log_likelihood.backward()
    
    ncv_value = 0
    psy_value = psy_model(state_dict_to_vec(model.state_dict()), x_new)
    psy_value.backward(retain_graph=True)
    psy_div = compute_tricky_divergence(psy_model)

    for n, p in model.named_parameters():
        if p.grad is not None:
            d_p = p.grad.data
            d_p.add_(p.data, alpha=-priors[n])

            ncv_value += d_p.sum()

    ncv_value = torch.cat([ncv_value.view(1)]*psy_value.shape[0], dim=0)

    ncv_value *= psy_value
    ncv_value += psy_div

    return ncv_value



In [485]:
from torch import nn

class PsyLinear(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.layer = nn.Linear(input_dim, 1)

    def forward(self, weights, x):
        return  self.layer(weights)

class PsyMLP(nn.Module):
    def __init__(self, input_dim, width, depth):
        super().__init__()

        self.input_dim = input_dim
        self.width = width
        self.depth = depth

        layers = [nn.Linear(input_dim, width), nn.ReLU()]
        for i in range(depth - 1):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(width, 1, bias=False))
        #layers.append(nn.Tanh())

        self.block = nn.Sequential(*layers)

        #for p in self.parameters():
        #    torch.nn.init.zeros_(p)

    def forward(self, weights, x):
        return self.block(weights)


In [486]:
psy_hidden = 30
n_iter = 100
psy_lr = 1e-4
psy_input = squeezed_weights[0].shape[0]
N_train = batch_size * len(trainloader)

In [487]:
x_new, y_new = next(iter(valloader))
x_new = x_new[:1]
train_x, train_y = next(iter(trainloader))

In [491]:
psy_model = PsyMLP(psy_input, psy_hidden, 1)
psy_model.to(device)

neural_control_variate = lambda x, model : stein_control_variate(psy_model, model, train_x, train_y, x, priors, N_train)

ncv_optimizer = torch.optim.Adam(psy_model.parameters(), lr=lr)

In [492]:
psy_model(squeezed_weights[0], x_new)

tensor([0.0068], grad_fn=<SqueezeBackward3>)

In [493]:
function_f = lambda x, model: get_binary_prediction(x, model, classes=[3, 5])

for it in range(n_iter):
    ncv_optimizer.zero_grad()
    mc_variance = compute_naive_variance(function_f, neural_control_variate, models, x_new)
    print(mc_variance)
    mc_variance.backward()
    ncv_optimizer.step()


tensor([11993.8486], grad_fn=<AddBackward0>)
tensor([2.7918e+09], grad_fn=<AddBackward0>)
tensor([1.7531e+08], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tensor([0.0232], grad_fn=<AddBackward0>)
tens

In [None]:
for n, p in psy_model.named_parameters():
    print(p)