In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from copy import deepcopy

import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from dataloader.rosen import RosenData
from experiment_setup import get_model, set_random, build_estimator
from analysis.metrics import uq_accuracy, uq_ndcg
from uncertainty_estimator.masks import BasicMask, LHSMask, MirrorMask, DecorrelationMask

config = {
    'estimator': 'nngp',
    'random_seed': 43,
    'n_dim': 10,
    'nn_runs': 25,
    'data_size': 2000,
    'data_split': [0.4, 0.6, 0, 0],
    'update_size': 100,
    'al_iterations': 10,
    'verbose': True,
    'use_cache': True,
    'layers': [10, 128, 64, 32, 1],
    'patience': 5,
    'retrain': True,
    'model_path': 'model/data/rosen_visual.ckpt',
    'epochs': 20_000,
    'runs': 10,
    'acc_percentile': 0.1,
    'patience': 10
}


### UQ estimation on 10d rosen function data

In [None]:
rosen = RosenData(
    config['n_dim'], config['data_size'], config['data_split'],
    use_cache=config['use_cache'])

x_pool, y_pool = rosen.dataset('pool')
x_train, y_train = rosen.dataset('train')
x_val, y_val = rosen.dataset('train')

# set_random(config['random_seed'])


In [None]:
def print_uq_at_error(model, estimator, x_val, x_train, y_train):
    estimations = estimator.estimate(x_val, x_train, y_train)
    predictions = model(x_val).cpu().numpy()
    errors = np.abs(predictions-y_val)/(predictions+y_val)
    plt.figure(figsize=(12, 9))
    plt.ylabel('Uncertainty')
    plt.xlabel('Error')
    plt.scatter(errors, estimations)

### Masking

In [None]:
config['model_runs'] = 5
target_loss = 750
# Train the models
model_paths = [f"model/data/rosen_visual_{i}.ckpt" for i in range(config['model_runs'])]

count = 0
while count != config['model_runs']:
    model = get_model(
        config['layers'], model_paths[count],
        (x_train, y_train), (x_val, y_val), epochs=config['epochs'],
        retrain=True, verbose=False, patience=config['patience'])
    print(f"Loss {model.val_loss}/{target_loss}")
    if model.val_loss < target_loss:
        print("Fix model", count)
        count += 1


In [None]:
config['runs'] = 2
config['nn_runs'] = 100
config['model_runs'] = 5


masks = {
    'vanilla': None,
    'basic_mask': BasicMask(),
    'lhs': LHSMask(config['nn_runs']),
    'lhs_shuffled': LHSMask(config['nn_runs'], shuffle=True),
    'mirror_random': MirrorMask(),
    'decorrelating': DecorrelationMask(),
    'decorrelating_scaled': DecorrelationMask(scaling=True, dry_run=False)
}

mask_results = []

for model_run in range(config['model_runs']):
    print(f"===MODEL RUN {model_run+1}====")
    model = get_model(
        config['layers'], model_paths[model_run],
        (x_train, y_train), (x_val, y_val), epochs=config['epochs'])   
    predictions = model(x_val).cpu().numpy()
    errors = np.abs(predictions - y_val)

    for name, mask in masks.items():
        estimator = build_estimator(
            'mcdue_masked', model, nn_runs=config['nn_runs'], dropout_mask=mask,
            dropout_rate=0.3)

        for run in range(config['runs']):
            estimations = estimator.estimate(x_val, x_train, y_train)
            acc = uq_accuracy(estimations, errors, config['acc_percentile'])
            ndcg = uq_ndcg(errors, estimations)
            mask_results.append([acc, ndcg, name])

            if hasattr(mask, 'reset'):
                mask.reset()


In [None]:
mask_df = pd.DataFrame(mask_results, columns = ['acc', 'ndcg', 'mask'])

plt.figure()
plt.xticks(rotation=30)
plt.ylim(0, 0.8)
sns.boxplot(data=mask_df, x='mask', y='acc')
plt.figure()
plt.xticks(rotation=30)
plt.ylim(0, 0.9)
sns.boxplot(data=mask_df, x='mask', y='ndcg')