# Extract OOD test results

In [133]:
import sys
sys.path.append("./../")

In [134]:
import os
import json
import glob
import pickle

import numpy as np
import pandas as pd

from tbparse import SummaryReader
import matplotlib.pyplot as plt
plt.style.use('seaborn')

In [135]:
def extract_results(model_dir):
    """
        Get OOD metrics from model dir
    """

    # Get config
    config_json = os.path.join(model_dir, 'config.json')
    config = json.load(open(config_json, 'r'))
    
    # Extract config values
    method = config['method']
    lam_sl = config['method_params'].get('lam_sl', 0.0)
    ds_size = config['ds_params'].get('size', 'Full')
    
    
    results = None
    
    # Get OOD result files
    ood_result_files = glob.glob(model_dir + "/ece_results_*.pkl")
    
    # Get results
    for rfile in ood_result_files:
        filename = os.path.basename(rfile)
        # Get corruption name from file name
        corr_name = ' '.join(filename.split('_')[2:])[:-4]
        with open(rfile, 'rb') as f:
            logs = pickle.load(f)[0]
            r = {
                'method': method,
                'lam_sl': lam_sl,
                'ds_size': ds_size,
                'corruption': corr_name,
                'ece': logs['ece_uncal'],
                'acc': logs['acc']
            }
            
            if results is not None:
                results.append(r)
            else:
                results = [r]
    
    return results

## Model dirs

In [136]:
# LeNet + 1000
models_root = "./../zoo/abl-alpha100-uniform-lenet/BinaryMNISTC-1000-53-identity/LeNet"
elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/LeNet"

# # LeNet + 10000
# models_root = "./../zoo/abl-alpha100-uniform-lenet/BinaryMNISTC-10000-53-identity/LeNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-10000-53-identity/LeNet"

# # ConvNet + 1000
# models_root = "./../zoo/abl-alpha100-uniform-convnet/BinaryMNISTC-1000-53-identity/ConvNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/ConvNet"

# # ConvNet + 10000
# models_root = "./../zoo/abl-alpha100-uniform-convnet/BinaryMNISTC-10000-53-identity/ConvNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-10000-53-identity/ConvNet"

## S-ELBO results

In [137]:
model_dirs = list(map(lambda d: os.path.join(models_root, d), os.listdir(models_root)))

In [138]:
results = []
for _m in model_dirs:
    results.extend(extract_results(_m))

## EBLO results

In [139]:
model_dirs = list(map(lambda d: os.path.join(elbo_models_root, d), os.listdir(elbo_models_root)))

In [140]:
for _m in model_dirs:
    print(_m)
    results.extend(extract_results(_m))
df_results = pd.DataFrame(results)

./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/LeNet/mfvi-sz1000-2-20220508171636
./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/LeNet/mfvi-sz1000-5-20220508171615
./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/LeNet/mfvi-sz1000-4-20220508171611
./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/LeNet/mfvi-sz1000-1-20220508171611
./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/LeNet/mfvi-sz1000-3-20220508171632


In [141]:
df_results

Unnamed: 0,method,lam_sl,ds_size,corruption,ece,acc
0,sl,0.00001,1000,impulse noise,0.014754,0.978970
1,sl,0.00001,1000,stripe,0.010404,0.985804
2,sl,0.00001,1000,glass blur,0.020491,0.969506
3,sl,0.00001,1000,fog,0.072853,0.984227
4,sl,0.00001,1000,motion blur,0.030828,0.958991
...,...,...,...,...,...,...
635,mfvi,0.00000,1000,dotted line,0.006007,0.988959
636,mfvi,0.00000,1000,shear,0.018195,0.972135
637,mfvi,0.00000,1000,spatter,0.007474,0.987382
638,mfvi,0.00000,1000,brightness,0.010573,0.982124


In [142]:
metrics_summ = df_results.groupby('lam_sl').agg(
    n = pd.NamedAgg(column='acc', aggfunc='count'),
    acc_mean = pd.NamedAgg(column='acc', aggfunc='mean'),
    acc_err = pd.NamedAgg(column='acc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ece_mean = pd.NamedAgg(column='ece', aggfunc='mean'),
    ece_err = pd.NamedAgg(column='ece', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)

In [143]:
gdf_corr = df_results.groupby('corruption')

In [144]:
rdfs = []
for k, _df in gdf_corr:
#     r1 = _df.groupby('lam_sl').agg({'corruption': 'first','ece': 'mean'}).sort_values(by='ece').reset_index()
    r1 = _df.groupby('lam_sl').agg({'corruption': 'first', 'acc': 'mean', 'ece': 'mean'}).reset_index()
    r1['ece_rank'] = r1.ece.rank(ascending=True)
    r1['acc_rank'] = r1.acc.rank(ascending=False)
    rdfs.append(r1)

df_ranked = pd.concat(rdfs)
# df_ranked.reset_index(inplace=True)


In [145]:
metrics_summ

Unnamed: 0_level_0,n,acc_mean,acc_err,ece_mean,ece_err
lam_sl,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0.0,80,0.955527,0.006963,0.036677,0.006378
1e-06,80,0.956671,0.006722,0.034794,0.006145
1e-05,80,0.958563,0.006735,0.034409,0.006146
0.0001,80,0.956283,0.007078,0.035745,0.006323
0.001,80,0.950145,0.007306,0.03987,0.005675
0.01,80,0.94423,0.007341,0.108541,0.00773
0.1,80,0.925743,0.008882,0.190696,0.007879
1.0,80,0.864012,0.012176,0.198021,0.006546


In [146]:
df_rank_results = df_ranked.groupby('lam_sl').agg(
    ece_rank_mean = pd.NamedAgg(column='ece_rank', aggfunc='mean'),
    ece_rank_err = pd.NamedAgg(column='ece_rank', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    acc_rank_mean = pd.NamedAgg(column='acc_rank', aggfunc='mean'),
    acc_rank_err = pd.NamedAgg(column='acc_rank', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0]))
)

In [147]:
df_final = metrics_summ.merge(df_rank_results, on='lam_sl')

## Printout final results

In [148]:
df_final

Unnamed: 0_level_0,n,acc_mean,acc_err,ece_mean,ece_err,ece_rank_mean,ece_rank_err,acc_rank_mean,acc_rank_err
lam_sl,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0.0,80,0.955527,0.006963,0.036677,0.006378,2.875,0.373696,3.25,0.419263
1e-06,80,0.956671,0.006722,0.034794,0.006145,3.1875,0.355962,3.1875,0.268368
1e-05,80,0.958563,0.006735,0.034409,0.006146,3.1875,0.333293,2.25,0.272431
0.0001,80,0.956283,0.007078,0.035745,0.006323,2.625,0.466662,2.4375,0.404745
0.001,80,0.950145,0.007306,0.03987,0.005675,4.0625,0.347634,4.625,0.173993
0.01,80,0.94423,0.007341,0.108541,0.00773,5.8125,0.181546,5.4375,0.384959
0.1,80,0.925743,0.008882,0.190696,0.007879,6.9375,0.39989,6.8125,0.131659
1.0,80,0.864012,0.012176,0.198021,0.006546,7.3125,0.361407,8.0,0.0


Print out latex table

In [149]:
for row in df_final.itertuples():
    print(
#         "${:.0e}$".format(row.Index),
        "& ${:.3f} \pm {:.3f}$".format(row.acc_mean, row.acc_err),
        # "& ${:.2f} \pm {:.2f}$".format(row.acc_rank_mean, row.acc_rank_err),
        "& ${:.2f}$".format(row.acc_rank_mean),
        "& ${:.3f} \pm {:.3f}$".format(row.ece_mean, row.ece_err),
#         "& ${:.2f} \pm {:.2f}$".format(row.ece_rank_mean, row.ece_rank_err)
        "& ${:.2f}$".format(row.ece_rank_mean)
    )

& $0.956 \pm 0.007$ & $3.25$ & $0.037 \pm 0.006$ & $2.88$
& $0.957 \pm 0.007$ & $3.19$ & $0.035 \pm 0.006$ & $3.19$
& $0.959 \pm 0.007$ & $2.25$ & $0.034 \pm 0.006$ & $3.19$
& $0.956 \pm 0.007$ & $2.44$ & $0.036 \pm 0.006$ & $2.62$
& $0.950 \pm 0.007$ & $4.62$ & $0.040 \pm 0.006$ & $4.06$
& $0.944 \pm 0.007$ & $5.44$ & $0.109 \pm 0.008$ & $5.81$
& $0.926 \pm 0.009$ & $6.81$ & $0.191 \pm 0.008$ & $6.94$
& $0.864 \pm 0.012$ & $8.00$ & $0.198 \pm 0.007$ & $7.31$
