In [1]:
%load_ext autoreload
%autoreload 2

### Импорты

In [2]:
import sys
sys.path.append('..')
from control_variates.model import MLP
from control_variates.optim import LangevinSGD as SGLD, ScaleAdaSGHMC as H_SA_SGHMC
from mnist_utils import load_mnist_dataset
from control_variates.trainer import BNNTrainer
import torch
from torch.nn import functional as F

import numpy as np
import dill as pickle
from pathlib import Path
from functools import partial

### Параметры обучения

In [3]:
batch_size = 500
input_dim = 784
width = 100
depth = 2
output_dim = 2
lr = 1e-3
n_epoch = 200
alpha0, beta0 = 10, 10
resample_prior_every = 15
resample_momentum_every = 50
burn_in_epochs = 20
save_freq = 4
resample_prior_until = 100

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Берем два класса из МНИСТа

In [4]:
Path('../data', 'mnist').mkdir(exist_ok=True, parents=True)
trainloader, valloader = load_mnist_dataset(Path('../data', 'mnist'), batch_size, [3, 5])

In [5]:
#model = MLP(input_dim=input_dim, width=width, depth=depth, output_dim=output_dim)

In [6]:
from control_variates.model import LogRegression
model = LogRegression(input_dim)

In [7]:
#optimizer = SGLD(model.parameters(), lr=lr, alpha0=alpha0, beta0=beta0)
optimizer = H_SA_SGHMC(model.parameters(), lr=lr, alpha0=alpha0, beta0=beta0)
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=False)

In [8]:
def nll_func(y_hat, y):
    nll = F.cross_entropy(y_hat, y, reduction='sum')
    return nll

def err_func(y_hat, y):
    err = y_hat.argmax(-1).ne(y)
    return err

In [9]:
trainer = BNNTrainer(model, optimizer, nll_func, err_func, trainloader, valloader, device=device, 
    resample_prior_every=resample_prior_every,
    resample_momentum_every=resample_momentum_every,
    save_freq=save_freq,
    batch_size=batch_size
    )

In [9]:
trainer.train(n_epoch=n_epoch, burn_in_epochs=burn_in_epochs, resample_prior_until=resample_prior_until)

-08-15 11:50:23,809 Epoch 7 finished. Val loss 0.09057198464870453, Val error 0.033123028391167195
2020-08-15 11:50:26,516 Epoch 8 finished. Val loss 0.09757083654403687, Val error 0.03470031545741325
2020-08-15 11:50:29,020 Epoch 9 finished. Val loss 0.09389007091522217, Val error 0.03785488958990536
2020-08-15 11:50:31,001 Epoch 10 finished. Val loss 0.08884655684232712, Val error 0.03470031545741325
2020-08-15 11:50:32,738 Epoch 11 finished. Val loss 0.09444264322519302, Val error 0.03890641430073607
2020-08-15 11:50:34,488 Epoch 12 finished. Val loss 0.09036420285701752, Val error 0.03785488958990536
2020-08-15 11:50:36,218 Epoch 13 finished. Val loss 0.09465666115283966, Val error 0.035226077812828605
2020-08-15 11:50:37,950 Epoch 14 finished. Val loss 0.09984655678272247, Val error 0.03627760252365931
2020-08-15 11:50:39,702 Epoch 15 finished. Val loss 0.09045618772506714, Val error 0.033648790746582544
2020-08-15 11:50:41,460 Epoch 16 finished. Val loss 0.09838180989027023, Val 

### Сохраняем сэмплы весов

In [10]:
weights_set = trainer.weight_set_samples[-(n_epoch - resample_prior_until) // save_freq:]

print(len(weights_set))

Path('../saved_samples', 'mnist_weights').mkdir(exist_ok=True, parents=True)
pickle.dump(weights_set, Path('../saved_samples', 'mnist_weights', 'weights.pkl').open('wb'))

25


In [10]:
weights_set = pickle.load(Path('../saved_samples', 'mnist_weights', 'weights.pkl').open('rb'))

In [11]:
#weights_set = weights_set[::10]

In [12]:
len(weights_set)

25

### CV

In [107]:
from control_variates.cv_utils import state_dict_to_vec
from control_variates.cv_utils import compute_log_likelihood, compute_mc_estimate, compute_naive_variance, compute_tricky_divergence
from control_variates.model import get_prediction, get_binary_prediction

In [108]:
squeezed_weights = [state_dict_to_vec(w) for w in weights_set]

models = [LogRegression(input_dim) for _ in range(len(weights_set))]
for w, model in zip(weights_set, models):
    model.load_state_dict(w)

In [109]:
opt_with_priors = trainer.optimizer

priors = {}
group_params = opt_with_priors.param_groups[0]['params']
for (n, _), p in zip(model.named_parameters(), group_params):  
    state = opt_with_priors.state[p]  
    priors[n] = state['weight_decay']

pickle.dump(priors, Path('../saved_samples', 'mnist_weights', 'priors.pkl').open('wb'))

KeyError: 'weight_decay'

In [110]:
priors = pickle.load(Path('../saved_samples', 'mnist_weights', 'priors.pkl').open('rb'))

In [111]:
from control_variates.cv import PsyMLP, PsyDoubleMLP, PsyLinear, SteinCV


In [120]:
psy_hidden = 150
psy_depth1 = 3
psy_depth2 = 2
n_iter = 1000
psy_lr = 1e-2
psy_input1 = squeezed_weights[0].shape[0]
N_train = len(trainloader.dataset)

In [133]:

_, new_valloader = load_mnist_dataset(Path('../data', 'mnist'), batch_size, [4, 4])

In [134]:
x_new, y_new = next(iter(new_valloader))

In [121]:
x_new, y_new = next(iter(valloader))
#x_new = x_new
train_x, train_y = next(iter(trainloader))

### Фитим на батче

In [136]:
#psy_model = PsyMLP(psy_input1, psy_hidden, psy_depth1)
psy_model = PsyLinear(psy_input1)
psy_model.init_zero()
psy_model.to(device)

neural_control_variate = SteinCV(psy_model, train_x, train_y, priors, N_train)

ncv_optimizer = torch.optim.Adam(psy_model.parameters(), lr=lr)

In [137]:
len(models)

25

In [138]:
function_f = lambda model, x: get_binary_prediction(model, x, classes=[0, 1])

for it in range(n_iter):
    for x in x_new[30:31]:
        ncv_optimizer.zero_grad()
        mc_variance, no_cv_variance = compute_naive_variance(function_f, neural_control_variate, models[::5], x)
        print(mc_variance.data, no_cv_variance.data)
        mc_variance.backward()
        # for n, p in psy_model.named_parameters():
        #     print(n, p, p.grad)
        ncv_optimizer.step()
        # for n, p in psy_model.named_parameters():
        #     print(n, p)


16])tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([2.6645e-15]) tensor([0.0012])
tensor([3.5527e-15]) tensor([0.0012])
tensor([3.5527e-15]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([1.7764e-15]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([8.8818e-16]) tensor([0.0012])
tensor([3.5527e-15]) tensor([0.0012])
tensor([0.]) tensor([0.0012])
tensor([3.5527e-15]

In [52]:
psy_model.state_dict()

OrderedDict([('layer.weight',
              tensor([[-0.0445,  0.0381,  0.0149,  ...,  0.0248,  0.0312, -0.0313]]))])

In [76]:
torch.autograd.grad(psy_model.forward(squeezed_weights[0], None), squeezed_weights[0])

(tensor([-0.0049, -0.0049, -0.0049,  ...,  0.0049,  0.0049, -0.0049]),)

In [77]:
x_new.requires_grad = True

In [86]:
torch.autograd.grad(psy_model.forward(squeezed_weights[0], None), squeezed_weights[0])

(tensor([-0.0049, -0.0049, -0.0049,  ...,  0.0049,  0.0049, -0.0049]),)