In [None]:
%load_ext autoreload
%autoreload 2

In [12]:
import os

from tqdm import tqdm

os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
from omegaconf import OmegaConf
from sbi_mcmc.tasks import *
from sbi_mcmc.tasks.tasks_utils import get_task_logp_func
from sbi_mcmc.utils.bf_utils import bf_log_prob_posterior
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling
from sbi_mcmc.utils.tf_chees_hmc_utils import run_chees_hmc
from sbi_mcmc.utils.utils import *

In [None]:
mcmc_method = "ChEES-HMC"
mcmc_method = "NUTS"
task_name = "psychometric_curve_overdispersion"
# task_name="CustomDDM(dt=0.0001)"
# task_name = "GEV"
# task_name = "BernoulliGLM"
processed = False

In [None]:
import arviz as az
import jax
from sbi_mcmc.tasks.tasks import ndarray_values_as_dict

num_runs = 20
num_warmup_values = [10, 50, 100, 200, 300, 500]
init_options = ["abi_psis", "abi", "stan-like"]
sort = False
rng_key = jax.random.key(42)


def get_save_path(observation_id, num_warmup, init_option, sort=False):
    filename = f"{test_dataset_name}_{observation_id}_num_warmup_{num_warmup}_init_option_{init_option}"
    if sort:
        filename += "_sorted"
    filename += ".pkl"
    result_save_path = (
        paths["chees_hmc_result_dir"]
        / f"warmup_tests_{mcmc_method}/{filename}"
    )
    result_save_path.parent.mkdir(parents=True, exist_ok=True)
    return result_save_path

In [None]:
task_names = [
    "GEV",
    "BernoulliGLM",
    "psychometric_curve_overdispersion",
    "CustomDDM(dt=0.0001)",
]

n_rhats_tasks = {}
for task_name in tqdm(task_names):
    stuff = get_stuff(
        task_name=task_name,
        test_dataset_name="test_dataset_chunk_1",
        overwrite_stats=False,
    )
    task = stuff["task"]
    paths = stuff["paths"]
    test_dataset = stuff["test_dataset"]
    test_dataset_name = stuff["test_dataset_name"]
    config = stuff["config"]

    psis_stats = read_from_file(paths["psis_stats"])
    psis_failed_observation_ids = psis_stats["reject_inds"]
    test_observation_ids = psis_failed_observation_ids

    valid_inds = []
    for observation_id in test_observation_ids:
        flag = True
        for init_option in ["abi_psis", "stan-like", "abi"]:
            for num_warmup in num_warmup_values:
                file_path = get_save_path(
                    observation_id, num_warmup, init_option, sort=sort
                )
                if not file_path.exists():
                    # print(file_path)
                    flag = False
        if flag:
            valid_inds.append(observation_id)
    assert len(valid_inds) >= 20, f"{len(valid_inds)}"

    from collections import OrderedDict

    from sbi_mcmc.metrics import gskl, mtv, wasserstein_distance

    n_rhats = {}
    for init_option in ["abi_psis", "stan-like", "abi"]:
        if init_option not in n_rhats:
            n_rhats[init_option] = {}
        for num_warmup in num_warmup_values:
            max_nrhats = []
            for observation_id in tqdm(valid_inds):
                file_path = get_save_path(
                    observation_id, num_warmup, init_option, sort=sort
                )
                result = read_from_file(file_path)
                max_nrhats.append(np.mean(result["n_rhat"].max() - 1))
            n_rhats[init_option][num_warmup] = max_nrhats
    n_rhats_tasks[task_name] = n_rhats

In [None]:
tasks_display_names = {
    "GEV": "GEV",
    "BernoulliGLM": "Bernoulli GLM",
    "psychometric_curve_overdispersion": "Psychometric curve",
    "CustomDDM(dt=0.0001)": "Decision model",
}

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Define custom figure parameters for NeurIPS
fig_width = 5  # NeurIPS page width constraint
fig_height = 3  # Slightly shorter to accommodate bottom legend

# Keep your original color palette
colors = {
    "abi_psis": "#DDAA33",
    "stan-like": "#BB5566",
    "abi": "#004488",
}
labels = {
    "abi_psis": "Amortized + PSIS",
    "stan-like": "Random Init.",
    "abi": "Amortized",
}
markers = {
    "abi_psis": "D",
    "stan-like": "s",
    "abi": "^",
}

# Configure global plot settings for publication quality
plt.rcParams.update(
    {
        "font.family": "serif",
        "font.size": 9,
        "axes.labelsize": 10,
        "axes.titlesize": 10,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "legend.fontsize": 9,
        "axes.linewidth": 0.8,  # Thinner spines
        "grid.linewidth": 0.6,  # Thinner grid lines
        "lines.linewidth": 1.0,  # Line thickness
        "lines.markersize": 5,  # Default marker size
        "xtick.major.width": 0.8,  # Tick width
        "ytick.major.width": 0.8,
        "xtick.direction": "out",  # Ticks facing outward
        "ytick.direction": "out",
    }
)
plt.rcParams["text.usetex"] = True

# Create figure with shared x axis
fig, axs = plt.subplots(
    2, 2, figsize=(fig_width, fig_height), sharex=True, sharey=False
)

axs = axs.flatten()

interquartile = True
if interquartile:
    lower = 25
    upper = 75
else:
    lower = 0
    upper = 100

# Loop through tasks to create individual subplots
for task_idx, task_name in enumerate(task_names):
    ax = axs[task_idx]

    data = n_rhats_tasks[task_name]

    # Loop through methods in the dictionary
    for method_idx, method_key in enumerate(data):
        x_values = sorted(data[method_key].keys())
        x_offset = -5 + 5 * method_idx  # Offset for separation

        rhat = np.array([data[method_key][x] for x in x_values])
        x = np.array(x_values) + x_offset

        # Calculate error bars
        medians = np.median(rhat, axis=1)
        lower_errs = medians - np.percentile(rhat, lower, axis=1)
        upper_errs = np.percentile(rhat, upper, axis=1) - medians

        ax.errorbar(
            x,
            medians,
            yerr=[lower_errs, upper_errs],
            fmt=markers[method_key],
            lw=1.0,
            color=colors[method_key],
            capsize=2.5,
            linestyle="-",
            label=labels[method_key],
            markeredgecolor="black",
            markeredgewidth=0.5,
            markersize=4.5,
            zorder=3,  # Ensure data points are on top
        )

    # Configure subplot
    ax.set_title(f"{tasks_display_names.get(task_name)}")
    ax.set_yscale("log")
    ax.grid(True, alpha=0.25, linestyle="-", which="major", zorder=0)

    # Add horizontal line at y=0.01 with better visibility
    ax.axhline(
        y=0.01, color="k", linestyle="--", alpha=0.7, linewidth=0.8, zorder=2
    )

    # Set specific y-ticks
    ax.set_yticks([0.001, 0.01, 0.1])

    # Remove top and right spines
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    # Only add x-label to bottom plots
    if task_idx >= 2:
        ax.set_xlabel("Warmup iterations")
        ax.set_xticks(x_values)

    # Only add y-label to leftmost plots
    if task_idx % 2 == 0:
        if mcmc_method == "NUTS":
            ylabel = r"$\widehat{R} - 1$"
        elif mcmc_method == "ChEES-HMC":
            ylabel = r"Nested $\widehat{R} - 1$"
        ax.set_ylabel(ylabel)

# Add a centered, publication-quality legend below the subplots
handles, labels_list = axs[0].get_legend_handles_labels()
if handles:
    order = sorted(range(len(labels_list)), key=lambda x: labels_list[x])
    legend = fig.legend(
        [handles[idx] for idx in order],
        [labels_list[idx] for idx in order],
        loc="lower center",  # Position below the subplots
        bbox_to_anchor=(0.5, -0.05),  # Fine-tune vertical position
        ncol=3,
        # loc='center left',
        # bbox_to_anchor=(1, 0.5),  # Positioned inside the fixed figure
        # ncol=1,
        frameon=False,
        handlelength=1.2,
        handletextpad=0.5,
        borderaxespad=0.2,
    )

plt.tight_layout()
plt.subplots_adjust(bottom=0.17, wspace=0.25, hspace=0.35)
fig_path = f"figures/warmup_comparision_{mcmc_method}"
plt.savefig(f"{fig_path}.pdf", dpi=300, bbox_inches="tight")

plt.show()