In [0]:
# !python -m pip install git+https://github.com/stat-ml/alpaca

In [0]:
import os
from pathlib import Path
import random

import matplotlib.pyplot as plt
import pickle
import torch
import numpy as np
from sklearn.model_selection import train_test_split

In [0]:
from alpaca.model.ensemble import MLPEnsemble
from alpaca.dataloader.builder import build_dataset
from alpaca.uncertainty_estimator import build_estimator
from alpaca.analysis.metrics import get_uq_metrics

In [0]:
from sklearn.preprocessing import StandardScaler
import numpy as np
from sklearn.model_selection import KFold
# we have pretrained models scales with these functions
def scale(train, val):
    scaler = StandardScaler()
    scaler.fit(train)
    train = scaler.transform(train)
    val = scaler.transform(val)
    return train, val, scaler


def split_ood(x_all, y_all, percentile=10):
    threshold = np.percentile(y_all, percentile)
    ood_idx = np.argwhere(y_all > threshold)[:, 0]
    x_ood, y_ood = x_all[ood_idx], y_all[ood_idx]
    train_idx = np.argwhere(y_all <= threshold)[:, 0]
    x_train, y_train = x_all[train_idx], y_all[train_idx]

    return x_train, y_train, x_ood, y_ood


def multiple_kfold(k, data_size, max_iterations):
    kfold = KFold(k)
    for i in range(max_iterations):
        if i % k == 0:
            data_idx = np.random.permutation(data_size)
            idx_generator = kfold.split(data_idx)
        train_idx, val_idx = next(idx_generator)
        yield data_idx[train_idx], data_idx[val_idx]

In [0]:
# for reproducibility (given same models)
SEED = 10
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

torch.cuda.set_device(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



In [0]:
# from google.colab import drive
# drive.mount('/content/gdrive')
# path = './gdrive/My Drive/chk/expD'
path = './data/regression'

In [0]:
import os
import pandas as pd
files = os.listdir(path)

In [0]:
import os
import pickle
import random
from pathlib import Path

import pandas as pd
import torch
from torch.nn.functional import elu
import numpy as np
import matplotlib.pyplot as plt

from alpaca.uncertainty_estimator.masks import build_masks, DEFAULT_MASKS
from alpaca.model.ensemble import MLPEnsemble
from alpaca.uncertainty_estimator import build_estimator
from alpaca.analysis.metrics import get_uq_metrics

plt.rcParams['figure.facecolor'] = 'white'

In [0]:
# estimator that invokes low-level alpaca functions
def construct_estimator(model, model_type, name, nn_runs):
    if model_type == 'mask': 
        mask = masks[name]
        msk = build_estimator(
            'mcdue_masked', model, nn_runs=nn_runs,
            dropout_mask=mask,
            dropout_rate=config['dropout_uq'])
        msk.tol_level=1e-10
        return msk
    elif model_type == 'emask': 
        mask = emasks[name]
        msk = build_estimator(
            'emcdue_masked', model, nn_runs=nn_runs,
            dropout_mask=mask,
            dropout_rate=config['dropout_uq'])
        msk.tol_level=1e-10
        return msk
    else:
        return build_estimator(name, model)

# evaluator that rescales dataset back and estimates the uncertainty
class Evaluator:    
    def __init__(self, x_test, y_test, y_scaler, tag='standard'):
        self.x_test = torch.DoubleTensor(x_test).cuda()
        self.y_test = y_test
        self.unscale = lambda y : y_scaler.inverse_transform(y) 
        self.tag = tag
        self.results = []

    def bench(self, model, name, model_type='mask', nn_runs): 
        predictions = model(self.x_test).cpu().detach().numpy()
        
        errors = np.abs(predictions - self.y_test)
        
        scaled_errors = self.unscale(predictions) - self.unscale(self.y_test)
        rmse = np.sqrt(np.mean(np.square(scaled_errors)))

        estimator = construct_estimator(model, model_type, name, nn_runs)
        if model_type == 'emask':
            name = 'e_' + name
        
        for run in range(config['n_ue_runs']):
            estimations = estimator.estimate(self.x_test)
            acc, ndcg, ll = get_uq_metrics(estimations, errors, 
                                           config['acc_percentile'],
                                           bins = [80, 95, 99]
                                          )
            self.results.append([acc, ndcg, ll, rmse, name, self.tag])
            if hasattr(estimator, 'reset'):
                estimator.reset()

In [0]:
folder = Path(path)

In [0]:
DEFAULT_MASKS = ['mc_dropout', 'ht_leverages',
                  'ht_dpp', 'ht_k_dpp',
                  'cov_dpp', 'cov_k_dpp',]

In [0]:
data = []
errs = []
# loop through files
for cnt, file in enumerate(files):
  with open(folder / file, 'rb') as f:
      dct = pickle.load(f)
  print(file)
  config = dct['config']
  config['n_ue_runs'] = 1
  config['acc_percentile'] = .1
  state_dict = dct['state_dict']
  # loading data
  x_train, y_train, x_val, y_val, x_scaler, y_scaler = dct['data']
  # loading model
  model = MLPEnsemble(
      config['layers'], n_models=config['n_ens'],
      activation = elu,
      reduction='mean')
  model.load_state_dict(state_dict)
  # preparing the evaluator
  standard_evaluator = Evaluator(x_val, y_val, y_scaler, 'standard')
  masks = build_masks(DEFAULT_MASKS)
  emasks = []
  for i in range(config['n_ens']):
      msk = build_masks(DEFAULT_MASKS)
      emasks.append(msk)
  emasks = {key: [e[key] for e in emasks] for key in masks.keys()}
  # evaluation
  for nn_runs in [10, 30, 100]:
    # single model
    single_model = model.models[2]
    for name in masks: 
        print(name, end = '|')
        try:
          standard_evaluator.bench(single_model, name, 'mask', nn_runs)
        except Exception as e:
          errs.append([e,cnt,file,name])
          print('error!', end = '|')
    # eue
    standard_evaluator.bench(model, 'eue', 'ensemble', nn_runs)    
    # masked ensembles
    for name in emasks: 
        print(name, end = '*|')
        try:
          standard_evaluator.bench(model, name, 'emask', nn_runs)
        except Exception as e:
          errs.append([e,cnt,file,name])
          print('error!', end = '|')
    mask_df = pd.DataFrame(standard_evaluator.results, 
                        columns=['Acc', 'NDCG', 'LL',
                                'RMSE', 'Mask', 'Tag'])
    mask_df['fname'] = file
    mask_df['runs'] = nn_runs
    data.append(mask_df)
    dfr = pd.concat(data)
    dfr.to_csv('fname.csv', index = None)

In [0]:
# PLOTTING

In [0]:
masks_dct = {
    'mc_dropout': 'MC dropout',
    'e_mc_dropout': 'MC dropout ens.',
    
    'ht_dpp': 'DPP',
    'e_ht_dpp': 'DPP ens.',
    
    'ht_k_dpp': 'k-DPP',
    'e_ht_k_dpp': 'k-DPP ens.',
    
    'cov_dpp': 'cov DPP',
    'e_cov_dpp': 'cov DPP ens.',
    
    'cov_k_dpp': 'cov k-DPP',
    'e_cov_k_dpp': 'cov k-DPP ens.',
    
    'ht_leverages': 'leverage',
    'e_ht_leverages': 'leverage ens.',
    
    'cov_leverages': 'cov leverage',
    'e_cov_leverages': 'cov leverage ens.',
    'eue': 'Ensemble'
}

dset_dct = {
    'nava': 'naval propulsion',
    'bost': 'boston housing',
    'kin8': 'kin8nm',
    'ccpp': 'ccpp',
    'conc': 'concrete',
    'red': 'red wine'
}

cols_sorted = ['Ensemble',
               'MC dropout', 
               'leverage', 
               'DPP', 
               'cov DPP',
      ]

dset_lims = {
    'bost': (-1.2, .5),
    'ccpp': (-7.5, .5),
    'conc': (-1.5, .3),
    'kin8': (-.4, .75),
    'nava': (-.5, 1.2),
    'red': (-7.5, .0),
}

In [0]:
m2 = [masks_dct[x] if x in masks_dct else x
      for x in dfr.Mask.values]
dfr['mask'] = m2

In [0]:
dfr.loc[dfr['mask'] == 'Ensemble', 'runs'] = 100 # some duplication here..

In [0]:
fig = plt.figure(figsize=(12,6))
for cq, dset in enumerate(dset_lims.keys()):
    plt.subplot(2,3,cq+1)
    ax = sns.boxplot(
            x = 'mask',
            y = 'LL',
            hue = 'runs',
            width=0.5,
            fliersize = 2,
            linewidth = .8,
            whis = 2.5,
            data = dfr[(dfr.dset == dset)],
            order = cols_sorted
        )
    patches = [Rectangle((.5, -200), 1, 300, color = 'r', alpha = .1),
               Rectangle((2.5, -200), 1, 300, color = 'b', alpha = .1),
              ]
    collection = PatchCollection(patches,alpha = .075,color='k')
    ax.add_collection(collection)
    plt.xticks(rotation=90)
    plt.grid(linestyle = ':')
    plt.title(dset_dct[dset])
    plt.xlabel('')
    plt.xlim((0, 4.5))
    plt.ylim(dset_lims[dset])
    if cq != 5:
        ax.get_legend().remove()
plt.tight_layout()
plt.savefig('./images/ll_uci_supp_D.png', dpi = 600)