# Extract eval results - SSTBERT

In [None]:
import sys
sys.path.append("./../")

In [None]:
import os
import json
import glob
import pickle

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('darkgrid')

#### Utility functions

In [None]:
# Accuracy threshold for models to be included in analysis
ACC_THRESHOLD = 0.80

In [None]:
def extract_results(model_dir):
    """
        Get metrics from model dir
    """

    # Get config
    config_json = os.path.join(model_dir, 'config.json')
    config = json.load(open(config_json, 'r'))
    
    # Extract config values
    method = config['method']
    
    # Create a table entry for parameters  as string
    param_str = method # start with method label
    for _p, _p_value in config['method_params'].items():
        _p_value_str = None
        if isinstance(_p_value, int):
            _p_value_str = '{:04d}'.format(_p_value)
        elif isinstance(_p_value, float):
            _p_value_str = '{:08.3f}'.format(_p_value)
        else:
            _p_value_str = '{}'.format(_p_value)
        param_str += '-{}={}'.format(_p, _p_value_str)
    param = param_str
    
    results = None
    
    # Get ECE result files
    ece_result_files = glob.glob(model_dir + "/ece_results_*.pkl")
    # print(model_dir, ood_result_files)
    
    # Get results
    for rfile in ece_result_files:
        filename = os.path.basename(rfile)
        # Get corruption name from file name
        corr_name = ' '.join(filename.split('_')[2:])[:-4]
        with open(rfile, 'rb') as f:
            logs = pickle.load(f)[0]
            r = {
                'method': method,
                'params': param,
                'corruption': corr_name,
                'ece': logs['ece_uncal_val'],
                'acc': logs['acc_val'],
                'nll': logs['nll_uncal_val'],
                'auroc': logs['auroc_val'],
                'ece_test': logs['ece_uncal_test'],
                'acc_test': logs['acc_test'],
                'nll_test': logs['nll_uncal_test'],
                'auroc_test': logs['auroc_test']
            }
            
            if results is not None:
                results.append(r)
            else:
                results = [r]
    
    return results

#### Specify experiments

In [None]:
# SST - BERT
result_dirs = [
    "./../zoo/sst/edl/SSTBERT/SSTNetEDL",
    "./../zoo/sst/ls/SSTBERT/SSTNet",
    "./../zoo/sst/mfvi/SSTBERT/SSTNet",
    # "./../zoo/sst/sl-eqbin/auto-prior-alphavar/SSTBERT/SSTNet",
    "./../zoo/sst/sl-eqbin/uniform-prior-alphavar/SSTBERT/SSTNet",
    # "./../zoo/sst/sl-uneqbin/auto-prior-alphavar/SSTBERT/SSTNet",
    # "./../zoo/sst/sl-uneqbin/uniform-prior-alphavar/SSTBERT/SSTNet",
]

#### Load results

In [None]:
# Enumerate model directories and load evaluation results
results = []
for models_root in result_dirs:
    model_dirs = list(map(lambda d: os.path.join(models_root, d), os.listdir(models_root)))
    for _m in model_dirs:
        results.extend(extract_results(_m))
df_results = pd.DataFrame(results)

In [None]:
# Do basic QA, ignore all models which fail to train satisfactorily
df_results = df_results[df_results.acc > ACC_THRESHOLD]

In [None]:
df_results

## Get results for Clean dataset

In [None]:
df_clean = df_results[df_results.corruption == 'eps-0.00'].drop(['corruption'], axis=1).reset_index()

In [None]:
df_clean

In [None]:
metrics_summ = df_clean.groupby('params').agg(
    n = pd.NamedAgg(column='acc', aggfunc='count'),
    acc_mean = pd.NamedAgg(column='acc', aggfunc='mean'),
    acc_err = pd.NamedAgg(column='acc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ece_mean = pd.NamedAgg(column='ece', aggfunc='mean'),
    ece_err = pd.NamedAgg(column='ece', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    nll_mean = pd.NamedAgg(column='nll', aggfunc='mean'),
    nll_err = pd.NamedAgg(column='nll', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    auroc_mean = pd.NamedAgg(column='auroc', aggfunc='mean'),
    auroc_err = pd.NamedAgg(column='auroc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)

In [None]:
metrics_summ

In [None]:
nll_min = metrics_summ.nll_mean.min()
acc_max = metrics_summ.acc_mean.max()
auroc_max = metrics_summ.auroc_mean.max()
ece_min = metrics_summ.ece_mean.min()

for row in metrics_summ.itertuples():
    buffer = "{:56s}".format(row.Index)

    if row.nll_mean == nll_min:
        buffer += "& $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.nll_mean, row.nll_err)
    else:
        buffer += "&          ${:.3f} \pm {:.3f}$".format(row.nll_mean, row.nll_err)

    if row.acc_mean == acc_max:
        buffer += "& $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.acc_mean, row.acc_err)
    else:
        buffer += "&          ${:.3f} \pm {:.3f}$".format(row.acc_mean, row.acc_err)

    if row.auroc_mean == auroc_max:
        buffer += "& $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.auroc_mean, row.auroc_err)
    else:
        buffer += "&          ${:.3f} \pm {:.3f}$".format(row.auroc_mean, row.auroc_err)

    if row.ece_mean == ece_min:
        buffer += "& $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.ece_mean, row.ece_err)
    else:
        buffer += "&          ${:.3f} \pm {:.3f}$".format(row.ece_mean, row.ece_err)

    print(buffer)

#### Get results for test dataset

In [None]:
metrics_summ_test = df_clean.groupby('params').agg(
    n = pd.NamedAgg(column='acc_test', aggfunc='count'),
    acc_mean = pd.NamedAgg(column='acc_test', aggfunc='mean'),
    acc_err = pd.NamedAgg(column='acc_test', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ece_mean = pd.NamedAgg(column='ece_test', aggfunc='mean'),
    ece_err = pd.NamedAgg(column='ece_test', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    nll_mean = pd.NamedAgg(column='nll_test', aggfunc='mean'),
    nll_err = pd.NamedAgg(column='nll_test', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    auroc_mean = pd.NamedAgg(column='auroc_test', aggfunc='mean'),
    auroc_err = pd.NamedAgg(column='auroc_test', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)

In [None]:
metrics_summ_test

In [None]:
def get_prefix(x):
    parts = x.split('-')
    if len(parts) < 2:
        return '-'.join(parts)
    else:
        buffer = '-'.join(parts[:-1])
        if buffer.startswith('sl'):
            buffer += '-alpha='
        return buffer

# Find groups of experiments without looking at last param
unique_prefixes = list(sorted(set(
        list(map(get_prefix, metrics_summ.index))
    )))

# For each of the unique prefixes, find the best in validation group according to
# NLL
r = []
for pfx in unique_prefixes:
    _df_val = metrics_summ[metrics_summ.index.str.startswith(pfx)]
    idx = _df_val.nll_mean.idxmin()

    # Now get the corresponding results from test set
    r.append(metrics_summ_test.loc[idx])

df_nll_best = pd.DataFrame(r)

In [None]:
df_nll_best

In [None]:
# For MNIST Results Table

nll_min = df_nll_best.nll_mean.min()
acc_max = df_nll_best.acc_mean.max()
auroc_max = df_nll_best.auroc_mean.max()
ece_min = df_nll_best.ece_mean.min()

for row in df_nll_best.itertuples():
    buffer = "{:56s}".format(row.Index)

    if row.nll_mean == nll_min:
        buffer += " & $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.nll_mean, row.nll_err)
    else:
        buffer += " &          ${:.3f} \pm {:.3f}$".format(row.nll_mean, row.nll_err)

    if row.acc_mean == acc_max:
        buffer += " & $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.acc_mean, row.acc_err)
    else:
        buffer += " &          ${:.3f} \pm {:.3f}$".format(row.acc_mean, row.acc_err)

    if row.auroc_mean == auroc_max:
        buffer += " & $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.auroc_mean, row.auroc_err)
    else:
        buffer += " &          ${:.3f} \pm {:.3f}$".format(row.auroc_mean, row.auroc_err)

    if row.ece_mean == ece_min:
        buffer += " & $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.ece_mean, row.ece_err)
    else:
        buffer += " &          ${:.3f} \pm {:.3f}$".format(row.ece_mean, row.ece_err)

    print(buffer)

## Get results for Corrupted dataset

In [None]:
df_corrupted = df_results

In [None]:
df_corrupted

In [None]:
df_corrupted['gamma'] = df_corrupted.corruption.apply(lambda x: float(x.split('-')[-1]))

In [None]:
df_nll_best.index.tolist()

In [None]:
# Get the result for best configuration in family
df_corrupted_test = df_corrupted[df_corrupted.params.isin(df_nll_best.index.tolist())].reset_index(drop=True)

In [None]:
df_corrupted_test

In [None]:
gdf_corrupted_ece_mean = df_corrupted_test[
    ['params', 'gamma', 'ece_test', 'nll_test', 'acc_test']
].groupby(['params', 'gamma']).mean()

df = gdf_corrupted_ece_mean.reset_index()
df = df[df.gamma <= 0.5]

In [None]:
def label_mapper(label):
    if label.startswith('edl'):
        return "EDL"
    elif label.startswith("mfvi"):
        return "ELBO"
    elif label.startswith("ls"):
        return "LS"
    elif label.startswith("sl"):
        # SL
        params = []
        alpha = float(label.split("=")[-1])
        params.append("$\\alpha = {:.0f}$".format(alpha))
        if 'adahist' in label:
            params.append("uneqbin")
        else:
            params.append("eqbin")
        if 'ea' in label:
            params.append("auto")
        else:
            params.append("uniform")
        # buff = "SL ({})".format(', '.join(params))
        buff = "Proposed"

        return buff

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(2.5*3, 2))

for label, _df in df.groupby(by='params'):
    label = label_mapper(label)
    # ax[0].semilogy(_df.gamma, _df.nll_test, label=label)
    # ax[1].semilogy(_df.gamma, _df.acc_test, label=label)
    # ax[2].semilogy(_df.gamma, _df.ece_test, label=label)
    ax[0].plot(_df.gamma, _df.nll_test, label=label)
    ax[1].plot(_df.gamma, _df.acc_test, label=label)
    ax[2].plot(_df.gamma, _df.ece_test, label=label)

ax[0].set_title("NLL")
ax[1].set_title("Accuracy")
ax[2].set_title("ECE")

for _ax in ax:
    _ax.set_xlabel("$\\gamma$")
    _ax.set_xticks([0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
    _ax.legend()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(6*3, 4))

for label, _df in df.groupby(by='params'):
    label = label_mapper(label)
    # ax[0].semilogy(_df.gamma, _df.nll_test, label=label)
    # ax[1].semilogy(_df.gamma, _df.acc_test, label=label)
    # ax[2].semilogy(_df.gamma, _df.ece_test, label=label)
    ax[0].plot(_df.gamma, _df.nll_test, label=label)
    ax[1].plot(_df.gamma, _df.acc_test, label=label)
    ax[2].plot(_df.gamma, _df.ece_test, label=label)

ax[0].set_title("NLL")
ax[1].set_title("Accuracy")
ax[2].set_title("ECE")

for _ax in ax:
    _ax.set_xlabel("$\\gamma$")
    _ax.set_xticks([0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
    _ax.legend()