In [None]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import defaultdict

import torch
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from scipy.special import softmax as softmax
import matplotlib.pyplot as plt

from dataloader.builder import build_dataset
from experiment_setup import build_estimator
from uncertainty_estimator.masks import build_masks, DEFAULT_MASKS
from analysis.metrics import uq_ndcg

from model.cnn import SimpleConv, MediumConv, StrongConv
from model.trainer import Trainer, EnsembleTrainer
from active_learning.al_trainer import ALTrainer
from active_learning.sample_selector import EagerSampleSelector, StochasticSampleSelector

# torch.cuda.set_device(1)

In [None]:
plt.rcParams['figure.facecolor'] = 'white'

config = {
    'use_cuda': True,
    'seed': 1,
    
    'nn_runs': 100,
    'patience': 2,
    'dropout_uq': 0.5,
    
    'n_models': 10, 
    
    # 'dataset': 'mnist',
    'dataset': 'cifar_10',
    
    'model_runs': 10,   
    
    'al_iterations': 10,
    'sampler_type': 'eager'
    # 'sampler_type': 'stochastic'
}


model_setups = {
    'mnist': {
        'model_class': SimpleConv,
        'train_samples': 5000,
        'epochs': 5,
        'batch_size': 32,
        'log_interval': 10,
        'lr': 1e-2,
        'num_classes': 10,
        'input_shape': (-1, 1, 28, 28),
        'al_start': 300,
        'al_step': 200,
        'pool_size': 5_000
    },
    'cifar_10': {
        'model_class': StrongConv,
        'train_samples': 45_000,
        'epochs': 50,
        'batch_size': 256,
        'log_interval': 150,
        'lr': 1e-2,
        'num_classes': 10,
        'input_shape': (-1, 3, 32, 32),
        'al_start': 10_000,
        'al_step': 1000,
        'pool_size': 25_000
    }
}

config.update(model_setups[config['dataset']])



#### Load data and preprocess

In [None]:
dataset = build_dataset(config['dataset'], val_size=10_000)
x_set, y_set = dataset.dataset('train')
x_val, y_val = dataset.dataset('val')



In [None]:
def scale(images):
    return (images - 128) / 128
x_set = scale(x_set)
x_val = scale(x_val)

In [None]:
x_set = x_set.reshape(config['input_shape'])
x_val = x_val.reshape(config['input_shape'])

y_set = y_set.astype('long').reshape(-1)
y_val = y_val.astype('long').reshape(-1)



#### Train model

In [None]:
def al_evaluate(
        model, estimator, x_train, y_train, x_val, y_val, x_pool):
    """Train """
    if config['sampler_type'] == 'eager':
        sampler = EagerSampleSelector()
    else:
        sampler = StochasticSampleSelector()
        
    active_teacher = ALTrainer(
        model, estimator, y_pool=y_pool.copy(), patience=3, update_size=200,
        iterations=config['al_iterations'], verbose=False, sampler=sampler)
    errors = active_teacher.train(
        x_train.copy(), y_train.copy(), x_val.copy(), y_val.copy(), x_pool.copy())
    return errors

In [None]:
errors = defaultdict(list) 

for i in range(config['model_runs']):
    x_train, x_pool, y_train, y_pool= train_test_split(
        x_set, y_set, train_size=config['al_start'], stratify=y_set)
    x_pool, _, y_pool, _ = train_test_split(x_pool, y_pool, train_size=config['pool_size'])
    masks = build_masks(DEFAULT_MASKS)

    # Random estimator
    print(f"\nrandom\n")
    model = config['model_class'](num_classes=config['num_classes'], activation=torch.nn.functional.celu)
    trainer = Trainer(model, batch_size=config['batch_size'])
    estimator = build_estimator('random', trainer)
    model_errors = al_evaluate(trainer, estimator, x_train, y_train, x_val, y_val, x_pool)
    errors['random'].append(model_errors)
    
    # # Ensemble
    # print(f"\nensemble\n")
    # ensemble = EnsembleTrainer(
    #     config['model_class'], {'num_classes': config['num_classes']}, config['n_models'],
    #     batch_size=config['batch_size'])
    # estimator = build_estimator('bald_ensemble', ensemble, num_classes=config['num_classes'])
    # model_errors = al_evaluate( ensemble, estimator, x_train, y_train, x_val, y_val, x_pool)
    # errors['ensemble'].append(model_errors)

    # Masks
    for name, mask in masks.items():
        print(f"\n{name}\n")
        model = config['model_class'](num_classes=config['num_classes'])
        trainer = Trainer(model, batch_size=config['batch_size'])
        estimator = build_estimator(
            'bald_masked', trainer, nn_runs=config['nn_runs'], dropout_mask=mask,
            dropout_rate=config['dropout_uq'], num_classes=config['num_classes'])
        
        model_errors = al_evaluate(trainer, estimator, x_train, y_train, x_val, y_val, x_pool)
        errors[name].append(model_errors)
        

In [None]:
plt.figure(figsize=(16, 10))
for name, err_values in errors.items():
    err_values = np.stack(err_values)
    means = np.mean(err_values, axis=0)
    stds = np.std(err_values, axis=0)
    plt.plot(means, label=name)
    plt.fill_between(
        np.arange(len(means)), means-stds, means+stds, alpha=.1)
plt.ylabel("Error")
plt.xlabel("Step")
title = f"Active learning with {config['al_start']} starting samples"
title += f"and {config['al_step']} samples per step"
plt.title(title)
plt.legend()
