In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from tqdm import tqdm

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
from omegaconf import OmegaConf
from sbi_mcmc.tasks import *
from sbi_mcmc.tasks.tasks_utils import get_task_logp_func
from sbi_mcmc.utils.bf_utils import bf_log_prob_posterior
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling
from sbi_mcmc.utils.tf_chees_hmc_utils import run_chees_hmc
from sbi_mcmc.utils.utils import *

In [None]:
stuff = get_stuff(
    job="chees_hmc",
)
task = stuff["task"]
paths = stuff["paths"]
test_dataset = stuff["test_dataset"]
test_dataset_name = stuff["test_dataset_name"]
stats_logger = stuff["stats_logger"]
config = stuff["config"]

In [None]:
target_num_draws = config.target_num_draws
K = config.chees_hmc.num_superchains
M = config.chees_hmc.num_subchains_per_superchain
init_step_size = config.chees_hmc.init_step_size
num_warmup = config.chees_hmc.num_warmup
num_chains = K * M
num_sampling = int(np.ceil(target_num_draws / num_chains))
print(f"Number of sampling: {num_sampling}")
D = task.D
print(f"Dimension of the task: {D}")

In [None]:
psis_stats = read_from_file(paths["psis_stats"])
psis_failed_observation_ids = psis_stats["reject_inds"]
print(len(psis_failed_observation_ids))
chees_abi_psis_failed = {}

In [None]:
dynamic_logp = config.get("dynamic_logp", False)
if dynamic_logp:
    lp_fn_dynamic = get_task_logp_func(
        task,
        static=False,
        pymc_model=task.setup_pymc_model(),
    )

In [None]:
for observation_id in tqdm(psis_failed_observation_ids):
    try:
        with stats_logger.timer(observation_id):
            tic = time.time()
            buffer = 5 * K
            result_record = {"time": {}}
            if task.name == "BernoulliGLM":
                observation = test_dataset["observables_raw"][observation_id]
            else:
                observation = test_dataset["observables"][observation_id]
            if not dynamic_logp:
                lp_fn = get_task_logp_func(task, observation=observation)
            else:
                observation_data = task.observation_to_pymc_data(observation)
                lp_fn = lambda x, obs=observation_data: lp_fn_dynamic(x, obs)
            psis_results = read_from_file(paths["psis_result"](observation_id))
            abi_samples = psis_results["abi_samples"]
            pareto_log_weights = psis_results["pareto_log_weights"]
            abi_psis_resamples_unique = _sir(
                abi_samples,
                log_weights=pareto_log_weights,
                num_samples=buffer,
                with_replacement=False,
            )
            initial_positions = abi_psis_resamples_unique[:buffer]

            result = run_chees_hmc(
                initial_positions,
                K,
                M,
                lp_fn,
                num_warmup,
                num_sampling,
                D,
                init_step_size=init_step_size,
            )
        result_record.update(result)
        result_record["init_option"] = "abi_psis"
        result_record["time"]["time_chees_hmc"] = time.time() - tic
        result_save_path = paths["chees_hmc_result"](observation_id)
        save_to_file(result_record, result_save_path)
    except Exception as e:
        # raise e
        print(str(e))
        chees_abi_psis_failed[observation_id] = str(e)

In [None]:
stats_logger.update("chees_abi_psis_failed", chees_abi_psis_failed)

Check the nested $\hat{R}$ values

In [None]:
n_rhats = []
n_rhats_dict = {}
for observation_id in psis_failed_observation_ids:
    try:
        result = read_from_file(paths["chees_hmc_result"](observation_id))
        n_rhat = result["n_rhat"]
    except FileNotFoundError:
        print(
            f"ChEES-HMC failed to process id {observation_id} for test dataset `{test_dataset_name}`"
        )
        n_rhat = np.full(task.D, np.Inf)
    n_rhats.append(n_rhat)
    n_rhats_dict[observation_id] = n_rhat
n_rhats = np.stack(n_rhats)

In [None]:
n_rhat_threshold = 1.01
chees_hmc_reject_inds = []
chees_hmc_accept_inds = []
for i, value in n_rhats_dict.items():
    if value.max() >= n_rhat_threshold:
        chees_hmc_reject_inds.append(i)
    else:
        chees_hmc_accept_inds.append(i)

In [None]:
print(f"{len(chees_hmc_reject_inds)}/{len(psis_failed_observation_ids)}")
print(
    f"{len(psis_failed_observation_ids) - len(chees_hmc_reject_inds)}/{len(psis_failed_observation_ids)}"
)

In [None]:
stats_logger.update("chees_hmc_reject_inds", chees_hmc_reject_inds)
stats_logger.update("chees_hmc_accept_inds", chees_hmc_accept_inds)