In [None]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch

from dataloader.rosen import RosenData
from experiment_setup import get_model, set_random, build_estimator
from analysis.metrics import uq_accuracy, uq_ndcg, uq_nll
from uncertainty_estimator.masks import build_masks 
from dataloader.toy import ToyQubicData, ToySinData
from model.mlp import MLP

plt.rcParams['figure.facecolor'] = 'white'

In [None]:
import torch
torch.cuda.set_device(1)

In [None]:

config = {
    'nn_runs': 200,
    'verbose': False,
    'use_cache': False, 
    'layers': [1, 256, 256, 128, 1],
    # 'layers': [1, 128, 128, 64, 1],
    # 'layers': [1, 64, 64, 32, 1],
    'patience': 10,
    'dropout_train': 0.2,
    'dropout_uq': 0.5
}

### Visualizing on toy data

In [None]:
# Generate dataset
dataset = 'qubic'
# dataset = 'sin'

if dataset == 'qubic':
    x_train, y_train = ToyQubicData(use_cache=config['use_cache']).dataset('train')
    x_val, y_val = ToyQubicData(use_cache=config['use_cache']).dataset('val')
    x_true, y_true = ToyQubicData().dataset('ground_truth')
else:
    x_train, y_train = ToySinData(use_cache=config['use_cache']).dataset('train')
    x_val, y_val = ToySinData(use_cache=config['use_cache']).dataset('val')
    x_true, y_true = ToySinData().dataset('ground_truth')
    
plt.plot(x_true, y_true)
plt.scatter(x_train, y_train, color='red')
plt.scatter(x_val, y_val, color='green')

In [None]:
# Train model
model = MLP(config['layers'], l2_reg=1e-5)
model.fit((x_train, y_train), (x_train, y_train), patience=config['patience'], validation_step=50, batch_size=5, dropout_rate=config['dropout_train'])

y_pred = model(x_true).cpu().numpy()
plt.figure(figsize=(12, 8))
plt.plot(x_true, y_true, alpha=0.5)
plt.scatter(x_train, y_train, color='red')
plt.scatter(x_true, y_pred, color='green', marker='+')

In [None]:
masks = build_masks(nn_runs=config['nn_runs'])

In [None]:
# Evaluate UQ and plot the results
plt.figure(figsize=(16, 30))

def make_uq_graph(name, estimations):
    plt.title(name)
    plt.plot(x_true, y_true, alpha=0.5)
    plt.scatter(x_true, y_pred, color='green', marker='+')
    plt.fill_between(np.ravel(x_true), np.ravel(y_pred)-estimations, np.ravel(y_pred)+estimations, alpha=0.3, color='green')
    plt.scatter(x_train, y_train, color='red')


for i, (name, mask) in enumerate(masks.items()):
    print(name)
    if hasattr(mask, 'reset'):
        mask.reset()
    estimator = build_estimator(
        'mcdue_masked', model, nn_runs=config['nn_runs'], dropout_mask=mask,
        dropout_rate=config['dropout_uq'])

    estimations = estimator.estimate(x_true)
    plt.subplot(6, 2, i+1)
    make_uq_graph(name, estimations)

    
nngp = build_estimator('nngp', model, nn_runs=config['nn_runs'], dropout_rate=config['dropout_uq'])
estimations = nngp.estimate(x_true, x_train)
plt.subplot(6, 2, len(masks)+1)
make_uq_graph('nngp', estimations)