In [None]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score
from scipy.special import softmax as softmax
import matplotlib.pyplot as plt

from dataloader.builder import build_dataset
from experiment_setup import build_estimator
from uncertainty_estimator.masks import build_masks, DEFAULT_MASKS
from analysis.metrics import uq_ndcg

from model.cnn import SimpleConv, MediumConv, StrongConv
from model.trainer import Trainer, EnsembleTrainer
from sklearn.preprocessing import StandardScaler

torch.cuda.set_device(1)

In [None]:

model_setups = {
    'mnist': {
        'model_class': SimpleConv,
        'train_samples': 5000,
        'epochs': 5,
        'batch_size': 256,
        'log_interval': 10,
        'lr': 1e-2,
        'num_classes': 10
    },
    'cifar_10': {
        'model_class': StrongConv,
        'train_samples': 45_000,
        'epochs': 50,
        'batch_size': 256,
        'log_interval': 150,
        'lr': 1e-2,
        'num_classes': 9
    }
}

config = {
    'use_cuda': True,
    'seed': 1,
    
    'nn_runs': 150,
    'patience': 6,
    'dropout_uq': 0.5,
    
    'n_models': 10, 
    
    # 'dataset': 'mnist',
    'dataset': 'cifar_10',
    
    'model_runs': 3,
    'repeat_runs': 3,
}

config.update(model_setups[config['dataset']])



#### Load data and preprocess

In [None]:
dataset = build_dataset(config['dataset'], val_size=10_000)
x_train, y_train = dataset.dataset('train')
x_val, y_val = dataset.dataset('val')

def cut_class(x, y, class_num):
    new_x = x[np.where(y!=class_num)]
    new_y = y[np.where(y!=class_num)]
    ood = x[np.where(y==class_num)]
    return new_x, new_y, ood

if config['dataset'] == 'mnist':
    ood = build_dataset('fashion_mnist', val_size=0)
    x_ood, _ = ood.dataset('train') 
elif config['dataset'] == 'cifar_10':
    x_train, y_train, x_ood = cut_class(x_train, y_train, '9')
    x_val, y_val, _ = cut_class(x_val, y_val, '9')
    

In [None]:
print(x_train.shape)

scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_val = scaler.transform(x_val)
x_ood = scaler.transform(x_ood)

# x_train /= 255.0
# x_val /= 255.0
# x_ood /= 255.0


In [None]:
if config['dataset'] == 'mnist':
    input_shape = (-1, 1, 28, 28)
elif config['dataset'] == 'cifar_10':
    input_shape = (-1, 3, 32, 32)
x_train = x_train.reshape(input_shape)
x_val = x_val.reshape(input_shape)
x_ood = x_ood.reshape(input_shape)

y_train = y_train.astype('long').reshape(-1)
y_val = y_val.astype('long').reshape(-1)



#### Train model

In [None]:
def retrain(
        train_samples, n_models=config['n_models'], epochs=config['epochs'],
        val_samples=2000, patience=config['patience']):
    idxs = np.random.choice(len(x_train), train_samples, replace=False)
    train_set = (x_train[idxs], y_train[idxs])
    idxs = np.random.choice(len(x_val), val_samples, replace=False)
    val_set = (x_val[idxs], y_val[idxs]) 
    
    model_class = config['model_class'] 
    model = model_class()
    print(model)
    trainer = Trainer(model)
    print(train_set[0].shape)
    trainer.fit(
        train_set, val_set , epochs=epochs, verbose=True, patience=patience)

    ensemble = EnsembleTrainer(model_class, {}, n_models)
    
    return trainer, ensemble

unique, counts = np.unique(y_train[:config['train_samples']], return_counts=True)
dict(zip(unique, counts))

In [None]:
trainer, ensemble = retrain(config['train_samples'], n_models=config['n_models'])

In [None]:
print('Model accuracy train', accuracy_score(y_train[:3000], trainer.predict(x_train[:3000])))
print('Model accuracy val', accuracy_score(y_val[:3000], trainer.predict(x_val[:3000])))
print('Ensemble accuracy', accuracy_score(y_val[:3000], ensemble.predict(x_val[:3000])))


