In [1]:
%load_ext autoreload
%autoreload 2

### Импорты

In [13]:
import sys
sys.path.append('..')
from control_variates.model import MLP
from control_variates.optim import LangevinSGD as SGLD, ScaleAdaSGHMC as H_SA_SGHMC
from mnist_utils import load_mnist_dataset
from control_variates.trainer import BNNTrainer
import torch
from torch.nn import functional as F

import numpy as np
import dill as pickle
from pathlib import Path
from functools import partial

### Параметры обучения

In [6]:
batch_size = 500
input_dim = 784
width = 100
depth = 2
output_dim = 2
lr = 1e-3
n_epoch = 200
alpha0, beta0 = 10, 10
resample_prior_every = 15
resample_momentum_every = 50
burn_in_epochs = 20
save_freq = 2
resample_prior_until = 100

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Берем два класса из МНИСТа

In [7]:
Path('../data', 'mnist').mkdir(exist_ok=True, parents=True)
trainloader, valloader = load_mnist_dataset(Path('../data', 'mnist'), batch_size, [6, 9])

In [8]:
#model = MLP(input_dim=input_dim, width=width, depth=depth, output_dim=output_dim)

In [9]:
from control_variates.model import LogRegression
model = LogRegression(input_dim)

In [14]:
optimizer = SGLD(model.parameters(), lr=lr, alpha0=alpha0, beta0=beta0)
#optimizer = H_SA_SGHMC(model.parameters(), lr=lr, alpha0=alpha0, beta0=beta0)
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=False)

In [15]:
def nll_func(y_hat, y):
    nll = F.cross_entropy(y_hat, y, reduction='sum')
    return nll

def err_func(y_hat, y):
    err = y_hat.argmax(-1).ne(y)
    return err

In [16]:
trainer = BNNTrainer(model, optimizer, nll_func, err_func, trainloader, valloader, device=device, 
    resample_prior_every=resample_prior_every,
    resample_momentum_every=resample_momentum_every,
    save_freq=save_freq,
    batch_size=batch_size
    )

In [17]:
trainer.train(n_epoch=n_epoch, burn_in_epochs=burn_in_epochs, resample_prior_until=resample_prior_until)

2020-08-15 00:53:44,409 Epoch 0 finished. Val loss 0.6892860531806946, Val error 0.005592272496187087
2020-08-15 00:53:45,326 Epoch 1 finished. Val loss 0.6376884579658508, Val error 0.006100660904931368
2020-08-15 00:53:46,223 Epoch 2 finished. Val loss 0.5597677230834961, Val error 0.005592272496187087
2020-08-15 00:53:47,121 Epoch 3 finished. Val loss 0.5181558132171631, Val error 0.004575495678698526
2020-08-15 00:53:48,017 Epoch 4 finished. Val loss 0.4870304763317108, Val error 0.004575495678698526
2020-08-15 00:53:48,898 Epoch 5 finished. Val loss 0.4641351103782654, Val error 0.004575495678698526
2020-08-15 00:53:49,784 Epoch 6 finished. Val loss 0.4641956090927124, Val error 0.006100660904931368
2020-08-15 00:53:50,673 Epoch 7 finished. Val loss 0.4078866243362427, Val error 0.0035587188612099642
2020-08-15 00:53:51,562 Epoch 8 finished. Val loss 0.388039231300354, Val error 0.004067107269954245
2020-08-15 00:53:52,453 Epoch 9 finished. Val loss 0.38133251667022705, Val error 

### Сохраняем сэмплы весов

In [18]:
weights_set = trainer.weight_set_samples[-(n_epoch - resample_prior_until) // save_freq:]

print(len(weights_set))

Path('../saved_samples', 'mnist_weights').mkdir(exist_ok=True, parents=True)
pickle.dump(weights_set, Path('../saved_samples', 'mnist_weights', 'weights.pkl').open('wb'))

50


In [19]:
weights_set = weights_set[::10]

In [20]:
len(weights_set)

5

### CV

In [None]:
from control_variates.cv_utils import state_dict_to_vec
from control_variates.cv_utils import compute_log_likelihood, compute_mc_estimate, compute_naive_variance, compute_tricky_divergence, stein_control_variate
from control_variates.model import get_prediction, get_binary_prediction

In [21]:



squeezed_weights = [state_dict_to_vec(w) for w in weights_set]

models = [LogRegression(input_dim)]*len(weights_set)
for w, model in zip(weights_set, models):
    model.load_state_dict(w)

In [22]:
opt_with_priors = trainer.optimizer

priors = {}
group_params = opt_with_priors.param_groups[0]['params']
for (n, _), p in zip(model.named_parameters(), group_params):  
    state = opt_with_priors.state[p]  
    priors[n] = state['weight_decay']

In [42]:
from control_variates.cv import PsyMLP, PsyDoubleMLP, SteinCV


In [63]:
psy_hidden = 150
psy_depth1 = 3
psy_depth2 = 2
n_iter = 1
psy_lr = 1e-7
psy_input1 = squeezed_weights[0].shape[0]
N_train = len(trainloader.dataset)

In [64]:
x_new, y_new = next(iter(valloader))
#x_new = x_new
train_x, train_y = next(iter(trainloader))

In [65]:
y_new

tensor([1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1,
        1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1,
        1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1,
        0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1,
        0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1,
        0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0,
        0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
        0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1,
        0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0,
        1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0,
        1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1,

### Фитим на батче

In [72]:
psy_model = PsyMLP(psy_input1, psy_hidden, psy_depth1)
psy_model.to(device)

neural_control_variate = SteinCV(psy_model, train_x, train_y, priors, N_train)

ncv_optimizer = torch.optim.Adam(psy_model.parameters(), lr=lr)

In [73]:
function_f = lambda model, x: get_binary_prediction(model, x, classes=[0, 1])

for it in range(n_iter):
    for x in x_new[:50]:
        ncv_optimizer.zero_grad()
        mc_variance, no_cv_variance = compute_naive_variance(function_f, neural_control_variate, models, x)
        print(mc_variance.data, no_cv_variance.data)
        mc_variance.backward()
        ncv_optimizer.step()


tensor([9035.6797]) tensor([0.])
tensor([64628.5117]) tensor([0.])
tensor([536011.8125]) tensor([0.])
tensor([1604206.7500]) tensor([0.])
tensor([4717175.]) tensor([0.])
tensor([12049512.]) tensor([0.])
tensor([25575036.]) tensor([0.])
tensor([47943772.]) tensor([0.])
tensor([81103184.]) tensor([0.])
tensor([1.4376e+08]) tensor([0.])
tensor([2.3611e+08]) tensor([0.])
tensor([3.7919e+08]) tensor([0.])
tensor([5.8015e+08]) tensor([0.])
tensor([8.4431e+08]) tensor([0.])
tensor([1.2331e+09]) tensor([0.])
tensor([1.7539e+09]) tensor([0.])
tensor([2.4182e+09]) tensor([0.])
tensor([3.2671e+09]) tensor([0.])
tensor([4.3580e+09]) tensor([0.])
tensor([5.7454e+09]) tensor([0.])
tensor([7.5150e+09]) tensor([0.])
tensor([9.7118e+09]) tensor([0.])
tensor([1.2454e+10]) tensor([0.])
tensor([1.5841e+10]) tensor([0.])
tensor([2.0003e+10]) tensor([0.])
tensor([2.5103e+10]) tensor([0.])
tensor([3.1429e+10]) tensor([0.])
tensor([3.8911e+10]) tensor([0.])
tensor([4.7887e+10]) tensor([0.])
tensor([5.8609e+10

In [56]:
for n, p in psy_model.named_parameters():
    print(n, p)

block1.0.weight Parameter containing:
tensor([[ 0.0227,  0.0327,  0.0079,  ...,  0.0422,  0.0046, -0.0429],
        [ 0.0402, -0.0279,  0.0282,  ...,  0.0178,  0.1172, -0.0346],
        [ 0.0450, -0.1537,  0.0589,  ...,  0.0482,  0.0731, -0.1426],
        ...,
        [ 0.0268,  0.0192, -0.0255,  ...,  0.0347,  0.0044, -0.0571],
        [ 0.0421, -0.0266, -0.0084,  ...,  0.0345,  0.0288, -0.0010],
        [-0.0019,  0.0370,  0.0134,  ...,  0.0454,  0.0290,  0.0124]],
       requires_grad=True)
block1.0.bias Parameter containing:
tensor([-0.0168,  0.0649,  0.0950,  0.0875,  0.0036, -0.0137,  0.0213,  0.0801,
         0.0089,  0.0109,  0.0843,  0.0825, -0.0136, -0.0168, -0.0124,  0.0028,
         0.0193,  0.1095,  0.1224,  0.0833,  0.1028, -0.0265, -0.0120, -0.0252,
        -0.0203,  0.0040,  0.0098, -0.0048,  0.0139,  0.1094,  0.0117,  0.0665,
        -0.0173,  0.0013,  0.0055,  0.0881, -0.0131,  0.0795,  0.0249,  0.0944,
        -0.0250, -0.0132,  0.0072, -0.0191, -0.0025, -0.0071,  0.