# Extract eval results

Analyze evaluation results for BMNIST with modified $s_0$
- Clean data

In [45]:
import sys
sys.path.append("./../")

In [46]:
import os
import json
import glob
import pickle

import numpy as np
import pandas as pd

In [47]:
def extract_results(model_dir):
    """
        Get OOD metrics from model dir
    """

    # Get config
    config_json = os.path.join(model_dir, 'config.json')
    config = json.load(open(config_json, 'r'))
    
    # Extract config values
    method = config['method']
    lam_sl = config['method_params'].get('lam_sl', 0.0)
    ds_size = config['ds_params'].get('size', 'Full')
    
    
    results = None
    
    # Get OOD result files
    ood_result_files = glob.glob(model_dir + "/ece_results.pkl")
    
    # Get results
    for rfile in ood_result_files:
        filename = os.path.basename(rfile)
        # Get corruption name from file name
        with open(rfile, 'rb') as f:
            logs = pickle.load(f)[0]
            r = {
                'method': method,
                'lam_sl': lam_sl,
                'ds_size': ds_size,
                'ece': logs['ece_uncal'],
                'acc': logs['acc'],
                'nll': logs['nll_uncal_test'],
                'auroc': logs['auroc']
            }
            
            if results is not None:
                results.append(r)
            else:
                results = [r]
    
    return results

## Model dirs

In [48]:
# LeNet + 1000
# models_root = "./../zoo/abl-a100-mnist-a050b050/BinaryMNISTC-1000-53-identity/LeNet"

# # LeNet + 10000
# models_root = "./../zoo/abl-a100-mnist-a050b050/BinaryMNISTC-10000-53-identity/LeNet"

# # ConvNet + 1000
# models_root = "./../zoo/abl-a100-mnist-a050b050/BinaryMNISTC-1000-53-identity/ConvNet"

# # ConvNet + 10000
models_root = "./../zoo/abl-a100-mnist-a050b050/BinaryMNISTC-10000-53-identity/ConvNet"

## S-ELBO results

In [49]:
model_dirs = list(map(lambda d: os.path.join(models_root, d), os.listdir(models_root)))

In [50]:
results = []
for _m in model_dirs:
    results.extend(extract_results(_m))

In [51]:
df_results = pd.DataFrame(results)

In [52]:
df_results

Unnamed: 0,method,lam_sl,ds_size,ece,acc,nll,auroc
0,sl,0.1,Full,0.174654,0.953749,0.282441,0.9922
1,sl,0.0001,Full,0.001475,0.994394,0.018434,0.999761
2,sl,0.01,Full,0.11434,0.985985,0.157406,0.999132
3,sl,0.1,Full,0.10144,0.934828,0.240448,0.987329
4,sl,0.001,Full,0.028824,0.992992,0.045612,0.999622
5,sl,1e-05,Full,0.002738,0.997898,0.010346,0.999967
6,sl,1e-06,Full,0.002422,0.996496,0.008897,0.999963
7,sl,0.01,Full,0.135422,0.989488,0.178798,0.998762
8,sl,0.001,Full,0.042093,0.993693,0.063582,0.999328
9,sl,1e-05,Full,0.000717,0.996496,0.011352,0.999907


In [53]:
metrics_summ = df_results.groupby('lam_sl').agg(
    n = pd.NamedAgg(column='acc', aggfunc='count'),
    acc_mean = pd.NamedAgg(column='acc', aggfunc='mean'),
    acc_err = pd.NamedAgg(column='acc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ece_mean = pd.NamedAgg(column='ece', aggfunc='mean'),
    ece_err = pd.NamedAgg(column='ece', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    nll_mean = pd.NamedAgg(column='nll', aggfunc='mean'),
    nll_err = pd.NamedAgg(column='nll', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    auroc_mean = pd.NamedAgg(column='auroc', aggfunc='mean'),
    auroc_err = pd.NamedAgg(column='auroc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)

In [54]:
metrics_summ

Unnamed: 0_level_0,n,acc_mean,acc_err,ece_mean,ece_err,nll_mean,nll_err,auroc_mean,auroc_err
lam_sl,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1e-06,5,0.996776,0.000154,0.001897,0.000405,0.009853,0.000707,0.999943,1.6e-05
1e-05,5,0.996216,0.000509,0.002682,0.00073,0.012884,0.002641,0.999944,1.3e-05
0.0001,5,0.996076,0.000425,0.00288,0.000398,0.013201,0.001304,0.999896,3.1e-05
0.001,5,0.992992,0.000443,0.036033,0.003,0.055384,0.00378,0.999514,7.7e-05
0.01,5,0.983462,0.002282,0.1078,0.006967,0.155837,0.006514,0.998375,0.000358
0.1,5,0.9452,0.002882,0.160763,0.017176,0.281172,0.015329,0.989527,0.000795
1.0,5,0.844429,0.014612,0.038363,0.010176,0.346525,0.025321,0.952231,0.00646


In [55]:
metrics_summ = df_results.groupby('lam_sl').agg(
    n = pd.NamedAgg(column='acc', aggfunc='count'),
    acc_mean = pd.NamedAgg(column='acc', aggfunc='mean'),
    acc_err = pd.NamedAgg(column='acc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ece_mean = pd.NamedAgg(column='ece', aggfunc='mean'),
    ece_err = pd.NamedAgg(column='ece', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    nll_mean = pd.NamedAgg(column='nll', aggfunc='mean'),
    nll_err = pd.NamedAgg(column='nll', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    auroc_mean = pd.NamedAgg(column='auroc', aggfunc='mean'),
    auroc_err = pd.NamedAgg(column='auroc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)

In [56]:
metrics_summ

Unnamed: 0_level_0,n,acc_mean,acc_err,ece_mean,ece_err,nll_mean,nll_err,auroc_mean,auroc_err
lam_sl,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1e-06,5,0.996776,0.000154,0.001897,0.000405,0.009853,0.000707,0.999943,1.6e-05
1e-05,5,0.996216,0.000509,0.002682,0.00073,0.012884,0.002641,0.999944,1.3e-05
0.0001,5,0.996076,0.000425,0.00288,0.000398,0.013201,0.001304,0.999896,3.1e-05
0.001,5,0.992992,0.000443,0.036033,0.003,0.055384,0.00378,0.999514,7.7e-05
0.01,5,0.983462,0.002282,0.1078,0.006967,0.155837,0.006514,0.998375,0.000358
0.1,5,0.9452,0.002882,0.160763,0.017176,0.281172,0.015329,0.989527,0.000795
1.0,5,0.844429,0.014612,0.038363,0.010176,0.346525,0.025321,0.952231,0.00646


In [57]:
for row in metrics_summ.itertuples():
    print(
        "& ${:.3f} \pm {:.3f}$".format(row.nll_mean, row.nll_err),
        "& ${:.3f} \pm {:.3f}$".format(row.acc_mean, row.acc_err),
#         "& ${:.3f} \pm {:.3f}$".format(row.auroc_mean, row.auroc_err),
        "& ${:.3f} \pm {:.3f}$".format(row.ece_mean, row.ece_err)
    )

& $0.010 \pm 0.001$ & $0.997 \pm 0.000$ & $0.002 \pm 0.000$
& $0.013 \pm 0.003$ & $0.996 \pm 0.001$ & $0.003 \pm 0.001$
& $0.013 \pm 0.001$ & $0.996 \pm 0.000$ & $0.003 \pm 0.000$
& $0.055 \pm 0.004$ & $0.993 \pm 0.000$ & $0.036 \pm 0.003$
& $0.156 \pm 0.007$ & $0.983 \pm 0.002$ & $0.108 \pm 0.007$
& $0.281 \pm 0.015$ & $0.945 \pm 0.003$ & $0.161 \pm 0.017$
& $0.347 \pm 0.025$ & $0.844 \pm 0.015$ & $0.038 \pm 0.010$
