In [4]:
from __future__ import division, print_function
import time
import torch.utils.data
from torchvision import transforms, datasets
import argparse
import matplotlib
from src.Stochastic_Gradient_HMC_SA.model_binary import BNN_cat
from src.utils import *
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

matplotlib.use('Agg')
import matplotlib.pyplot as plt

In [141]:
parser = argparse.ArgumentParser(description='Train Bayesian Neural Net on Simulated Dataset with Stochastic Gradient HMC')
parser.add_argument('--epochs', type=int, nargs='?', action='store', default=2000,
                    help='How many epochs to train. Default: 20.')
parser.add_argument('--sample_freq', type=int, nargs='?', action='store', default=1,
                    help='How many epochs pass between saving samples. Default: 2.')
parser.add_argument('--burn_in', type=int, nargs='?', action='store', default=40,
                    help='How many epochs to burn in for?. Default: 20.')
parser.add_argument('--lr', type=float, nargs='?', action='store', default=0.01,
                    help='learning rate. I recommend 1e-2. Default: 1e-2.')
parser.add_argument('--models_dir', type=str, nargs='?', action='store', default='SGHMC_models',
                    help='Where to save learnt weights and train vectors. Default: \'SGHMC_models\'.')
parser.add_argument('--results_dir', type=str, nargs='?', action='store', default='SGHMC_results',
                    help='Where to save learnt training plots. Default: \'SGHMC_results\'.')
args = parser.parse_args(args=[])

In [142]:
# Where to save models weights
models_dir = args.models_dir
# Where to save plots and error, accuracy vectors
results_dir = args.results_dir

mkdir(models_dir)
mkdir(results_dir)
# ------------------------------------------------------------------------------------------------------
# train config
NTrainPoints = 800
batch_size = 32
nb_epochs = args.epochs
log_interval = 1
nb_its_dev = log_interval
flat_ims=True
# ------------------------------------------------------------------------------------------------------
# dataset
cprint('c', '\nData:')

# load data

class CustomNNDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        return x, y

with open(os.path.join('./nn.pickle'), 'rb') as f:
  [pr_cov,x,true_input,y] = pickle.load(f)

use_cuda = torch.cuda.is_available()


x = torch.from_numpy(x).type(torch.float)
y = torch.from_numpy(y).type(torch.int).reshape(y.shape[0],)

# split the data into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)

# define the train and validation sets
train_dataset = CustomNNDataset(x_train, y_train)
val_dataset = CustomNNDataset(x_val, y_val)
all_dataset = CustomNNDataset(x, y)

# test set
with open(os.path.join('./nn_test.pickle'), 'rb') as f:
  [pr_cov_test,x_test,true_input_test,y_test] = pickle.load(f)
x_test = torch.from_numpy(x_test).type(torch.float)
y_test = torch.from_numpy(y_test).type(torch.int).reshape(y_test.shape[0],)
test_dataset = CustomNNDataset(x_test, y_test)


if use_cuda:
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True,
                                              num_workers=3)
    valloader = torch.utils.data.DataLoader(all_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
                                            num_workers=3)

else:
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False,
                                              num_workers=3)
    valloader = torch.utils.data.DataLoader(all_dataset, batch_size=batch_size, shuffle=False, pin_memory=False,
                                            num_workers=3)

[36m
Data:[0m


In [143]:
## ---------------------------------------------------------------------------------------------------------------------
# net dims
cprint('c', '\nNetwork:')
from collections import OrderedDict
lr = args.lr
########################################################################################

import torch
net = BNN_cat(NTrainPoints, lr=lr, cuda=False, grad_std_mul=20)
burn_in = args.burn_in
sim_steps = args.sample_freq
N_saves=100000
resample_its = 50
resample_prior_its = 15
re_burn = 1e8


epoch = 0
it_count = 0
## ----------------------------------------------------------------------------------------
# train
cprint('c', '\nTrain:')

print('  init cost variables:')
cost_train = np.zeros(nb_epochs)
err_train = np.zeros(nb_epochs)
cost_dev = np.zeros(nb_epochs)
err_dev = np.zeros(nb_epochs)
best_cost = np.inf
best_err = np.inf
test_accuracy_dev = np.zeros(nb_epochs)


probs_dev = np.zeros([nb_epochs, 1000])
_dev = np.zeros([nb_epochs, 1000])

tic0 = time.time()
for i in range(epoch, nb_epochs):
    net.set_mode_train(True)
    tic = time.time()
    nb_samples = 0
    for x, y in trainloader:

        if flat_ims:
            x = x.view(x.shape[0], -1)
            y = y.unsqueeze(1)

        cost_pred, err = net.fit(x, y, burn_in=(i % re_burn < burn_in),
                                 resample_momentum=(it_count % resample_its == 0),
                                 resample_prior=(it_count % resample_prior_its == 0))
        it_count += 1
        err_train[i] += err
        cost_train[i] += cost_pred
        nb_samples += len(x)

    cost_train[i] /= nb_samples
    err_train[i] /= nb_samples
    toc = time.time()

    # ---- print
    print("it %d/%d, Jtr_pred = %f, err = %f, " % (i, nb_epochs, cost_train[i], err_train[i]), end="")
    cprint('r', '   time: %f seconds\n' % (toc - tic))
    net.update_lr(i)

    # ---- save weights
    if i % re_burn >= burn_in and i % sim_steps == 0:
        net.save_sampled_net(max_samples=N_saves)

    # ---- dev
    if i % nb_its_dev == 0:
        nb_samples = 0
        for j, (x, y) in enumerate(valloader):
            if flat_ims:
                x = x.view(x.shape[0], -1)
                y = y.unsqueeze(1)

            cost, err, probs = net.eval(x, y)

            cost_dev[i] += cost
            err_dev[i] += err
            nb_samples += len(x)
            probs_dev[i, batch_size*j : batch_size*j + len(y)] = probs.view(-1)

        cost_dev[i] /= nb_samples
        err_dev[i] /= nb_samples

        cprint('g', '    Jdev = %f, err = %f\n' % (cost_dev[i], err_dev[i]))
        if err_dev[i] < best_err:
            best_err = err_dev[i]
            cprint('b', 'best test error')

    test_accuracy = 1 - net.eval(test_dataset.x, test_dataset.y.unsqueeze(1))[1].numpy() / test_dataset.y.shape[0]
    test_accuracy_dev[i] = test_accuracy

toc0 = time.time()
runtime_per_it = (toc0 - tic0) / float(nb_epochs)
runtime_total = toc0 - tic0
cprint('r', '   average time: %f seconds\n' % runtime_per_it)
cprint('r', '   total time: %f seconds\n' % runtime_total)

## SAVE WEIGHTS
net.save_weights(models_dir + '/state_dicts.pkl')

save_object(probs_dev, models_dir + '/probs.pkl')

[36m
Network:[0m
[36m
Net:[0m
[33mBNN categorical output[0m
    Total params: 0.00M
[36m
Train:[0m
  init cost variables:
it 0/2000, Jtr_pred = 0.476476, err = 0.280000, [31m   time: 0.231317 seconds
[0m
[36m saving weight samples 1/100000[0m
[32m    Jdev = 0.208839, err = 0.054000
[0m
[34mbest test error[0m
it 1/2000, Jtr_pred = 0.225960, err = 0.068750, [31m   time: 0.215816 seconds
[0m
[36m saving weight samples 2/100000[0m
[32m    Jdev = 0.177885, err = 0.052000
[0m
[34mbest test error[0m
it 2/2000, Jtr_pred = 0.202164, err = 0.058750, [31m   time: 0.233400 seconds
[0m
[36m saving weight samples 3/100000[0m
[32m    Jdev = 0.166427, err = 0.049000
[0m
[34mbest test error[0m
it 3/2000, Jtr_pred = 0.200617, err = 0.061250, [31m   time: 0.246078 seconds
[0m
[36m saving weight samples 4/100000[0m
[32m    Jdev = 0.171181, err = 0.056000
[0m
it 4/2000, Jtr_pred = 0.186485, err = 0.058750, [31m   time: 0.231210 seconds
[0m
[36m saving weight samples

In [150]:
import numpy as np

# Assuming y_hat is your original (10, 1000) array
threshold = 0.5
y_hat_thresholded = np.where(probs >= threshold, 1, 0)
save_object(y_hat_thresholded, models_dir + '/yhats.pkl')

In [153]:
# save result
with open('./SGHMC_models/calibration_bnn.pickle','wb') as f:
    pickle.dump([init_state, state_dicts, yhats, probs],f)
# load result
with open(os.path.join('./SGHMC_models/calibration_bnn.pickle'), 'rb') as f:
  [init_weights, state_dicts, yhats, ps] = pickle.load(f)
init_weights.shape, yhats.shape, ps.shape

((56,), (2000, 1000), (2000, 1000))

In [155]:
np.mean(test_accuracy_dev)

0.9601675000000001