# Extract eval results

Analyze evaluation results for BMNIST
- Clean data
- OOD with corruptions

In [126]:
import sys
sys.path.append("./../")

In [127]:
import os
import json
import glob
import pickle

import numpy as np
import pandas as pd

In [128]:
def extract_results(model_dir):
    """
        Get OOD metrics from model dir
    """

    # Get config
    config_json = os.path.join(model_dir, 'config.json')
    config = json.load(open(config_json, 'r'))
    
    # Extract config values
    method = config['method']
    lam_sl = config['method_params'].get('lam_sl', 0.0)
    ds_size = config['ds_params'].get('size', 'Full')
    
    
    results = None
    
    # Get OOD result files
    ood_result_files = glob.glob(model_dir + "/ece_results_*.pkl")
    
    # Get results
    for rfile in ood_result_files:
        filename = os.path.basename(rfile)
        # Get corruption name from file name
        corr_name = ' '.join(filename.split('_')[2:])[:-4]
        with open(rfile, 'rb') as f:
            logs = pickle.load(f)[0]
            r = {
                'method': method,
                'lam_sl': lam_sl,
                'ds_size': ds_size,
                'corruption': corr_name,
                'ece': logs['ece_uncal'],
                'acc': logs['acc'],
                'nll': logs['nll_uncal_test'],
                'auroc': logs['auroc']
            }
            
            if results is not None:
                results.append(r)
            else:
                results = [r]
    
    return results

## Model dirs

In [177]:
# # # LeNet + 1000
# models_root = "./../zoo/abl-alpha100-uniform-lenet/BinaryMNISTC-1000-53-identity/LeNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/LeNet"
# ls_models_root = "./../zoo/bmnist53-ls/BinaryMNISTC-1000-53-identity/LeNet"

# # LeNet + 10000
# models_root = "./../zoo/abl-alpha100-uniform-lenet/BinaryMNISTC-10000-53-identity/LeNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-10000-53-identity/LeNet"
# ls_models_root = "./../zoo/bmnist53-ls/BinaryMNISTC-10000-53-identity/LeNet"

# # # ConvNet + 1000
# models_root = "./../zoo/abl-alpha100-uniform-convnet/BinaryMNISTC-1000-53-identity/ConvNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/ConvNet"
# ls_models_root = "./../zoo/bmnist53-ls/BinaryMNISTC-1000-53-identity/ConvNet"

# ConvNet + 10000
models_root = "./../zoo/abl-alpha100-uniform-convnet/BinaryMNISTC-10000-53-identity/ConvNet"
elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-10000-53-identity/ConvNet"
ls_models_root = "./../zoo/bmnist53-ls/BinaryMNISTC-10000-53-identity/ConvNet"

## S-ELBO results

In [178]:
model_dirs = list(map(lambda d: os.path.join(models_root, d), os.listdir(models_root)))

In [179]:
results = []
for _m in model_dirs:
    results.extend(extract_results(_m))

## EBLO results

In [180]:
model_dirs = list(map(lambda d: os.path.join(elbo_models_root, d), os.listdir(elbo_models_root)))

In [181]:
for _m in model_dirs:
    results.extend(extract_results(_m))
df_results = pd.DataFrame(results)

## LS results

In [182]:
model_dirs = list(map(lambda d: os.path.join(ls_models_root, d), os.listdir(ls_models_root)))

In [183]:
for _m in model_dirs:
    results.extend(extract_results(_m))
df_results = pd.DataFrame(results)

In [184]:
# patch up for additional methods
df_results.lam_sl[df_results.method=='mfvi'] = -5.0 # For MFVI
df_results.lam_sl[df_results.method=='ls'] = -1.0 # For label smoothing

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_results.lam_sl[df_results.method=='mfvi'] = -5.0 # For label smoothing
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_results.lam_sl[df_results.method=='ls'] = -1.0 # For label smoothing


In [185]:
df_results

Unnamed: 0,method,lam_sl,ds_size,corruption,ece,acc,nll,auroc
0,sl,1.0,10000,impulse noise,0.099643,0.904837,0.289460,0.966292
1,sl,1.0,10000,stripe,0.015664,0.840694,0.349133,0.963795
2,sl,1.0,10000,glass blur,0.170904,0.913249,0.357369,0.965926
3,sl,1.0,10000,fog,0.117304,0.750263,0.567576,0.836822
4,sl,1.0,10000,motion blur,0.038230,0.697687,0.604941,0.848617
...,...,...,...,...,...,...,...,...
491,ls,-1.0,10000,dotted line,0.008858,0.986856,0.051835,0.999677
492,ls,-1.0,10000,shear,0.006639,0.990010,0.031546,0.999670
493,ls,-1.0,10000,spatter,0.006380,0.993165,0.026982,0.999781
494,ls,-1.0,10000,brightness,0.006558,0.978444,0.054525,0.999816


In [186]:
metrics_summ = df_results.groupby('lam_sl').agg(
    n = pd.NamedAgg(column='acc', aggfunc='count'),
    acc_mean = pd.NamedAgg(column='acc', aggfunc='mean'),
    acc_err = pd.NamedAgg(column='acc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ece_mean = pd.NamedAgg(column='ece', aggfunc='mean'),
    ece_err = pd.NamedAgg(column='ece', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    nll_mean = pd.NamedAgg(column='nll', aggfunc='mean'),
    nll_err = pd.NamedAgg(column='nll', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    auroc_mean = pd.NamedAgg(column='auroc', aggfunc='mean'),
    auroc_err = pd.NamedAgg(column='auroc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)

In [187]:
gdf_corr = df_results.groupby('corruption')

In [188]:
rdfs = []
for k, _df in gdf_corr:
#     r1 = _df.groupby('lam_sl').agg({'corruption': 'first','ece': 'mean'}).sort_values(by='ece').reset_index()
    r1 = _df.groupby('lam_sl').agg({
                    'corruption': 'first', 
                    'acc': 'mean', 
                    'ece': 'mean', 
                    'nll': 'mean',
                    'auroc': 'mean'}).reset_index()
    r1['ece_rank'] = r1.ece.rank(ascending=True)
    r1['acc_rank'] = r1.acc.rank(ascending=False)
    r1['nll_rank'] = r1.nll.rank(ascending=True)
    r1['auroc_rank'] = r1.auroc.rank(ascending=False)
    rdfs.append(r1)

df_ranked = pd.concat(rdfs)
# df_ranked.reset_index(inplace=True)


In [189]:
metrics_summ

Unnamed: 0_level_0,n,acc_mean,acc_err,ece_mean,ece_err,nll_mean,nll_err,auroc_mean,auroc_err
lam_sl,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
-5.0,80,0.950598,0.008744,0.030032,0.00661,0.140794,0.022851,0.992741,0.002146
-1.0,80,0.956151,0.008525,0.030378,0.007055,0.153734,0.02889,0.995076,0.001339
1e-06,48,0.93578,0.01564,0.035378,0.01083,0.151808,0.030813,0.992934,0.002319
1e-05,48,0.929099,0.016956,0.044304,0.01396,0.187946,0.045804,0.99406,0.0019
0.0001,48,0.948169,0.01165,0.032888,0.0083,0.133286,0.023751,0.992531,0.003025
0.001,48,0.938584,0.015263,0.065426,0.009684,0.163765,0.022008,0.989734,0.003627
0.01,48,0.927346,0.014836,0.164915,0.00795,0.306104,0.022635,0.97972,0.008171
0.1,48,0.895373,0.015447,0.201627,0.008863,0.399507,0.019313,0.967586,0.010627
1.0,48,0.833651,0.016689,0.134005,0.007398,0.424787,0.024743,0.926505,0.01396


In [190]:
df_rank_results = df_ranked.groupby('lam_sl').agg(
    ece_rank_mean = pd.NamedAgg(column='ece_rank', aggfunc='mean'),
    ece_rank_err = pd.NamedAgg(column='ece_rank', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    acc_rank_mean = pd.NamedAgg(column='acc_rank', aggfunc='mean'),
    acc_rank_err = pd.NamedAgg(column='acc_rank', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    nll_rank_mean = pd.NamedAgg(column='nll_rank', aggfunc='mean'),
    nll_rank_err = pd.NamedAgg(column='nll_rank', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    auroc_rank_mean = pd.NamedAgg(column='auroc_rank', aggfunc='mean'),
    auroc_rank_err = pd.NamedAgg(column='auroc_rank', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0]))
)

In [191]:
df_final = metrics_summ.merge(df_rank_results, on='lam_sl')

## Printout final results

In [192]:
df_final

Unnamed: 0_level_0,n,acc_mean,acc_err,ece_mean,ece_err,nll_mean,nll_err,auroc_mean,auroc_err,ece_rank_mean,ece_rank_err,acc_rank_mean,acc_rank_err,nll_rank_mean,nll_rank_err,auroc_rank_mean,auroc_rank_err
lam_sl,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
-5.0,80,0.950598,0.008744,0.030032,0.00661,0.140794,0.022851,0.992741,0.002146,3.1875,0.387487,3.625,0.403839,3.125,0.340897,3.9375,0.39989
-1.0,80,0.956151,0.008525,0.030378,0.007055,0.153734,0.02889,0.995076,0.001339,3.5,0.405046,2.28125,0.441873,3.4375,0.483871,2.1875,0.296051
1e-06,48,0.93578,0.01564,0.035378,0.01083,0.151808,0.030813,0.992934,0.002319,3.0625,0.446066,3.84375,0.384562,3.4375,0.404745,3.1875,0.452586
1e-05,48,0.929099,0.016956,0.044304,0.01396,0.187946,0.045804,0.99406,0.0019,3.4375,0.522679,4.75,0.555512,3.6875,0.591071,3.25,0.511585
0.0001,48,0.948169,0.01165,0.032888,0.0083,0.133286,0.023751,0.992531,0.003025,3.4375,0.394976,4.1875,0.509433,3.25,0.390312,3.625,0.304587
0.001,48,0.938584,0.015263,0.065426,0.009684,0.163765,0.022008,0.989734,0.003627,5.4375,0.432731,4.25,0.463512,4.625,0.440835,5.25,0.347985
0.01,48,0.927346,0.014836,0.164915,0.00795,0.306104,0.022635,0.97972,0.008171,7.6875,0.170449,6.125,0.431884,6.875,0.195156,6.625,0.173993
0.1,48,0.895373,0.015447,0.201627,0.008863,0.399507,0.019313,0.967586,0.010627,8.375,0.373696,7.1875,0.461136,8.125,0.195156,7.9375,0.060515
1.0,48,0.833651,0.016689,0.134005,0.007398,0.424787,0.024743,0.926505,0.01396,6.875,0.413399,8.75,0.242061,8.4375,0.249511,9.0,0.0


Print out latex table

In [193]:
for row in df_final.itertuples():
    print(
#         "${:.0e}$".format(row.Index),
        "& ${:.3f} \pm {:.3f}$".format(row.acc_mean, row.acc_err),
        # "& ${:.2f} \pm {:.2f}$".format(row.acc_rank_mean, row.acc_rank_err),
        "& ${:.2f}$".format(row.acc_rank_mean),
        "& ${:.3f} \pm {:.3f}$".format(row.ece_mean, row.ece_err),
#         "& ${:.2f} \pm {:.2f}$".format(row.ece_rank_mean, row.ece_rank_err)
        "& ${:.2f}$".format(row.ece_rank_mean)
    )

& $0.951 \pm 0.009$ & $3.62$ & $0.030 \pm 0.007$ & $3.19$
& $0.956 \pm 0.009$ & $2.28$ & $0.030 \pm 0.007$ & $3.50$
& $0.936 \pm 0.016$ & $3.84$ & $0.035 \pm 0.011$ & $3.06$
& $0.929 \pm 0.017$ & $4.75$ & $0.044 \pm 0.014$ & $3.44$
& $0.948 \pm 0.012$ & $4.19$ & $0.033 \pm 0.008$ & $3.44$
& $0.939 \pm 0.015$ & $4.25$ & $0.065 \pm 0.010$ & $5.44$
& $0.927 \pm 0.015$ & $6.12$ & $0.165 \pm 0.008$ & $7.69$
& $0.895 \pm 0.015$ & $7.19$ & $0.202 \pm 0.009$ & $8.38$
& $0.834 \pm 0.017$ & $8.75$ & $0.134 \pm 0.007$ & $6.88$


Create latex table for aggregate OOD performance over all corruptions

In [194]:
#  Table 13 - 16
for row in df_final.itertuples():
    print(
        "& ${:.3f} \pm {:.3f} ({:.2f})$".format(row.nll_mean, row.nll_err, row.nll_rank_mean),
        "& ${:.3f} \pm {:.3f} ({:.2f})$".format(row.acc_mean, row.acc_err, row.acc_rank_mean),
        "& ${:.3f} \pm {:.3f} ({:.2f})$".format(row.auroc_mean, row.auroc_err, row.auroc_rank_mean),
        "& ${:.3f} \pm {:.3f} ({:.2f})$".format(row.ece_mean, row.ece_err, row.ece_rank_mean)
    )

& $0.141 \pm 0.023 (3.12)$ & $0.951 \pm 0.009 (3.62)$ & $0.993 \pm 0.002 (3.94)$ & $0.030 \pm 0.007 (3.19)$
& $0.154 \pm 0.029 (3.44)$ & $0.956 \pm 0.009 (2.28)$ & $0.995 \pm 0.001 (2.19)$ & $0.030 \pm 0.007 (3.50)$
& $0.152 \pm 0.031 (3.44)$ & $0.936 \pm 0.016 (3.84)$ & $0.993 \pm 0.002 (3.19)$ & $0.035 \pm 0.011 (3.06)$
& $0.188 \pm 0.046 (3.69)$ & $0.929 \pm 0.017 (4.75)$ & $0.994 \pm 0.002 (3.25)$ & $0.044 \pm 0.014 (3.44)$
& $0.133 \pm 0.024 (3.25)$ & $0.948 \pm 0.012 (4.19)$ & $0.993 \pm 0.003 (3.62)$ & $0.033 \pm 0.008 (3.44)$
& $0.164 \pm 0.022 (4.62)$ & $0.939 \pm 0.015 (4.25)$ & $0.990 \pm 0.004 (5.25)$ & $0.065 \pm 0.010 (5.44)$
& $0.306 \pm 0.023 (6.88)$ & $0.927 \pm 0.015 (6.12)$ & $0.980 \pm 0.008 (6.62)$ & $0.165 \pm 0.008 (7.69)$
& $0.400 \pm 0.019 (8.12)$ & $0.895 \pm 0.015 (7.19)$ & $0.968 \pm 0.011 (7.94)$ & $0.202 \pm 0.009 (8.38)$
& $0.425 \pm 0.025 (8.44)$ & $0.834 \pm 0.017 (8.75)$ & $0.927 \pm 0.014 (9.00)$ & $0.134 \pm 0.007 (6.88)$


## For only identity

In [195]:
df_iden = df_results[df_results.corruption == 'identity'].drop(['corruption'], axis=1).reset_index()

In [196]:
df_iden

Unnamed: 0,index,method,lam_sl,ds_size,ece,acc,nll,auroc
0,15,sl,1.0,10000,0.110566,0.928496,0.260321,0.976682
1,31,sl,1e-06,10000,0.00333,0.99632,0.015017,0.999962
2,47,sl,0.1,10000,0.161214,0.979495,0.219453,0.998018
3,63,sl,0.0001,10000,0.002353,0.996845,0.010566,0.999918
4,79,sl,0.001,10000,0.009012,0.997371,0.016153,0.999926
5,95,sl,1e-05,10000,0.001578,0.996845,0.013137,0.999881
6,111,sl,0.0001,10000,0.001776,0.995794,0.013491,0.999938
7,127,sl,1e-05,10000,0.002715,0.99632,0.009364,0.999954
8,143,sl,0.001,10000,0.042937,0.99632,0.056694,0.999709
9,159,sl,1e-06,10000,0.001496,0.997371,0.009903,0.999948


In [197]:
metrics_summ = df_iden.groupby('lam_sl').agg(
    n = pd.NamedAgg(column='acc', aggfunc='count'),
    acc_mean = pd.NamedAgg(column='acc', aggfunc='mean'),
    acc_err = pd.NamedAgg(column='acc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ece_mean = pd.NamedAgg(column='ece', aggfunc='mean'),
    ece_err = pd.NamedAgg(column='ece', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    nll_mean = pd.NamedAgg(column='nll', aggfunc='mean'),
    nll_err = pd.NamedAgg(column='nll', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    auroc_mean = pd.NamedAgg(column='auroc', aggfunc='mean'),
    auroc_err = pd.NamedAgg(column='auroc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)

In [198]:
metrics_summ

Unnamed: 0_level_0,n,acc_mean,acc_err,ece_mean,ece_err,nll_mean,nll_err,auroc_mean,auroc_err
lam_sl,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
-5.0,5,0.996635,0.000484,0.001836,0.000283,0.01149,0.001075,0.999931,7e-06
-1.0,5,0.996845,0.000297,0.002677,0.000165,0.015104,0.001615,0.999946,3e-06
1e-06,3,0.996845,0.000248,0.002412,0.000432,0.01164,0.001379,0.999952,4e-06
1e-05,3,0.99667,0.000143,0.001902,0.000334,0.010815,0.000958,0.999922,1.8e-05
0.0001,3,0.995969,0.000379,0.002244,0.0002,0.012381,0.000747,0.999921,7e-06
0.001,3,0.996144,0.000624,0.02523,0.008019,0.036492,0.009556,0.999797,5.4e-05
0.01,3,0.993165,0.00138,0.124627,0.010171,0.158087,0.010828,0.99955,0.000107
0.1,3,0.977042,0.002221,0.189735,0.012791,0.259859,0.016772,0.997933,0.000104
1.0,3,0.94129,0.005236,0.141265,0.0133,0.267487,0.006712,0.984215,0.003273


In [199]:
# For Table 1
for row in metrics_summ.itertuples():
    print(
        "& ${:.3f} \pm {:.3f}$".format(row.nll_mean, row.nll_err),
        "& ${:.3f} \pm {:.3f}$".format(row.acc_mean, row.acc_err),
        "& ${:.3f} \pm {:.3f}$".format(row.auroc_mean, row.auroc_err),
        "& ${:.3f} \pm {:.3f}$".format(row.ece_mean, row.ece_err)
    )

& $0.011 \pm 0.001$ & $0.997 \pm 0.000$ & $1.000 \pm 0.000$ & $0.002 \pm 0.000$
& $0.015 \pm 0.002$ & $0.997 \pm 0.000$ & $1.000 \pm 0.000$ & $0.003 \pm 0.000$
& $0.012 \pm 0.001$ & $0.997 \pm 0.000$ & $1.000 \pm 0.000$ & $0.002 \pm 0.000$
& $0.011 \pm 0.001$ & $0.997 \pm 0.000$ & $1.000 \pm 0.000$ & $0.002 \pm 0.000$
& $0.012 \pm 0.001$ & $0.996 \pm 0.000$ & $1.000 \pm 0.000$ & $0.002 \pm 0.000$
& $0.036 \pm 0.010$ & $0.996 \pm 0.001$ & $1.000 \pm 0.000$ & $0.025 \pm 0.008$
& $0.158 \pm 0.011$ & $0.993 \pm 0.001$ & $1.000 \pm 0.000$ & $0.125 \pm 0.010$
& $0.260 \pm 0.017$ & $0.977 \pm 0.002$ & $0.998 \pm 0.000$ & $0.190 \pm 0.013$
& $0.267 \pm 0.007$ & $0.941 \pm 0.005$ & $0.984 \pm 0.003$ & $0.141 \pm 0.013$


In [200]:
# For Table 7 - 
for row in metrics_summ.itertuples():
    print(
#         "& ${:.3f} \pm {:.3f}$".format(row.nll_mean, row.nll_err),
        "& ${:.3f} \pm {:.3f}$".format(row.acc_mean, row.acc_err),
#         "& ${:.3f} \pm {:.3f}$".format(row.auroc_mean, row.auroc_err),
        "& ${:.3f} \pm {:.3f}$".format(row.ece_mean, row.ece_err)
    )

& $0.997 \pm 0.000$ & $0.002 \pm 0.000$
& $0.997 \pm 0.000$ & $0.003 \pm 0.000$
& $0.997 \pm 0.000$ & $0.002 \pm 0.000$
& $0.997 \pm 0.000$ & $0.002 \pm 0.000$
& $0.996 \pm 0.000$ & $0.002 \pm 0.000$
& $0.996 \pm 0.001$ & $0.025 \pm 0.008$
& $0.993 \pm 0.001$ & $0.125 \pm 0.010$
& $0.977 \pm 0.002$ & $0.190 \pm 0.013$
& $0.941 \pm 0.005$ & $0.141 \pm 0.013$
