In [None]:
%load_ext autoreload
%autoreload 2

In [84]:
import os

from tqdm import tqdm

os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
from omegaconf import OmegaConf
from sbi_mcmc.tasks import *
from sbi_mcmc.tasks.tasks_utils import get_task_logp_func
from sbi_mcmc.utils.bf_utils import bf_log_prob_posterior
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling
from sbi_mcmc.utils.tf_chees_hmc_utils import run_chees_hmc
from sbi_mcmc.utils.utils import *

In [None]:
mcmc_method = "ChEES-HMC"
# mcmc_method = "NUTS"
task_name = "psychometric_curve_overdispersion"
task_name = "CustomDDM(dt=0.0001)"
task_name = "GEV"
# task_name = "BernoulliGLM"
processed = False

In [None]:
stuff = get_stuff(
    task_name=task_name,
    test_dataset_name="test_dataset_chunk_1",
    overwrite_stats=False,
)
task = stuff["task"]
paths = stuff["paths"]
test_dataset = stuff["test_dataset"]
test_dataset_name = stuff["test_dataset_name"]
config = stuff["config"]

In [None]:
target_num_draws = config.target_num_draws
init_step_size = config.chees_hmc.init_step_size
print(f"MCMC method: {mcmc_method}, Task: {task_name}")
if mcmc_method == "ChEES-HMC":
    K = config.chees_hmc.num_superchains
    M = config.chees_hmc.num_subchains_per_superchain
    num_chains = K * M
elif mcmc_method == "NUTS":
    K = 4
    num_chains = K
else:
    raise NotImplementedError
num_sampling = int(np.ceil(target_num_draws / num_chains))
print(f"Number of sampling: {num_sampling}")
D = task.D
print(f"Dimension of the task: {D}")

In [None]:
psis_stats = read_from_file(paths["psis_stats"])
psis_failed_observation_ids = psis_stats["reject_inds"]
print(len(psis_failed_observation_ids))
test_observation_ids = psis_failed_observation_ids

In [None]:
if not processed:
    dynamic_logp = config.get("dynamic_logp", False)
    if dynamic_logp:
        lp_fn_dynamic = get_task_logp_func(
            task,
            static=False,
            pymc_model=task.setup_pymc_model(),
        )

In [None]:
import arviz as az
import jax
from sbi_mcmc.tasks.tasks import ndarray_values_as_dict

num_runs = 20
num_warmup_values = [10, 50, 100, 200, 300, 500]
init_options = ["abi_psis", "abi", "stan-like"]
sort = False
rng_key = jax.random.key(42)


def get_save_path(observation_id, num_warmup, init_option, sort=False):
    filename = f"{test_dataset_name}_{observation_id}_num_warmup_{num_warmup}_init_option_{init_option}"
    if sort:
        filename += "_sorted"
    filename += ".pkl"
    result_save_path = (
        paths["chees_hmc_result_dir"]
        / f"warmup_tests_{mcmc_method}/{filename}"
    )
    result_save_path.parent.mkdir(parents=True, exist_ok=True)
    return result_save_path

In [None]:
check_init_positions_only = (
    True  # Debug check: init positions versus reference posterior samples
)

init_positions_record = {}
if not processed:
    proceesed_inds = []
    exceptions = []
    buffer = 5 * K
    for observation_id in test_observation_ids:
        psis_results = read_from_file(paths["psis_result"](observation_id))
        pareto_log_weights = psis_results["pareto_log_weights"]
        abi_samples = psis_results["abi_samples"]
        if task.name == "BernoulliGLM":
            observation = test_dataset["observables_raw"][observation_id]
        else:
            observation = test_dataset["observables"][observation_id]
        if not dynamic_logp:
            lp_fn = get_task_logp_func(task, observation=observation)
        else:
            observation_data = task.observation_to_pymc_data(observation)
            lp_fn = lambda x, obs=observation_data: lp_fn_dynamic(x, obs)
        try:
            abi_psis_resamples_unique = _sir(
                abi_samples,
                log_weights=pareto_log_weights,
                num_samples=buffer,
                with_replacement=False,
            )
        except Exception as e:
            exceptions.append(
                f"id: {observation_id}: can't use abi_psis" + str(e)
            )
            continue

        for num_warmup in num_warmup_values:
            for init_option in init_options:
                print(observation_id, num_warmup, init_option)
                if init_option == "abi":
                    initial_positions = abi_samples[:buffer]
                elif init_option == "abi_psis":
                    initial_positions = abi_psis_resamples_unique[:buffer]
                elif init_option == "stan-like":
                    rng_key, init_key = jax.random.split(rng_key)
                    initial_positions = jax.random.uniform(
                        init_key, (buffer, D), minval=-2, maxval=2
                    )

                    if "DDM" in task.name:
                        print("Make initial values for ndt meaningful")
                        initial_positions_constrained = (
                            task.transform_to_constrained_space(
                                initial_positions
                            )
                        )
                        initial_positions_constrained[:, -2:] = (
                            np.ones_like(
                                (initial_positions_constrained.shape[0], 2)
                            )
                            * task.min_rts
                            / 2
                        )
                        initial_positions = (
                            task.transform_to_unconstrained_space(
                                initial_positions_constrained
                            )
                        )
                else:
                    raise ValueError(f"Unknown init_option: {init_option}")

                if observation_id not in init_positions_record.keys():
                    init_positions_record[observation_id] = {}
                init_positions_record[observation_id][init_option] = (
                    initial_positions
                )
                if check_init_positions_only:
                    continue
                try:
                    if mcmc_method == "ChEES-HMC":
                        result = run_chees_hmc(
                            initial_positions,
                            K,
                            M,
                            lp_fn,
                            num_warmup,
                            num_sampling,
                            D,
                            init_step_size=init_step_size,
                            sort=sort,
                        )
                    elif mcmc_method == "NUTS":
                        from sbi_mcmc.utils.tf_chees_hmc_utils import (
                            filter_invalid_init_positions,
                        )

                        initial_positions, _ = filter_invalid_init_positions(
                            initial_positions, lp_fn, sort=sort
                        )
                        if len(initial_positions) < num_chains:
                            raise ValueError(
                                "Not enough valid and unique initial positions"
                            )
                        initial_positions = initial_positions[:num_chains]

                        initial_positions = (
                            task.transform_to_constrained_space(
                                initial_positions
                            )
                        )
                        initvals = ndarray_values_as_dict(
                            initial_positions, task.var_dims
                        )
                        initvals = [
                            ndarray_values_as_dict(
                                initial_positions[i : i + 1], task.var_dims
                            )
                            for i in range(len(initial_positions))
                        ]
                        sampler_kwargs = {
                            "nuts_sampler": "numpyro",
                            "nuts_sampler_kwargs": {"jitter": False},
                            "target_accept": 0.99,
                        }
                        pymc_model = task.setup_pymc_model(
                            observation=observation
                        )
                        idata_post = pm.sample(
                            tune=num_warmup,
                            draws=num_sampling,
                            chains=num_chains,
                            model=pymc_model,
                            initvals=initvals,
                            progressbar=False,
                            **sampler_kwargs,
                        )
                        rhat_az = az.rhat(
                            idata_post, var_names=list(task.var_names)
                        )
                        _, rhat = az.sel_utils.xarray_to_ndarray(
                            rhat_az, var_names=list(task.var_names)
                        )
                        result = {"n_rhat": rhat.squeeze()}
                    else:
                        continue
                except ValueError as e:
                    exceptions.append(
                        f"id: {observation_id}_{num_warmup}_{init_option}: "
                        + str(e)
                    )
                    continue
                result_save_path = get_save_path(
                    observation_id, num_warmup, init_option, sort=sort
                )
                # print(result["n_rhat"])
                save_to_file(result, result_save_path)
        proceesed_inds.append(observation_id)
        print("total processed: ", len(proceesed_inds))
        if len(proceesed_inds) >= num_runs:
            break

In [None]:
valid_inds = []
for observation_id in test_observation_ids:
    flag = True
    for init_option in ["abi_psis", "stan-like", "abi"]:
        for num_warmup in num_warmup_values:
            file_path = get_save_path(
                observation_id, num_warmup, init_option, sort=sort
            )
            if not file_path.exists():
                # print(file_path)
                flag = False
    if flag:
        valid_inds.append(observation_id)
print(len(valid_inds))

In [None]:
from collections import OrderedDict

from sbi_mcmc.metrics import gskl, mtv, wasserstein_distance

n_rhats = {}

for init_option in ["abi_psis", "stan-like", "abi"]:
    if init_option not in n_rhats:
        n_rhats[init_option] = {}
    for num_warmup in num_warmup_values:
        max_nrhats = []
        for observation_id in valid_inds:
            file_path = get_save_path(
                observation_id, num_warmup, init_option, sort=sort
            )
            result = read_from_file(file_path)
            max_nrhats.append(np.mean(result["n_rhat"].max() - 1))

        n_rhats[init_option][num_warmup] = max_nrhats

In [None]:
import matplotlib.pyplot as plt

data = n_rhats
fig = plt.figure(figsize=(5, 3))

# serif font
plt.rcParams["font.family"] = "serif"

# math serif
plt.rcParams["mathtext.fontset"] = "dejavuserif"

# larger font
plt.rcParams.update({"font.size": 12})
# colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
colors = {
    "abi_psis": "#DDAA33",
    "stan-like": "#BB5566",
    "abi": "#004488",
}
labels = {
    "abi_psis": "Amortized + PSIS",
    "stan-like": "Random Initialization",
    "abi": "Amortized",
}

markers = {
    "abi_psis": "D",
    "stan-like": "s",
    "abi": "^",
}
interquartile = True
# interquartile = False
if interquartile:
    title = "median, interquartile"
    lower = 25
    upper = 75
else:
    title = "median, [Min,Max]"
    lower = 0
    upper = 100
# Loop through the outer keys in the dictionary
for idx, key in enumerate(data):
    x_values = sorted(data[key].keys())
    x_offset = -4 + 4 * idx

    rhat = np.array([data[key][x] for x in x_values])
    x = np.array(x_values) + x_offset

    plt.errorbar(
        x,
        np.median(rhat, axis=1),
        yerr=[
            np.percentile(rhat, lower, axis=1),
            np.percentile(rhat, upper, axis=1),
        ],
        fmt=markers[key],
        lw=1.5,
        color=colors[key],
        capsize=4,
        linestyle="-",
        label=labels[key],
        # marker edges black
        markeredgecolor="black",
    )


plt.xlabel("Number of warmup iterations", fontsize="x-large")
plt.xticks(x_values)
if mcmc_method == "NUTS":
    ylabel = r"$\widehat{R} - 1$"
elif mcmc_method == "ChEES-HMC":
    ylabel = r"Nested $\widehat{R} - 1$"
else:
    raise ValueError
plt.ylabel(ylabel, fontsize="x-large")
plt.yscale("log")
plt.grid(True, alpha=0.5)


# get handles and labels
handles, labels = plt.gca().get_legend_handles_labels()
handles = [h[0] for h in handles]

# specify order of items in legend
order = sorted(range(len(labels)), key=lambda x: labels[x])
plt.legend(
    [handles[idx] for idx in order],
    [labels[idx] for idx in order],
    handlelength=1,
    borderpad=0.2,
    fontsize=9,
    handletextpad=0.3,
)

plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)
plt.title(title)
fig_path = (
    paths["chees_hmc_result_dir"]
    / f"figures/amortized_inits_mcmc_{mcmc_method}_{'_sorted' if sort else ''}.pdf"
)
fig_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(fig_path, dpi=300, bbox_inches="tight")