# Extract OOD test results

In [None]:
import sys
sys.path.append("./../")

In [None]:
import os
import json
import glob
import pickle

import numpy as np
import pandas as pd

from tbparse import SummaryReader
import matplotlib.pyplot as plt
plt.style.use('seaborn')

In [None]:
def extract_results(model_dir):
    """
        Get OOD metrics from model dir
    """

    # Get config
    config_json = os.path.join(model_dir, 'config.json')
    config = json.load(open(config_json, 'r'))
    
    # Extract config values
    method = config['method']
    lam_sl = config['method_params'].get('lam_sl', 0.0)
    ds_size = config['ds_params'].get('size', 'Full')
    
    
    results = None
    
    # Get OOD result files
    ood_result_files = glob.glob(model_dir + "/ece_results_*.pkl")
    
    # Get results
    for rfile in ood_result_files:
        filename = os.path.basename(rfile)
        # Get corruption name from file name
        corr_name = ' '.join(filename.split('_')[2:])[:-4]
        with open(rfile, 'rb') as f:
            logs = pickle.load(f)[0]
            r = {
                'method': method,
                'lam_sl': lam_sl,
                'ds_size': ds_size,
                'corruption': corr_name,
                'ece': logs['ece_uncal'],
                'acc': logs['acc']
            }
            
            if results is not None:
                results.append(r)
            else:
                results = [r]
    
    return results

## Model dirs

In [None]:
# # LeNet + 1000
# models_root = "./../zoo/abl-alpha100-uniform-lenet/BinaryMNISTC-1000-53-identity/LeNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/LeNet"

# # LeNet + 10000
# models_root = "./../zoo/abl-alpha100-uniform-lenet/BinaryMNISTC-10000-53-identity/LeNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-10000-53-identity/LeNet"

# # ConvNet + 1000
# models_root = "./../zoo/abl-alpha100-uniform-convnet/BinaryMNISTC-1000-53-identity/ConvNet"
# elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-1000-53-identity/ConvNet"

# ConvNet + 10000
models_root = "./../zoo/abl-alpha100-uniform-convnet/BinaryMNISTC-10000-53-identity/ConvNet"
elbo_models_root = "./../zoo/bmnist53-mfvi/BinaryMNISTC-10000-53-identity/ConvNet"

## S-ELBO results

In [None]:
model_dirs = list(map(lambda d: os.path.join(models_root, d), os.listdir(models_root)))

In [None]:
results = []
for _m in model_dirs:
    results.extend(extract_results(_m))

## EBLO results

In [None]:
model_dirs = list(map(lambda d: os.path.join(elbo_models_root, d), os.listdir(elbo_models_root)))

In [None]:
for _m in model_dirs:
    results.extend(extract_results(_m))
df_results = pd.DataFrame(results)

In [None]:
df_results

In [None]:
metrics_summ = df_results.groupby('lam_sl').agg(
    n = pd.NamedAgg(column='acc', aggfunc='count'),
    acc_mean = pd.NamedAgg(column='acc', aggfunc='mean'),
    acc_err = pd.NamedAgg(column='acc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ece_mean = pd.NamedAgg(column='ece', aggfunc='mean'),
    ece_err = pd.NamedAgg(column='ece', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)

In [None]:
gdf_corr = df_results.groupby('corruption')

In [None]:
rdfs = []
for k, _df in gdf_corr:
#     r1 = _df.groupby('lam_sl').agg({'corruption': 'first','ece': 'mean'}).sort_values(by='ece').reset_index()
    r1 = _df.groupby('lam_sl').agg({'corruption': 'first', 'acc': 'mean', 'ece': 'mean'}).reset_index()
    r1['ece_rank'] = r1.ece.rank(ascending=True)
    r1['acc_rank'] = r1.acc.rank(ascending=False)
    rdfs.append(r1)

df_ranked = pd.concat(rdfs)
# df_ranked.reset_index(inplace=True)


In [None]:
metrics_summ

In [None]:
df_rank_results = df_ranked.groupby('lam_sl').agg(
    ece_rank_mean = pd.NamedAgg(column='ece_rank', aggfunc='mean'),
    ece_rank_err = pd.NamedAgg(column='ece_rank', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    acc_rank_mean = pd.NamedAgg(column='acc_rank', aggfunc='mean'),
    acc_rank_err = pd.NamedAgg(column='acc_rank', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0]))
)

In [None]:
df_final = metrics_summ.merge(df_rank_results, on='lam_sl')

## Printout final results

In [None]:
df_final

Print out latex table

In [None]:
for row in df_final.itertuples():
    print(
#         "${:.0e}$".format(row.Index),
        "& ${:.3f} \pm {:.3f}$".format(row.acc_mean, row.acc_err),
        # "& ${:.2f} \pm {:.2f}$".format(row.acc_rank_mean, row.acc_rank_err),
        "& ${:.2f}$".format(row.acc_rank_mean),
        "& ${:.3f} \pm {:.3f}$".format(row.ece_mean, row.ece_err),
#         "& ${:.2f} \pm {:.2f}$".format(row.ece_rank_mean, row.ece_rank_err)
        "& ${:.2f}$".format(row.ece_rank_mean)
    )