In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from copy import deepcopy

import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from dataloader.rosen import RosenData
from visualize import get_vae, get_model, set_random, build_estimator
from analysis.metrics import uq_accuracy, uq_ndcg

config = {
    'estimator': 'nngp',
    'random_seed': 43,
    'n_dim': 10,
    'data_size': 2000,
    'data_split': [0.4, 0.6, 0, 0],
    'update_size': 100,
    'al_iterations': 10,
    'verbose': True,
    'use_cache': True,
    'layers': [10, 128, 64, 32, 1],
    'patience': 5,
    'retrain': True,
    'model_path': 'model/data/rosen_visual.ckpt',
    'epochs': 20_000,
    'runs': 10
}


### UQ estimation on 10d rosen function data

In [None]:
rosen = RosenData(
    config['n_dim'], config['data_size'], config['data_split'],
    use_cache=config['use_cache'])

x_pool, y_pool = rosen.dataset('pool')
x_train, y_train = rosen.dataset('train')
x_val, y_val = rosen.dataset('train')

# set_random(config['random_seed'])


In [None]:
def print_uq_at_error(model, estimator, x_val, x_train, y_train):
    estimations = estimator.estimate(x_val, x_train, y_train)
    predictions = model(x_val).cpu().numpy()
    errors = np.abs(predictions-y_val)/(predictions+y_val)
    plt.figure(figsize=(12, 9))
    plt.ylabel('Uncertainty')
    plt.xlabel('Error')
    plt.scatter(errors, estimations)

### Base experiment

In [None]:
runs = 1
retrain = (runs != 1)
epochs = 30_000
percentile = 0.10

results = []

for run in range(runs):
    print(f"====== RUN {run} =====")
    model = get_model(
        config['layers'], config['model_path'], (x_train, y_train), (x_val, y_val),
        epochs=epochs, retrain=retrain)
    
    predictions = model(x_val).cpu().numpy()
    errors = np.abs(predictions - y_val)
    
    for estimator_name in ['mcdue', 'nngp', 'random']:
        estimator = build_estimator(estimator_name, model)
        estimations = estimator.estimate(x_val, x_train, y_train)
        
        acc = uq_accuracy(estimations, errors, percentile)
        ndcg = uq_ndcg(errors, estimations)
        results.append([run, estimator_name, acc, ndcg])
         

In [None]:
df = pd.DataFrame(results, columns=['run', 'estimator', 'acc', 'ndcg'])
plt.figure()
sns.boxplot(data=df, x='estimator', y='acc')
plt.figure()
sns.boxplot(data=df, x='estimator', y='ndcg')

### Masking

In [None]:
model = get_model(
    config['layers'], config['model_path'],
    (x_train, y_train), (x_val, y_val), epochs=config['epochs'])   

In [None]:
percentile = 0.10

estimator = build_estimator('mcdue', model)

predictions = model(x_val).cpu().numpy()
errors = np.abs(predictions - y_val)

mask_results = []
for run in range(config['runs']):
    estimations = estimator.estimate(x_val, x_train, y_train)

    acc = uq_accuracy(estimations, errors, percentile)
    ndcg = uq_ndcg(errors, estimations)
    mask_results.append([acc, ndcg, 'vanilla'])



In [None]:
from uncertainty_estimator.masks import BasicMask

results_2 = deepcopy(mask_results)



masks = {
    'basic_mask': BasicMask,
#     'lhs_shuffled': lhs_shuffled,
#     'mirror_random': mirror_random,
#     'decorrelating': decorellating
}

# lhs(torch.rand([10,1]))

for name, mask in masks.items():
    mf = BasicMask()
    estimator = build_estimator('mcdue_masked', model, dropout_mask=mf)
    for run in range(config['runs']):
        estimations = estimator.estimate(x_val, x_train, y_train)

        acc = uq_accuracy(estimations, errors, percentile)
        ndcg = uq_ndcg(errors, estimations)
        results_2.append([acc, ndcg, name])


mask_df = pd.DataFrame(results_2, columns = ['acc', 'ndcg', 'mask'])
plt.figure()
sns.boxplot(data=mask_df, x='mask', y='acc')
plt.figure()
sns.boxplot(data=mask_df, x='mask', y='ndcg')