In [None]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from scipy.special import softmax as softmax
import matplotlib.pyplot as plt

from dataloader.builder import build_dataset
from experiment_setup import build_estimator
from uncertainty_estimator.masks import build_masks, DEFAULT_MASKS
from analysis.metrics import uq_ndcg

from model.cnn import SimpleConv, MediumConv, StrongConv
from model.trainer import Trainer, EnsembleTrainer
from active_learning.al_trainer import ALTrainer

torch.cuda.set_device(1)

In [None]:

model_setups = {
    'mnist': {
        'model_class': SimpleConv,
        'train_samples': 5000,
        'epochs': 5,
        'batch_size': 256,
        'log_interval': 10,
        'lr': 1e-2,
        'num_classes': 10,
        'input_shape': (-1, 1, 28, 28)
    },
    'cifar_10': {
        'model_class': StrongConv,
        'train_samples': 45_000,
        'epochs': 50,
        'batch_size': 256,
        'log_interval': 150,
        'lr': 1e-2,
        'num_classes': 10,
        'input_shape': (-1, 3, 32, 32)
    }
}

config = {
    'use_cuda': True,
    'seed': 1,
    
    'nn_runs': 100,
    'patience': 5,
    'dropout_uq': 0.5,
    
    'n_models': 10, 
    
    'dataset': 'mnist',
    # 'dataset': 'cifar_10',
    
    'model_runs': 3,
    'repeat_runs': 3,
    
    'al_start': 500,
    'al_step': 200,
    'pool_size': 5000
}

config.update(model_setups[config['dataset']])



#### Load data and preprocess

In [None]:
dataset = build_dataset(config['dataset'], val_size=10_000)
x_set, y_set = dataset.dataset('train')
x_val, y_val = dataset.dataset('val')



In [None]:
def scale(images):
    return (images - 128) / 128
x_set = scale(x_set)
x_val = scale(x_val)

In [None]:
x_set = x_set.reshape(config['input_shape'])
x_val = x_val.reshape(config['input_shape'])

y_set = y_set.astype('long').reshape(-1)
y_val = y_val.astype('long').reshape(-1)



#### Train model

In [None]:
x_train, x_pool, y_train, y_pool= train_test_split(
    x_set, y_set, train_size=config['al_start'], stratify=y_set)
x_pool, _, y_pool, _ = train_test_split(x_pool, y_pool, train_size=config['pool_size'])

In [None]:
masks = build_masks(DEFAULT_MASKS)

In [None]:
errors = {}
for name, mask in masks.items():
    model = config['model_class']()
    trainer = Trainer(model)
    estimator = build_estimator(
        'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
        dropout_rate=config['dropout_uq'], num_classes=config['num_classes'])
    active_teacher = ALTrainer(
        trainer, estimator, y_pool=y_pool, patience=3, update_size=200,
        iterations=10, verbose=False)
    errors[name] = active_teacher.train(x_train, y_train, x_val, y_val, x_pool)
    

In [None]:
plt.figure(figsize=(16, 10))
for name, err_values in errors.items():
    plt.plot(err_values, label=name)
plt.ylabel("Error")
plt.xlabel("Step")
plt.title(f"AL with {config['al_start']} starting samples and {config['al_step']} samples per step")
plt.legend()
