In [1]:
%load_ext autoreload
%load_ext tensorboard
%matplotlib inline

In [2]:
import matplotlib
import numpy as np
import os
import random
import yaml
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
from matplotlib import cm
import seaborn as sns
from importlib import reload
from pathlib import Path
import sklearn
import joblib
import torch
import pandas as pd
import copy

# Set the font to a nicer font
rc('text', usetex=True)
plt.style.use('seaborn-whitegrid')
plt.rcParams["font.family"] = "serif"

In [3]:
def get_metric_from_dict(results, method, metric):
    vals = []
    for _, metric_keys in results.items():
        for candidate_method, metric_map in metric_keys.items():
            if method != candidate_method:
                continue
            for metric_name, val in metric_map.items():
                if metric_name == metric:
                    vals.append(val)
    return vals


# Standarized Colors and Markers

In [30]:
# Plot Hyperparameters
clrs = sns.color_palette("deep", 15)
MAIN_METHOD = "MixCEM Final" #"Entropy CMCMixCEM"

color_map = {
    "Bayes MLP": "black",
    "Bayes Classifier": "black",
    "MixCEM (ours)": "red",
    "MixCEM (No Calibration)": "salmon",
    "MixCEM (no IntCEM loss)": "black",
    "MixCEM + IntCEM": "orange",
    "MixIntCEM (ours)": "orange",
    "IntCEM": "cyan",
    "Logit Joint CBM": "salmon",
    "Independent CBM": "lightgreen",
    "Sequential CBM": "slateblue",
}

baselines_to_include = [
    "Joint CBM",
    "Hybrid-CBM",
    "CEM",
    # "Sigmoidal CEM",
    "IntCEM",
    "ProbCBM",
    "Posthoc CBM",
    "Posthoc Hybrid CBM",
    "Entropy CMCMixCEM",
    "Bayes MLP",
    "Sequential CBM",
    "Independent CBM",
    "Logit Joint CBM",
]
for idx, baseline in enumerate(baselines_to_include):
    if baseline not in color_map:
        color_map[baseline] = clrs[idx]

markers = {
    "Joint CBM": '-o',
    "Vanilla CBM": '-o',
    "Sigmoidal Joint CBM": '-o',
    "Logit Joint CBM": '-v',
    "Sequential CBM": '-^',
    "Independent CBM": '-2',
    "Hybrid-CBM": '-v',
    "Hybrid CBM": '-v',
    "CEM": '-^',
    "IntCEM": '-x',
    "ProbCBM": '-s',
    "Posthoc CBM": '-p',
    "P-CBM": '-p',
    "Posthoc Hybrid CBM": '-1',
    "Hybrid Posthoc CBM": '-1',
    "Hybrid P-CBM": '-1',
    "MixCEM (ours)": '--*',
    "MixCEM (No Calibration)": ':*',
    "MixCEM (no IntCEM loss)": ":*",
    "MixCEM + IntCEM": ":*",
    "MixIntCEM (ours)": ":*",
    "Bayes MLP": ":.",
    "Bayes Classifier": ":.",
}

max_limit = 10

select_metric = 'val_acc_y_random_group_level_True_use_prior_False_int_auc'

rename_map = {
    "Entropy CMCMixIntCEM": "MixIntCEM (ours)",
    "Entropy CMCMixCEM": "MixCEM (ours)",
    "MixCEM Final": "MixCEM (ours)",
    "MixCEM Final All": "MixCEM (ours)",
    "Entropy CMCMixIntCEM No Calibration": "MixIntCEM (No Calibration)",
    "Entropy CMCMixCEM No Calibration": "MixCEM (No Calibration)",
    "MixCEM Final No Calibration": "MixCEM (No Calibration)",
    "Posthoc Hybrid CBM": "Residual P-CBM",
    "Hybrid Posthoc CBM": "Residual P-CBM",
    "Posthoc CBM": "P-CBM",
    "Bayes MLP": "Bayes Classifier",
    "Joint CBM": "Vanilla CBM",
    "Hybrid-CBM": "Hybrid CBM",
}
used_rename_map = rename_map
show_variance = True
ood_suffix = 'OOD_sap_0.1_'

# Training Times Comparisons

In [34]:
from texttable import Texttable
import latextable
from collections import defaultdict

results_to_include = [
    dict(
        path='/anfs/bigdisc/me466/mixcem_results/cub_complete/',
        name='\\texttt{CUB}',
    ),
    # dict(
    #     path='/anfs/bigdisc/me466/mixcem_results/cub_incomplete/',
    #     name='\\texttt{CUB-Incomplete}',
    #     rename={'Posthoc Hybrid CBM': 'Hybrid Posthoc CBM'},
    # ),
    dict(
        path='/anfs/bigdisc/me466/mixcem_results/awa2_complete/',
        name='\\texttt{AwA2}',
    ),
    # dict(
    #     path='/anfs/bigdisc/me466/mixcem_results/awa2_incomplete/',
    #     name='\\texttt{AwA2-Incomplete}',
    # ),
    dict(
        path='/anfs/bigdisc/me466/mixcem_results/cifar10/',
        name='\\texttt{Cifar10}',
    ),
    # dict(
    #     path='/anfs/bigdisc/me466/mixcem_results/celeba/',
    #     name='\\texttt{CelebA}',
    # ),
]

for res in results_to_include:
    print(res['name'])
    res['results'] = joblib.load(os.path.join(res['path'], 'results.joblib'))

latex_table = Texttable()
baselines_to_include = [
    # "DNN",
    # "Joint CBM",
    # "Hybrid-CBM",
    # "Sigmoidal CEM",
    "ProbCBM",
    # "Posthoc CBM",
    # "Posthoc Hybrid CBM",
    "CEM",
    "IntCEM",
    MAIN_METHOD,
]
num_stds = 2
select_metric = 'val_acc_y_random_group_level_True_use_prior_False_int_auc'
cols = ["Method"] + [x['name'] for x in results_to_include]
rows = []
metrics_to_include = [('training_time', 'black')]
col_results =  [defaultdict(list) for _ in metrics_to_include]
best_col_results = [defaultdict(lambda: (float("inf"), None)) for _ in metrics_to_include]

for idx, label in enumerate(baselines_to_include):
    new_row = []
    for col, dataset_results in enumerate(results_to_include):
        real_label = dataset_results.get('rename', {}).get(label, label)
        used_select_metric = dataset_results.get(
            'select_metric',
            select_metric,
        )
        dataset_results['selected_models'] = joblib.load(os.path.join(dataset_results['path'], f'selected_models_{used_select_metric}.joblib'))
        selected_models = dataset_results['selected_models']
        model_name = selected_models.get(
            real_label + " (Baseline)",
            real_label + " (Baseline)",
        )
        results = dataset_results['results']
        metric_vals = []
        for metric_idx, (metric, color) in enumerate(metrics_to_include):
            metric_val = np.array(get_metric_from_dict(
                results,
                model_name,
                metric,
            ))
            num_epochs = np.array(get_metric_from_dict(
                results,
                model_name,
                'num_epochs',
            ))
            metric_val = metric_val / num_epochs
            used_name = rename_map.get(label, label)
            if len(new_row) == 0:
                new_row.append(used_name)
            mean = np.mean(metric_val, axis=0)
            std = np.std(metric_val, axis=0)
            if mean < best_col_results[metric_idx][col][0]:
                best_col_results[metric_idx][col] = (mean, std)
            col_results[metric_idx][col].append((mean, std))
    rows.append(new_row)

for idx, label in enumerate(baselines_to_include):
    new_row = rows[idx]
    for col, dataset_results in enumerate(results_to_include):
        entry = '$'
        for metric_idx, (metric, color) in enumerate(metrics_to_include):
            mean, std = col_results[metric_idx][col][idx]
            if metric_idx:
                entry += " \; / \;"
            if "DNN" in label and metric == "test_auc_c":
                # Then this is not applicable
                entry += " \\textcolor{gray}{\\text{N/A \; \; \; \; \;}} "
            else:
                if mean + num_stds * std <= (best_col_results[metric_idx][col][0] - num_stds * best_col_results[metric_idx][col][1]):
                    entry += (f" \\textcolor{{{color}}}{{\\underline{{{mean:.2f}_{{\pm {std:.2f}}}}}}}")
                else:
                    entry += (f" \\textcolor{{{color}}}{{{mean:.2f}_{{\pm {std:.2f}}}}}")
        new_row.append(entry + "$")

latex_table.set_cols_align(["c" for _ in cols])
latex_table.set_cols_valign(["m" for _ in cols])
latex_table.add_rows([cols] + rows)
print('-- Example 1: Basic --')
print('Texttable Output:')
print(latex_table.draw())
print('\nLatextable Output:')
print(
    latextable.draw_latex(
        latex_table,
        caption=(
            "Task Accuracy results"
        ),
        caption_above=True,
        label="tab:task_accuracy_summary",
        position="ht",
        use_booktabs=True,
))


\texttt{CUB}
\texttt{AwA2}
\texttt{Cifar10}
-- Example 1: Basic --
Texttable Output:
+---------------+--------------------+--------------------+--------------------+
|    Method     |    \texttt{CUB}    |   \texttt{AwA2}    |  \texttt{Cifar10}  |
|               | $ \textcolor{black | $ \textcolor{black | $ \textcolor{black |
|    ProbCBM    |    }{67.90_{\pm    |   }{182.38_{\pm    |   }{1066.27_{\pm   |
|               |      6.54}}$       |      6.09}}$       |      0.00}}$       |
+---------------+--------------------+--------------------+--------------------+
|               | $ \textcolor{black | $ \textcolor{black | $ \textcolor{black |
|      CEM      |    }{30.34_{\pm    |    }{82.44_{\pm    |    }{32.37_{\pm    |
|               |      2.87}}$       |      9.38}}$       |      0.52}}$       |
+---------------+--------------------+--------------------+--------------------+
|               | $ \textcolor{black | $ \textcolor{black | $ \textcolor{black |
|    IntCEM     |    }{5

In [35]:
from texttable import Texttable
import latextable
from collections import defaultdict

dataset_results = dict(
    path='/anfs/bigdisc/me466/mixcem_results/cub_incomplete_smaller_ablation/',
    name='\\texttt{CUB-Incomplete}',
)
select_metric = 'val_acc_y_random_group_level_True_use_prior_False_int_auc'
dataset_results['results'] = joblib.load(os.path.join(dataset_results['path'], 'results.joblib'))
dataset_results['selected_models'] = joblib.load(os.path.join(dataset_results['path'], f'selected_models_{select_metric}.joblib'))

latex_table = Texttable()
baselines_to_include = [
    "Base IntCEM",
    "Base MixCEM"
]
used_rename_map = {
    "Base IntCEM": "IntCEM",
    "Base MixCEM": "MixCEM (ours)",
}
num_stds = 2
metrics_to_include = [('training_time', 'Training Time (min)', None, 'min'), ('num_epochs', 'Epochs to Convergence', None, 'min'), ('sec_per_epoch', 'Seconds per Epoch', None, 'min')]
cols = [""] + [name for (_, name, _, _) in metrics_to_include]
rows = []
col_results =  defaultdict(list)
best_col_results = defaultdict(lambda: (float("inf"), None))

for idx, label in enumerate(baselines_to_include):
    new_row = []
    selected_models = dataset_results['selected_models']
    model_name = selected_models.get(
        label,
        label,
    )
    results = dataset_results['results']
    metric_vals = []
    for col, (metric, col_name, color, mode) in enumerate(metrics_to_include):
        if metric == 'sec_per_epoch':
            metric_val = np.array(get_metric_from_dict(
                results,
                model_name,
                'training_time',
            ))
            num_epochs = np.array(get_metric_from_dict(
                results,
                model_name,
                'num_epochs',
            ))
            metric_val = metric_val / num_epochs
        elif metric == 'training_time':
            metric_val = np.array(get_metric_from_dict(
                results,
                model_name,
                metric,
            )) / 60
        else:
            metric_val = np.array(get_metric_from_dict(
                results,
                model_name,
                metric,
            ))
        used_name = used_rename_map.get(label, label)
        if len(new_row) == 0:
            new_row.append(used_name)
        mean = np.mean(metric_val, axis=0)
        std = np.std(metric_val, axis=0)
        if mode == 'min':
            if mean < best_col_results[col][0]:
                best_col_results[col] = (mean, std)
        else:
            if mean > best_col_results[col][0]:
                best_col_results[col] = (mean, std)
        col_results[col].append((mean, std))
    rows.append(new_row)

for idx, label in enumerate(baselines_to_include):
    new_row = rows[idx]
    for col, (metric, col_name, color, mode) in enumerate(metrics_to_include):
        entry = '$'
        mean, std = col_results[col][idx]
        if mode == 'max':
            if mean + num_stds * std >= (best_col_results[col][0] - num_stds * best_col_results[col][1]):
                if color is not None:
                    entry += (f" \\textcolor{{{color}}}{{\\underline{{{mean:.2f}_{{\pm {std:.2f}}}}}}}")
                else:
                    entry += (f" \\underline{{{mean:.2f}_{{\pm {std:.2f}}}}}")
            else:
                if color is not None:
                    entry += (f" \\textcolor{{{color}}}{{{mean:.2f}_{{\pm {std:.2f}}}}}")
                else:
                    entry += (f" {mean:.2f}_{{\pm {std:.2f}}}")
        else:
            if mean + num_stds * std <= (best_col_results[col][0] - num_stds * best_col_results[col][1]):
                if color is not None:
                    entry += (f" \\textcolor{{{color}}}{{\\underline{{{mean:.2f}_{{\pm {std:.2f}}}}}}}")
                else:
                    entry += (f" \\underline{{{mean:.2f}_{{\pm {std:.2f}}}}}")
            else:
                if color is not None:
                    entry += (f" \\textcolor{{{color}}}{{{mean:.2f}_{{\pm {std:.2f}}}}}")
                else:
                    entry += (f" {mean:.2f}_{{\pm {std:.2f}}}")
        new_row.append(entry + "$")

latex_table.set_cols_align(["c" for _ in cols])
latex_table.set_cols_valign(["m" for _ in cols])
latex_table.add_rows([cols] + rows)
print(latex_table.draw())
print('\nLatextable Output:')
print(
    latextable.draw_latex(
        latex_table,
        caption=(
            "Training ablation"
        ),
        caption_above=True,
        label="tab:training_times",
        position="ht",
        use_booktabs=True,
))


+---------------+--------------------+--------------------+--------------------+
|               |   Training Time    |     Epochs to      | Seconds per Epoch  |
|               |       (min)        |    Convergence     |                    |
|    IntCEM     |    $ 31.66_{\pm    |   $ 115.00_{\pm    |    $ 16.54_{\pm    |
|               |       3.87}$       |      15.00}$       |       0.14}$       |
+---------------+--------------------+--------------------+--------------------+
| MixCEM (ours) |    $ 32.95_{\pm    |   $ 160.00_{\pm    |    $ 12.35_{\pm    |
|               |       1.65}$       |       0.00}$       |       0.62}$       |
+---------------+--------------------+--------------------+--------------------+

Latextable Output:
\begin{table}[ht]
	\caption{Training ablation}
	\begin{center}
		\begin{tabular}{cccc}
			\toprule
			 & Training Time (min) & Epochs to Convergence & Seconds per Epoch \\
			\midrule
			IntCEM & $ 31.66_{\pm 3.87}$ & $ 115.00_{\pm 15.00}$ & $ 16.54_{

In [23]:
best_col_results

defaultdict(<function __main__.<lambda>()>,
            {0: (-inf, None), 1: (-inf, None), 2: (-inf, None)})