# Extract eval results - CIFAR

In [1]:
import sys
sys.path.append("./../")

In [2]:
import os
import json
import glob
import pickle

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('darkgrid')

#### Utility functions

In [3]:
# Accuracy threshold for models to be included in analysis
ACC_THRESHOLD = 0.1 # For CIFAR10

In [12]:
def extract_ood_results(model_dir, dataset_str='FMNIST'):
    """
        Get OOD metrics from model dir
    """

    # Get config
    config_json = os.path.join(model_dir, 'config.json')
    config = json.load(open(config_json, 'r'))
    
    # Extract config values
    method = config['method']
    
    # Create a table entry for parameters  as string
    param_str = method # start with method label
    for _p, _p_value in config['method_params'].items():
        _p_value_str = None
        if isinstance(_p_value, int):
            _p_value_str = '{:04d}'.format(_p_value)
        elif isinstance(_p_value, float):
            _p_value_str = '{:08.3f}'.format(_p_value)
        else:
            _p_value_str = '{}'.format(_p_value)
        param_str += '-{}={}'.format(_p, _p_value_str)
    param = param_str
    
    results = None
    
    # Get OOD result files
    ood_result_files = glob.glob(model_dir + "/ood_pred_results_{}.pkl".format(dataset_str))
    
    assert len(ood_result_files) <= 1, "More than one OOD results exists"

    # Get accuracy on clean dataset also for quality checks
    acc_results_file = os.path.join(model_dir, "ece_results_identity-1.pkl")
    
    # Get results
    for rfile in ood_result_files:
        with open(acc_results_file, 'rb') as f:
            clean_results = pickle.load(f)[0]
        with open(rfile, 'rb') as f:
            logs = pickle.load(f)[0]
            r = {
                'method': method,
                'params': param,
                'ent_ood': logs['ent_ood'],
                'ent_test': logs['ent_test'],
                'ent_delta': logs['ent_delta'],
                'acc': clean_results['acc_val'],
                'preds_ood': logs['preds_ood'],
                'preds_indomain': logs['preds_indomain']
            }
            
            if results is not None:
                results.append(r)
            else:
                results = [r]
    
    return results

#### Specify experiments

In [13]:
# CIFAR10 + VGG11
fig_prefix = "cifar10im-vgg11"
result_dirs = [
    "./../zoo/multiclass/slim/CIFAR10Im/VGG11",
    "./../zoo/multiclass/mfvi/CIFAR10Im/VGG11",
    # "./../zoo/multiclass/ls/CIFAR10Im/VGG11",
    # "./../zoo/multiclass/edl/uniform-prior/CIFAR10Im/VGG11EDL",
    # "./../zoo/multiclass/edl/computed-prior/CIFAR10/VGG11EDL"
]

## OOD Results

In [14]:
# Enumerate model directories and load evaluation results
ood_results = []
for models_root in result_dirs:
    model_dirs = list(map(lambda d: os.path.join(models_root, d), os.listdir(models_root)))
    for _m in model_dirs:
        _r = extract_ood_results(_m, dataset_str='SVHN')
        if _r:
            ood_results.extend(_r)
        else:
            print("Skipping", _m)
df_ood_results = pd.DataFrame(ood_results)

Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha5e+03-1-20221008231609
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha1e+03-1-20221008205200
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha1e+03-3-20221008205200
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha1e+03-4-20221008205150
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha5e+03-2-20221008231609
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha5e+03-5-20221008231634
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha1e+03-2-20221008205200
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha5e+03-4-20221008231634
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha1e+03-5-20221008205150
Skipping ./../zoo/multiclass/slim/CIFAR10Im/VGG11/slim-alpha5e+03-3-20221008231609


In [15]:
df_ood_results

Unnamed: 0,method,params,ent_ood,ent_test,ent_delta,acc,preds_ood,preds_indomain
0,slim,slim-alpha=0500,0.836934,0.276498,0.560436,0.862588,"[[0.003388115, 0.0011053467, 0.004107215, 0.64...","[[1.7554448e-06, 4.853828e-15, 0.9999575, 3.33..."
1,slim,slim-alpha=0500,0.477874,0.111562,0.366312,0.85657,"[[0.978045, 0.00074251316, 0.00039773155, 4.54...","[[0.00011624745, 1.25305e-07, 0.99814737, 6.23..."
2,slim,slim-alpha=0500,0.422121,0.138586,0.283535,0.859579,"[[0.9355871, 0.0083082095, 0.0163987, 0.031681...","[[0.0029744087, 9.273199e-05, 0.9563514, 0.000..."
3,slim,slim-alpha=0500,0.347263,0.118718,0.228545,0.865597,"[[0.9960215, 0.00042134887, 0.00039003772, 0.0...","[[1.3276577e-06, 2.9811653e-10, 0.9999938, 4.5..."
4,slim,slim-alpha=0500,0.500496,0.113566,0.38693,0.85657,"[[0.0018223327, 0.0008067546, 0.034741826, 0.8...","[[2.7520198e-07, 3.9047192e-16, 0.99999964, 5...."
5,mfvi,mfvi,0.233352,0.088368,0.144984,0.861585,"[[0.9969841, 9.8212826e-05, 0.0021724652, 6.65...","[[2.518763e-10, 1.4796154e-19, 1.0, 8.053085e-..."
6,mfvi,mfvi,0.416558,0.10544,0.311118,0.852558,"[[0.0010561463, 0.00048093608, 0.00031851121, ...","[[2.2839561e-08, 1.960713e-10, 1.0, 4.3077595e..."
7,mfvi,mfvi,0.327041,0.118063,0.208978,0.864594,"[[0.9982009, 0.0005297746, 0.00085096154, 0.00...","[[1.4852645e-05, 8.297025e-09, 0.99997604, 2.3..."
8,mfvi,mfvi,0.325159,0.102283,0.222876,0.853561,"[[0.999987, 1.5101775e-06, 1.1020057e-05, 2.28...","[[1.8850683e-07, 2.7793004e-11, 0.9999958, 5.7..."
9,mfvi,mfvi,0.261203,0.083718,0.177485,0.864594,"[[0.99999607, 1.942358e-08, 3.899033e-06, 7.14...","[[3.1230033e-07, 2.4974143e-09, 0.9999985, 1.4..."


In [8]:
df_ood_results = df_ood_results[df_ood_results.acc > ACC_THRESHOLD]

In [9]:
ood_metrics_summ = df_ood_results.groupby('params').agg(
    n = pd.NamedAgg(column='ent_ood', aggfunc='count'),
    ent_ood_mean = pd.NamedAgg(column='ent_ood', aggfunc='mean'),
    ent_ood_err = pd.NamedAgg(column='ent_ood', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ent_test_mean = pd.NamedAgg(column='ent_test', aggfunc='mean'),
    ent_test_err = pd.NamedAgg(column='ent_test', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ent_delta_mean = pd.NamedAgg(column='ent_delta', aggfunc='mean'),
    ent_delta_err = pd.NamedAgg(column='ent_delta', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
    ent_acc_mean = pd.NamedAgg(column='acc', aggfunc='mean'),
    ent_acc_err = pd.NamedAgg(column='acc', aggfunc=lambda x: np.std(x) / np.sqrt(x.shape[0])),
)
ood_metrics_summ['ood_ratio'] = ood_metrics_summ.ent_delta_mean / ood_metrics_summ.ent_test_mean

In [10]:
ood_metrics_summ

Unnamed: 0_level_0,n,ent_ood_mean,ent_ood_err,ent_test_mean,ent_test_err,ent_delta_mean,ent_delta_err,ent_acc_mean,ent_acc_err,ood_ratio
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
mfvi,5,0.312663,0.028352,0.099574,0.005516,0.213089,0.025003,0.859378,0.002363,2.139998
slim-alpha=0500,5,0.516937,0.075359,0.151786,0.028213,0.365152,0.05057,0.860181,0.001569,2.405704


In [11]:
def get_prefix(x):
    parts = x.split('-')
    if len(parts) < 2:
        return '-'.join(parts)
    else:
        return '-'.join(parts[:-1])

# Find groups of experiments without looking at last param
unique_prefixes = list(sorted(set(
        list(map(get_prefix, ood_metrics_summ.index))
    )))

# For each of the unique prefixes, find the best in validation group according to
# NLL and printout the entropy in OOD
r = []
for pfx in unique_prefixes:
    _df_val = metrics_summ[metrics_summ.index.str.startswith(pfx)]
    idx = _df_val.nll_mean.idxmin()

    # Now get the corresponding results from OOD set
    r.append(ood_metrics_summ.loc[idx])

df_ood_best = pd.DataFrame(r)

NameError: name 'metrics_summ' is not defined

In [None]:
df_ood_best

In [None]:
# Table - Results OOD
ent_delta_max = df_ood_best.ent_delta_mean.max()
ent_ratio_max = df_ood_best.ood_ratio.max()

for row in df_ood_best.itertuples():
    buffer = "{:50s}".format(row.Index)

    # In-domain entropy
    buffer += "& ${:.3f} \pm {:.3f}$".format(row.ent_test_mean, row.ent_test_err)

    # OOD entropy
    buffer += "& ${:.3f} \pm {:.3f}$".format(row.ent_ood_mean, row.ent_ood_err)

    if row.ent_delta_mean == ent_delta_max:
        buffer += "& $\mathbf{{{:.3f} \pm {:.3f}}}$".format(row.ent_delta_mean, row.ent_delta_err)
    else:
        buffer += "&          ${:.3f} \pm {:.3f}$".format(row.ent_delta_mean, row.ent_delta_err)

    if row.ood_ratio == ent_ratio_max:
        buffer += "& $\mathbf{{{:.2f}}}$".format(row.ood_ratio)
    else:
        buffer += "&          ${:.2f}$".format(row.ood_ratio)

    print(buffer)