In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from tqdm import tqdm

os.environ["KERAS_BACKEND"] = "jax"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
import keras
import numpy as np
import pandas as pd
from exp_utils import read_reference_posterior
from omegaconf import OmegaConf
from sbi_mcmc.tasks import *
from sbi_mcmc.tasks.tasks_utils import get_task_logp_func
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling
from sbi_mcmc.utils.utils import *
from tqdm.autonotebook import tqdm

In [None]:
task_name = "psychometric_curve_overdispersion"
# task_name = "CustomDDM(dt=0.0001)"
# task_name = "GEV"
# task_name = "BernoulliGLM"
max_num_runs = 100

In [None]:
stuff = get_stuff(
    task_name=task_name,
    test_dataset_name="test_dataset_chunk_1",
    job=None,
    overwrite_stats=False,
)
task = stuff["task"]
paths = stuff["paths"]
N_testdata = stuff["N_testdata"]
test_dataset = stuff["test_dataset"]
test_dataset_name = stuff["test_dataset_name"]
config = stuff["config"]

stats = {}
for job in ["ood", "psis", "abi", "chees_hmc"]:
    stats_logger = PickleStatLogger(
        paths[f"{job}_stats"], overwrite=False, verbose=True
    )
    stats[job] = stats_logger.data

In [None]:
abi_accept_inds = set(
    stats["ood"][f"Mahalanobis_{test_dataset_name}"]["ood_accept_inds"]
)
abi_reject_inds = set(
    stats["ood"][f"Mahalanobis_{test_dataset_name}"]["ood_failed_inds"]
)
psis_accept_inds = set(stats["psis"]["accept_inds"])
psis_reject_inds = set(stats["psis"]["reject_inds"])
assert abi_reject_inds == psis_accept_inds | psis_reject_inds
assert abi_reject_inds.issuperset(psis_accept_inds)

chees_hmc_reject_inds = set(stats["chees_hmc"]["chees_hmc_reject_inds"])
chees_hmc_accept_inds = set(stats["chees_hmc"]["chees_hmc_accept_inds"])
assert psis_reject_inds == chees_hmc_reject_inds | chees_hmc_accept_inds
assert psis_reject_inds.issuperset(chees_hmc_reject_inds)

In [None]:
pymc_source = f"pymc_runs/{test_dataset_name}"
batch_results = read_from_file(paths["abi_result"])

metrics_logger = PickleStatLogger(
    paths["metrics_stats"], overwrite=False, verbose=True
)

In [None]:
inds_dict = {
    "ABI(accepted)": abi_accept_inds,
    "ABI(rejected)": abi_reject_inds,
    "PSIS": psis_accept_inds,
    "ChEES-HMC": chees_hmc_accept_inds,
}
for id_type, inds in inds_dict.items():
    inds_dict[id_type] = sorted(inds)

In [None]:
def read_posterior_draws(id_type, observation_id):
    # Retrieve posterior draws
    if "ABI" in id_type:
        posterior_draws = batch_results["abi_samples_batch"][observation_id]

    elif id_type == "PSIS":
        psis_results = read_from_file(paths["psis_result"](observation_id))
        posterior_draws = psis_results["abi_psis_resamples"]

    elif id_type == "ChEES-HMC":
        chees_hmc_results = read_from_file(
            paths["chees_hmc_result"](observation_id)
        )
        posterior_draws = chees_hmc_results["chees_draws_tfp"][
            : config.target_num_draws, 0, :
        ]
    else:
        posterior_draws = None
    return posterior_draws

In [None]:
from collections import OrderedDict
from functools import partial

from sbi_mcmc.metrics import gskl, mtv, wasserstein_distance
from sbi_mcmc.utils.plot_utils import corner_plot

mtv_FFTKDE = partial(mtv, kde="FFTKDE")
metric_dict = OrderedDict(
    {"W1": wasserstein_distance, "mtv_FFTKDE": mtv_FFTKDE, "GsKL": gskl}
)

computed_runs = {
    "ABI(accepted)": max_num_runs,
    "ABI(rejected)": max_num_runs,
    "PSIS": max_num_runs,
    "ChEES-HMC": max_num_runs,
}
# computed_runs = {}

all_observation_ids = list(range(stuff["N_testdata"]))

pymc_missing_ids = []
# for observation_id in tqdm(all_observation_ids):
for id_type, inds in tqdm(inds_dict.items(), desc="Groups"):
    for observation_id in tqdm(inds, desc=f"Obs ({id_type})", leave=False):
        if id_type not in computed_runs.keys():
            computed_runs[id_type] = 0

        if computed_runs[id_type] >= max_num_runs:
            continue  # Skip unnecessary file reading

        try:
            samples_pymc_unconstrained = read_reference_posterior(
                task, observation_id, pymc_source, raise_error=True
            )
        except ValueError as e:
            if "Rhat is too large" in str(e):
                print(f"Skipping {observation_id} due to Rhat error.")
                continue
            elif "not found" in str(e):
                print(
                    f"Skipping {observation_id} due to PyMC result not found."
                )
                pymc_missing_ids.append(observation_id)
                continue
            else:
                raise

        posterior_draws = read_posterior_draws(id_type, observation_id)
        if posterior_draws is None:
            continue

        # Compute metrics
        for metric_name, metric_fn in metric_dict.items():
            record_key = f"{id_type}-{observation_id}"
            if (
                metrics_logger.data.get(metric_name, {}).get(record_key)
                is None
            ):
                metric_value = metric_fn(
                    posterior_draws, samples_pymc_unconstrained
                )
                # print(metric_value)
                metrics_logger.update(metric_name, {record_key: metric_value})
                if metric_name == "mtv" or metric_name == "mtv_FFTKDE":
                    metrics_logger.update(
                        f"m{metric_name}", {record_key: np.mean(metric_value)}
                    )
            else:
                pass

        computed_runs[id_type] += 1

In [None]:
import matplotlib.pyplot as plt

metric_names = ["W1", "mmtv_FFTKDE", "GsKL"]

metrics_results = {}
metrics_results_ids = {}
for metric_name in tqdm(metric_names):
    metric_values_dict = {k: [] for k in inds_dict.keys()}
    corresponding_ids = {k: [] for k in inds_dict.keys()}
    for id_type, inds in inds_dict.items():
        print(id_type)
        n_plots = 0
        fig_dir = paths["metrics_result_dir"] / id_type
        fig_dir.mkdir(parents=True, exist_ok=True)
        num_files = sum(1 for f in fig_dir.iterdir() if f.is_file())

        for observation_id in inds:
            record_key = f"{id_type}-{observation_id}"
            m_value = metrics_logger.data[metric_name].get(record_key)
            if m_value is not None:
                metric_values_dict[id_type].append(m_value)
                corresponding_ids[id_type].append(observation_id)

                max_plots = 10
                if n_plots <= max_plots and num_files <= max_plots:
                    samples_pymc_unconstrained = read_reference_posterior(
                        task, observation_id, pymc_source, raise_error=False
                    )
                    posterior_draws = read_posterior_draws(
                        id_type, observation_id
                    )
                    fig = corner_plot(
                        samples_pymc_unconstrained,
                        posterior_draws,
                        save_as=fig_dir / f"{observation_id}.png",
                        dpi=100,
                    )
                    plt.close(fig)
                    n_plots += 1

    metrics_results[metric_name] = metric_values_dict
    metrics_results_ids[metric_name] = corresponding_ids

    for k, v in metric_values_dict.items():
        print(k, len(v))

    data_dict = metric_values_dict

    fig = plt.figure(figsize=(5, 4))
    plt.boxplot(
        data_dict.values(),
        labels=data_dict.keys(),
        widths=0.6,
        showfliers=True,
    )
    plt.ylabel(metric_name)
    if metric_name in ["GsKL"]:
        plt.yscale("log")
    fig_dir = paths["metrics_result_dir"] / "figures"
    fig_dir.mkdir(parents=True, exist_ok=True)

    fig.savefig(fig_dir / f"{metric_name}.png", bbox_inches="tight", dpi=100)

## Get the summary table for accepted datasets

In [None]:
id_chunks_dict = {
    "GEV": [1],
    "BernoulliGLM": [1, 2],
    "psychometric_curve_overdispersion": [1, 2],
    "CustomDDM(dt=0.0001)": [1, 3, 4],
}
# We estimated the training time for V100 separately
V100_training_times = {
    "GEV": 141,
    "CustomDDM(dt=0.0001)": 2200,
    "psychometric_curve_overdispersion": 267,
    "BernoulliGLM": 38,
}
speed_ups = []
for task_name, chunks in id_chunks_dict.items():
    print("===")
    print(task_name)

    # Initialize accumulators
    total_step_times = [0.0, 0.0, 0.0]
    total_num_accepted = [0, 0, 0]
    total_N_testdata = 0
    total_abi_reject = 0
    total_psis_reject = 0

    for id_chunk in chunks:
        stuff = get_stuff(
            task_name=task_name,
            test_dataset_name=f"test_dataset_chunk_{id_chunk}",
        )
        task = stuff["task"]
        paths = stuff["paths"]
        N_testdata = stuff["N_testdata"]
        test_dataset_name = stuff["test_dataset_name"]

        total_N_testdata += N_testdata

        stats = {}
        for job in ["ood", "psis", "abi", "chees_hmc"]:
            stats_logger = PickleStatLogger(
                paths[f"{job}_stats"], overwrite=False, verbose=True
            )
            stats[job] = stats_logger.data

        stats["training"] = PickleStatLogger(
            paths["training_result_dir"] / "training_record.pkl",
            overwrite=False,
        ).data

        abi_accept_inds = set(
            stats["ood"][f"Mahalanobis_{test_dataset_name}"]["ood_accept_inds"]
        )
        abi_reject_inds = set(
            stats["ood"][f"Mahalanobis_{test_dataset_name}"]["ood_failed_inds"]
        )
        psis_accept_inds = set(stats["psis"]["accept_inds"])
        psis_reject_inds = set(stats["psis"]["reject_inds"])
        chees_hmc_accept_inds = set(
            stats["chees_hmc"]["chees_hmc_accept_inds"]
        )
        chees_hmc_reject_inds = set(
            stats["chees_hmc"]["chees_hmc_reject_inds"]
        )
        assert abi_reject_inds == psis_accept_inds | psis_reject_inds
        assert abi_reject_inds.issuperset(psis_accept_inds)
        assert (
            psis_reject_inds == chees_hmc_reject_inds | chees_hmc_accept_inds
        )
        assert psis_reject_inds.issuperset(chees_hmc_reject_inds)

        # Step 0
        step_0_time_total = 0
        for k, t in stats["training"]["wall_time"].items():
            if k == "training":
                print("Use training time for V100")
                t = V100_training_times[task_name]
            if "lc2st_cal" not in k:
                step_0_time_total += t

        if task_name == "CustomDDM(dt=0.0001)":
            assert "simulation_train" not in stats["training"]
            step_0_time_total += (
                2208 / 8 * 10.0
            )  # Simulation cost for training dataset

        # Step 1
        ood_time = sum(stats["ood"]["wall_time"].values())
        abi_time = sum(stats["abi"]["wall_time"].values())
        total_step_times[0] += ood_time + abi_time

        # Step 2
        for obs_id in abi_reject_inds:
            total_step_times[1] += stats["psis"]["wall_time"][obs_id]

        # Step 3
        for obs_id in psis_reject_inds:
            if obs_id not in stats["chees_hmc"]["wall_time"].keys():
                # print(f"{obs_id} not in ChEES-HMC, probably failed")
                continue
            total_step_times[2] += stats["chees_hmc"]["wall_time"][obs_id]

        # Accumulate counts
        total_num_accepted[0] += len(abi_accept_inds)
        total_num_accepted[1] += len(psis_accept_inds)
        total_num_accepted[2] += len(chees_hmc_accept_inds)
        total_abi_reject += len(abi_reject_inds)
        total_psis_reject += len(psis_reject_inds)

    total_step_times[0] += step_0_time_total  # Plus the training phase time
    total_time = sum(total_step_times)
    times = total_step_times + [total_time]
    num_accepted = total_num_accepted + [sum(total_num_accepted)]

    def format_g(i):
        if i < 1:
            return f"{i:.1g}"
        return f"{i:.0f}"

    estimate_total_time_chees_hmc = (
        times[2] / total_psis_reject
    ) * num_accepted[3]
    speed_ups.append(estimate_total_time_chees_hmc // total_time)

    data = {
        "name": [
            "& Step 1: Amortized inference",
            "& Step 2: Amortized + PSIS",
            "& Step 3: ChEES-HMC w/ inits",
            "& Workflow total",
            "& Direct ChEES-HMC",
        ],
        "Accepted": [
            f"${num_accepted[0]}/{total_N_testdata}$",
            f"${num_accepted[1]}/{total_abi_reject}$",
            f"${num_accepted[2]}/{total_psis_reject}$",
            f"${num_accepted[3]}/{total_N_testdata}$",
            r"\NA",
        ],
        "Time (minutes)": [format_g(t / 60) for t in times]
        + [format_g(estimate_total_time_chees_hmc / 60)],
        "TPA": [
            format_g(t / a) for t, a in zip(times, num_accepted, strict=False)
        ]
        + [r"\NA"],
    }

    df = pd.DataFrame(data)
    latex_table = df.to_latex(index=False)
    print(latex_table)
print("Speed up: ", speed_ups)