In [None]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score
from scipy.special import softmax as softmax
import matplotlib.pyplot as plt

from dataloader.builder import build_dataset
from experiment_setup import build_estimator
from uncertainty_estimator.masks import build_masks, DEFAULT_MASKS
from analysis.metrics import uq_ndcg

from model.cnn import SimpleConv, MediumConv, StrongConv
from model.trainer import Trainer, EnsembleTrainer

torch.cuda.set_device(1)

In [None]:
plt.rcParams['figure.facecolor'] = 'white'

model_setups = {
    'mnist': {
        'model_class': SimpleConv,
        'train_samples': 5000,
        'epochs': 5,
        'batch_size': 256,
        'log_interval': 10,
        'lr': 1e-2,
        'num_classes': 10
    },
    'cifar_10': {
        'model_class': StrongConv,
        'train_samples': 45_000,
        'epochs': 50,
        'batch_size': 256,
        'log_interval': 150,
        'lr': 1e-2,
        'num_classes': 9
    }
}

config = {
    'use_cuda': True,
    'seed': 1,
    
    'nn_runs': 150,
    'patience': 5,
    'dropout_uq': 0.5,
    
    'n_models': 3, 
    
    'dataset': 'mnist',
    # 'dataset': 'cifar_10',
    
    'model_runs': 1,
    'repeat_runs': 1,
    
    'activation': torch.nn.functional.celu
}

config.update(model_setups[config['dataset']])



#### Load data and preprocess

In [None]:
dataset = build_dataset(config['dataset'], val_size=10_000)
x_train, y_train = dataset.dataset('train')
x_val, y_val = dataset.dataset('val')

def cut_class(x, y, class_num):
    new_x = x[np.where(y!=class_num)]
    new_y = y[np.where(y!=class_num)]
    ood = x[np.where(y==class_num)]
    return new_x, new_y, ood

if config['dataset'] == 'mnist':
    ood = build_dataset('fashion_mnist', val_size=0)
    x_ood, _ = ood.dataset('train') 
elif config['dataset'] == 'cifar_10':
    x_train, y_train, x_ood = cut_class(x_train, y_train, '9')
    x_val, y_val, _ = cut_class(x_val, y_val, '9')



In [None]:
def scale(images):
    return (images - 128) / 128
x_train = scale(x_train)
x_val = scale(x_val)
x_ood = scale(x_ood)

In [None]:
if config['dataset'] == 'mnist':
    input_shape = (-1, 1, 28, 28)
elif config['dataset'] == 'cifar_10':
    input_shape = (-1, 3, 32, 32)
x_train = x_train.reshape(input_shape)
x_val = x_val.reshape(input_shape)
x_ood = x_ood.reshape(input_shape)

y_train = y_train.astype('long').reshape(-1)
y_val = y_val.astype('long').reshape(-1)



#### Train model

In [None]:
def retrain(
        train_samples, n_models=config['n_models'], epochs=config['epochs'],
        val_samples=2000, patience=config['patience']):
    idxs = np.random.choice(len(x_train), train_samples, replace=False)
    train_set = (x_train[idxs], y_train[idxs])
    idxs = np.random.choice(len(x_val), val_samples, replace=False)
    val_set = (x_val[idxs], y_val[idxs]) 
    
    model_class = config['model_class'] 
    model = model_class(config['num_classes'], activation=config['activation'])
    trainer = Trainer(model)
    trainer.fit(
        train_set, val_set , epochs=epochs, verbose=True, patience=patience)

    ensemble = EnsembleTrainer(
        model_class, {'num_classes': config['num_classes'], 'activation': config['activation']}, n_models)
    ensemble.fit(train_set, val_set, epochs=epochs, patience=patience, verbose=True)
    
    return trainer, ensemble

def ll(trainer, x, y):
    trainer.eval()
    logits = trainer(x).detach().cpu().numpy()
    probs = softmax(logits, axis=-1)[np.arange(len(x)), y]
    return np.log(probs)

unique, counts = np.unique(y_train[:config['train_samples']], return_counts=True)
dict(zip(unique, counts))

In [None]:
trainer, ensemble = retrain(config['train_samples'], n_models=config['n_models'])

In [None]:
print('Model accuracy train', accuracy_score(y_train[:3000], trainer.predict(x_train[:3000])))
print('Model accuracy val', accuracy_score(y_val[:3000], trainer.predict(x_val[:3000])))
print('Ensemble accuracy', accuracy_score(y_val[:3000], ensemble.predict(x_val[:3000])))


### BALD

#### UQ by different masks

In [None]:
masks = build_masks(DEFAULT_MASKS, nn_runs=config['nn_runs'])

In [None]:
estimation_samples = 5000 
uqs, datasets, mask_type = [], [], []

print('Ensemble')        
estimator = build_estimator('bald_ensemble', ensemble, num_classes=config['num_classes'])
for data_name, x_current in (('train', x_train), ('val', x_val), ('ood', x_ood)):
    uq = estimator.estimate(x_current[:estimation_samples])
    uqs = np.concatenate((uqs, uq))
    datasets = np.concatenate((datasets, [data_name]*estimation_samples))
    mask_type = np.concatenate((mask_type, ['ensemble']*estimation_samples))
    
for mask_name, mask in masks.items():
    print(mask_name)
    estimator = build_estimator(
        'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
        dropout_rate=config['dropout_uq'], num_classes=config['num_classes'])

    for data_name, x_current in (('train', x_train), ('val', x_val), ('ood', x_ood)):
        uq = estimator.estimate(x_current[:estimation_samples])
        uqs = np.concatenate((uqs, uq))
        datasets = np.concatenate((datasets, [data_name]*estimation_samples))
        mask_type = np.concatenate((mask_type, [mask_name]*estimation_samples))
        estimator.reset()
        

In [None]:
plt.figure(figsize=(16, 9))

plt.title('OOD sanity check')
df = pd.DataFrame({'uq': uqs, 'dataset': datasets, 'mask_type': mask_type})
sns.boxplot(data=df, x='mask_type',  y='uq', hue='dataset')


#### LL prediction by UQ

In [None]:
pool_size = 300
x_pool, y_pool = x_val[:pool_size], y_val[:pool_size]
pool_ll = ll(trainer, x_pool, y_pool)
    

In [None]:
plt.figure(figsize=(16, 18))
for name, mask in masks.items():
    estimator = build_estimator(
        'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
        dropout_rate=config['dropout_uq'], num_classes=config['num_classes'])
    estimations = estimator.estimate(x_pool)
    estimator.reset()
    plt.scatter(estimations, pool_ll, label=name, alpha=0.5)
    plt.xlabel('Uncertainty estimation')
    plt.ylabel('Log likelihood')
plt.legend(loc='lower right')
    

In [None]:
plt.figure(figsize=(15, 20))
num = (len(masks) + 3) // 2
for i, (name, mask) in enumerate(masks.items()):
    plt.subplot(num, 2, i+2)
    estimator = build_estimator(
        'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
        dropout_rate=config['dropout_uq'], num_classes=config['num_classes'])
    estimations = estimator.estimate(x_pool)
    plt.scatter(estimations, pool_ll, alpha=0.5)
    plt.xlabel('Uncertainty estimation')
    plt.ylabel('Log likelihood')
    plt.title(name)
    estimator.reset()
# plt.legend(loc='lower right')
estimator = build_estimator('bald_ensemble', ensemble, num_classes=config['num_classes'])
estimations = estimator.estimate(x_pool)
pool_ll = ll(trainer, x_pool, y_pool)
plt.subplot(num, 2, 1)
plt.scatter(estimations, pool_ll, alpha=0.5)
plt.title('ensemble')

 

#### NDCG estimation


In [None]:
estimation_samples = 3000 
ndcgs, estimator_type, train_size = [], [], []


for train_samples in [500, 2000, 5000, 20000, len(x_train)]:
    print('\n', train_samples)
    for i in range(config['model_runs']):
        trainer, ensemble = retrain(train_samples)
        for j in range(config['repeat_runs']):
            idxs = np.random.choice(len(x_val), estimation_samples, replace=False)
            x_current = x_val[idxs]
            y_current = y_val[idxs]
            
            # ensemble
            print('ensemble')
            estimator = build_estimator('bald_ensemble', ensemble, num_classes=config['num_classes'])
            current_ll = ll(ensemble, x_current, y_current)
            uq = estimator.estimate(x_current)
            ndcgs.append(uq_ndcg(-current_ll, uq))
            estimator_type.append('ensemble')
            train_size.append(train_samples)
            

            # masks
            current_ll = ll(trainer, x_current, y_current)
            for mask_name, mask in masks.items():
                print(mask_name)
                estimator = build_estimator(
                    'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
                    dropout_rate=config['dropout_uq'], num_classes=config['num_classes'])
                uq = estimator.estimate(x_current)
                estimator.reset()
                ndcgs.append(uq_ndcg(-current_ll, uq))
                estimator_type.append(mask_name)
                train_size.append(train_samples)
                estimator.reset()
    

In [None]:
plt.figure(figsize=(12, 8))
plt.title(f"NDCG on different train samples")

df = pd.DataFrame({
    'ndcg': ndcgs,
    'estimator_type': estimator_type,
    'train_size': train_size
})
sns.boxplot(data=df, x='estimator_type',  y='ndcg', hue='train_size')


#### Different nn_runs

In [None]:
estimation_samples = 3000 
ndcgs, estimator_type, nn_size = [], [], []

train_samples = 3000


for nn_runs in [20, 50, 100, 150]:
    print('\n', nn_runs)
    masks = build_masks(DEFAULT_MASKS, nn_runs=nn_runs)
    n_models = nn_runs // 10
    for i in range(config['model_runs']):
        trainer, ensemble = retrain(train_samples, n_models=n_models)
        for j in range(config['repeat_runs']):
            idxs = np.random.choice(len(x_val), estimation_samples, replace=False)
            x_current = x_val[idxs]
            y_current = y_val[idxs]
            
            # ensemble
            print('ensemble')
            estimator = build_estimator('bald_ensemble', ensemble, num_classes=config['num_classes'])
            current_ll = ll(ensemble, x_current, y_current)
            uq = estimator.estimate(x_current)
            ndcgs.append(uq_ndcg(-current_ll, uq))
            estimator_type.append('ensemble')
            nn_size.append(nn_runs)

            # masks
            current_ll = ll(trainer, x_current, y_current)
            for mask_name, mask in masks.items():
                print(mask_name)
                estimator = build_estimator(
                    'bald_masked', trainer, nn_runs=nn_runs, dropout_mask=mask,
                    dropout_rate=config['dropout_uq'], num_classes=config['num_classes'])
                uq = estimator.estimate(x_current)
                estimator.reset()
                ndcgs.append(uq_ndcg(-current_ll, uq))
                estimator_type.append(mask_name)
                nn_size.append(nn_runs)
                estimator.reset()

In [None]:
plt.figure(figsize=(12, 8))
plt.title(f"NDCG on different nn_runs (ensemble size = nn_runs/10); train_size {train_samples}")

df = pd.DataFrame({
    'ndcg': ndcgs,
    'estimator_type': estimator_type,
    'nn_runs': nn_size
})
sns.boxplot(data=df, x='estimator_type',  y='ndcg', hue='nn_runs')

