In [1]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from dataloader.builder import build_dataset
from experiment_setup import build_estimator
from uncertainty_estimator.masks import build_masks, DEFAULT_MASKS

from model.cnn import SimpleConv, Trainer

In [3]:
config = {
    'use_cuda': True,
    'batch_size': 128,
    'epochs': 3,
    'lr': 1e-2,
    'momentum': 0.5,
    'seed': 1,
    'log_interval': 10,
    'nn_runs': 10,
    'dropout_uq': 0.5
}



#### Load data and preprocess

In [4]:
mnist = build_dataset('mnist', val_size=10_000)
x_train, y_train = mnist.dataset('train')
x_val, y_val = mnist.dataset('val')

ood = build_dataset('fashion_mnist', val_size=0)
x_ood, _ = ood.dataset('train') 

In [5]:
x_train = x_train.reshape(-1, 1, 28, 28)
x_val = x_val.reshape(-1, 1, 28, 28)
x_ood = x_ood.reshape(-1, 1, 28, 28)
y_train = y_train.astype('long').reshape(-1)
y_val = y_val.astype('long').reshape(-1)
x_train /= 255.0
x_val /= 255.0
x_ood /= 255.0

#### Train model

In [6]:
train_samples = 5000
model = SimpleConv()
trainer = Trainer(model)
trainer.fit(x_train[:train_samples], y_train[:train_samples], epochs=config['epochs'])
accuracy_score(y_val, trainer.predict(x_val))


#### BALD


In [7]:
masks = build_masks(['vanilla', 'dpp'])
# masks = build_masks(DEFAULT_MASKS)

In [11]:
# estimation_samples = 50 
# uqs, datasets, mask_type = [], [], []
# 
# for mask_name, mask in masks.items():
#     estimator = build_estimator(
#         'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
#         dropout_rate=config['dropout_uq'], num_classes=10)
#     
#     for data_name, x_current in (('train', x_train), ('val', x_val), ('ood', x_ood)):
#         uq = estimator.estimate(x_current[:estimation_samples])
#         uqs = np.concatenate((uqs, uq))
#         datasets = np.concatenate((datasets, [data_name]*estimation_samples))
#         mask_type = np.concatenate((mask_type, [mask_name]*estimation_samples))
# 
#     
# df = pd.DataFrame({'uq': uqs, 'dataset': datasets, 'mask_type': mask_type})
# sns.boxplot(data=df, x='dataset',  y='uq', hue='mask_type')

In [13]:
mask = masks['dpp']
estimator = build_estimator(
    'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
    dropout_rate=config['dropout_uq'], num_classes=10)
# estimator.reset()


In [17]:
estimator.estimate(x_train[:10])