In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from tqdm import tqdm

os.environ["KERAS_BACKEND"] = "jax"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
import keras
import matplotlib.pyplot as plt
import numpy as np
from exp_utils import *
from omegaconf import OmegaConf
from sbi_mcmc.tasks import *
from sbi_mcmc.tasks.tasks_utils import get_task_logp_func
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling
from sbi_mcmc.utils.utils import *
from tqdm.autonotebook import tqdm

In [None]:
task_names = [
    "GEV",
    "BernoulliGLM",
    "psychometric_curve_overdispersion",
    "CustomDDM(dt=0.0001)",
]
max_num_runs = 100

In [None]:
metrics_results_tasks = {k: {} for k in task_names}
metrics_results_tasks_ids = {k: {} for k in task_names}
for task_name in task_names:
    print("=" * 10)
    print(task_name)
    stuff = get_stuff(
        task_name=task_name,
        test_dataset_name="test_dataset_chunk_1",
        job=None,
        overwrite_stats=False,
    )
    task = stuff["task"]
    paths = stuff["paths"]
    test_dataset_name = stuff["test_dataset_name"]
    config = stuff["config"]

    stats = {}
    for job in ["ood", "psis", "abi", "chees_hmc"]:
        stats_logger = PickleStatLogger(
            paths[f"{job}_stats"], overwrite=False, verbose=True
        )
        stats[job] = stats_logger.data

    stats["training"] = PickleStatLogger(
        paths["training_result_dir"] / "training_record.pkl", overwrite=False
    ).data

    abi_accept_inds = set(
        stats["ood"][f"Mahalanobis_{test_dataset_name}"]["ood_accept_inds"]
    )
    abi_reject_inds = set(
        stats["ood"][f"Mahalanobis_{test_dataset_name}"]["ood_failed_inds"]
    )
    psis_accept_inds = set(stats["psis"]["accept_inds"])
    psis_reject_inds = set(stats["psis"]["reject_inds"])
    assert abi_reject_inds == psis_accept_inds | psis_reject_inds
    assert abi_reject_inds.issuperset(psis_accept_inds)

    chees_hmc_reject_inds = set(stats["chees_hmc"]["chees_hmc_reject_inds"])
    chees_hmc_accept_inds = set(stats["chees_hmc"]["chees_hmc_accept_inds"])
    assert psis_reject_inds == chees_hmc_reject_inds | chees_hmc_accept_inds
    assert psis_reject_inds.issuperset(chees_hmc_reject_inds)

    metrics_logger = PickleStatLogger(
        paths["metrics_stats"], overwrite=False, verbose=True
    )

    inds_dict = {
        "ABI(accepted)": abi_accept_inds,
        "ABI(rejected)": abi_reject_inds,
        "PSIS": psis_accept_inds,
        "ChEES-HMC": chees_hmc_accept_inds,
    }
    for id_type, inds in inds_dict.items():
        inds_dict[id_type] = sorted(inds)

    metric_name = "W1"
    metric_names = ["W1", "mmtv_FFTKDE", "GsKL"]
    for metric_name in tqdm(metric_names):
        metric_values_dict = {k: [] for k in inds_dict.keys()}
        corresponding_ids = {k: [] for k in inds_dict.keys()}
        for id_type, inds in inds_dict.items():
            for observation_id in inds:
                record_key = f"{id_type}-{observation_id}"
                m_value = metrics_logger.data[metric_name].get(record_key)
                if m_value is not None:
                    metric_values_dict[id_type].append(m_value)
                    corresponding_ids[id_type].append(observation_id)

        metrics_results_tasks[task_name][metric_name] = metric_values_dict
        metrics_results_tasks_ids[task_name][metric_name] = corresponding_ids

    for k, v in metric_values_dict.items():
        if len(v) > 100:
            v = v[:100]
        print(k, len(v))

Plot for paper

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns
from exp_utils import set_default_plot_settings
from matplotlib.ticker import MaxNLocator

mpl.rcParams.update(mpl.rcParamsDefault)
set_default_plot_settings()
# Setup: publication-quality plot style with optimized settings for smaller figure
sns.set(style="whitegrid", context="paper", font_scale=0.9)
plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"""
\usepackage{pifont}
\newcommand{\cmark}{\ding{51}}
\newcommand{\xmark}{\ding{55}}
"""
plt.rcParams.update(
    {
        "font.family": "serif",
        "figure.dpi": 300,
        "axes.labelsize": 9,
        "axes.titlesize": 8.5,
        "xtick.labelsize": 7,
        "ytick.labelsize": 8,
        "lines.linewidth": 0.8,
    }
)

# Define parameters
metrics = ["W1", "mmtv_FFTKDE"]
metrics_display_names = {"W1": "W1", "mmtv_FFTKDE": "MMTV"}
tasks_display_names = {
    "GEV": "GEV",
    "BernoulliGLM": "Bernoulli GLM",
    "psychometric_curve_overdispersion": "Psychometric curve",
    "CustomDDM(dt=0.0001)": "Decision model",
}


# Replace ABI labels with symbols
def format_abi_labels(labels):
    formatted_labels = []
    for label in labels:
        if "ABI(accepted)" in label:
            formatted_labels.append(
                label.replace("ABI(accepted)", r"ABI(\cmark)")
            )
        elif "ABI(rejected)" in label:
            formatted_labels.append(
                label.replace("ABI(rejected)", r" ABI(\xmark)")
            )
        elif "HMC" in label:
            formatted_labels.append(label.replace("ChEES-HMC", "C-HMC"))
        else:
            formatted_labels.append(label)
    return formatted_labels


# Create figure with minimal margins
fig, axes = plt.subplots(
    len(metrics),
    len(task_names),
    figsize=(5.5, 2.3),
    sharey="row",
)

# Minimize spacing between subplots
plt.subplots_adjust(wspace=0.05, hspace=0.3)

# Custom color palette - using more distinctive colors for better visibility
colors = sns.color_palette("Set2", 8)

# Plot data
for row, metric in enumerate(metrics):
    for col, task_name in enumerate(task_names):
        ax = axes[row, col]
        data_dict = metrics_results_tasks[task_name][metric]

        # Create boxplots with consistent colors
        bplot = ax.boxplot(
            data_dict.values(),
            patch_artist=True,
            widths=0.55,
            showfliers=True,
            medianprops={"color": "black", "linewidth": 1.0},
            flierprops={
                "marker": ".",
                "markersize": 2,
                "alpha": 0.7,
            },  # Smaller, transparent outliers
        )

        # Color boxes consistently
        for patch, color in zip(bplot["boxes"], colors, strict=False):
            patch.set_facecolor(color)
            patch.set_edgecolor("black")
            patch.set_linewidth(0.5)
            patch.set_alpha(0.85)  # Slight transparency for better appearance

        # Set titles only for top row
        if row == 0:
            ax.set_title(tasks_display_names.get(task_name, task_name))

        # Set y-labels only for leftmost column
        if col == 0:
            ax.set_ylabel(metrics_display_names.get(metric, metric))
            ax.yaxis.set_major_locator(MaxNLocator(nbins=4, prune=None))

        # Set x-tick labels only for bottom row with checkmark/x symbols
        if row == len(metrics) - 1:
            # Replace ABI accepted/rejected with symbols
            formatted_labels = format_abi_labels(list(data_dict.keys()))
            ax.set_xticklabels(
                formatted_labels, rotation=0, ha="center", fontsize=5
            )
        else:
            ax.set_xticklabels([])

        # Tighter margins
        ax.margins(0.03)

        # Apply grid
        ax.grid(axis="y", linestyle="--", alpha=0.6, linewidth=0.5)

        # Remove most padding around the plot
        # ax.tick_params(pad=1)

        # Remove right and top spines
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

        # Format y-axis to avoid scientific notation and use fewer decimals
        if metric == "W1":
            ax.yaxis.set_major_formatter(mticker.FormatStrFormatter("%.1f"))
        else:
            ax.yaxis.set_major_formatter(mticker.FormatStrFormatter("%.1f"))

        if "mmtv" in metric:
            ax.axhline(
                y=0.2, color="k", linestyle="--", alpha=0.7, linewidth=0.8
            )
            ax.set_yticks([0.0, 0.2, 0.5, 0.8])

# Use tight layout to maximize use of available space
fig.tight_layout(pad=0.2, h_pad=0.5, w_pad=0.1)
plt.savefig("figures/metrics_boxplots.pdf", bbox_inches="tight")
plt.show()