In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from tqdm import tqdm

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
from omegaconf import OmegaConf
from sbi_mcmc.tasks import *
from sbi_mcmc.tasks.tasks_utils import get_task_logp_func
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling
from sbi_mcmc.utils.utils import *
from tqdm.autonotebook import tqdm

In [None]:
stuff = get_stuff(
    job="psis",
)
task = stuff["task"]
paths = stuff["paths"]
test_dataset = stuff["test_dataset"]
test_dataset_name = stuff["test_dataset_name"]
stats_logger = stuff["stats_logger"]
config = stuff["config"]
step_1_failed_inds = set()

In [None]:
batch_results = read_from_file(paths["abi_result"])

In [None]:
ood_stats = read_from_file(paths["ood_stats"])
ood_failed_inds = sorted(
    ood_stats[f"Mahalanobis_{test_dataset_name}"]["ood_failed_inds"]
)
step_1_failed_inds |= set(ood_failed_inds)
print(len(step_1_failed_inds))

In [None]:
dynamic_logp = config.get("dynamic_logp", False)
if dynamic_logp:
    print("Dynamic logp")
    lp_fn_dynamic = get_task_logp_func(
        task,
        static=False,
        pymc_model=task.setup_pymc_model(),
    )

In [None]:
check_strict = False  # If False, replace NaN log densities with -inf

step_1_failed_inds = sorted(step_1_failed_inds)
for observation_id in tqdm(step_1_failed_inds):
    tic = time.time()
    with stats_logger.timer(observation_id):
        if task.name == "BernoulliGLM":
            observation = test_dataset["observables_raw"][observation_id]
        else:
            observation = test_dataset["observables"][observation_id]

        if not dynamic_logp:
            lp_fn = get_task_logp_func(task, observation=observation)
        result_record = {"time": {}}
        abi_samples = batch_results["abi_samples_batch"][observation_id]
        log_pdfs_abi = batch_results["log_pdfs_abi_batch"][observation_id]
        assert abi_samples.shape[0] >= config["target_num_draws"]
        assert log_pdfs_abi.shape[0] >= config["target_num_draws"]
        assert not np.isnan(abi_samples).any(), "NaN in abi_samples. Exiting."

        with stats_logger.timer(f"{observation_id}_logp-task"):
            if dynamic_logp:
                log_pdfs_task = lp_fn_dynamic(
                    abi_samples, task.observation_to_pymc_data(observation)
                )
            else:
                log_pdfs_task = lp_fn(abi_samples)
        if np.isnan(log_pdfs_task).any():
            if check_strict:
                raise ValueError("NaN in log_pdfs_task. Exiting.")
            else:
                print(
                    f"observation_id: {observation_id}. {sum(np.isnan(log_pdfs_task))} NaNs in log_pdfs_task. Replacing with -inf."
                )
                log_pdfs_task = np.nan_to_num(log_pdfs_task, nan=-np.inf)
        assert not np.isinf(log_pdfs_task).all(), (
            "log_pdfs_task are all inf. Exiting."
        )
        assert log_pdfs_abi.ndim == log_pdfs_task.ndim == 1
        with stats_logger.timer(f"{observation_id}_sir"):
            abi_psis_resamples, k_stat, pareto_log_weights = (
                sampling_importance_resampling(
                    log_pdfs_task,
                    log_pdfs_abi,
                    abi_samples,
                    return_weights=True,
                    num_samples=config["target_num_draws"],
                )
            )
        result_record["log_pdfs_task"] = log_pdfs_task
        result_record["log_pdfs_abi"] = log_pdfs_abi
        result_record["pareto_log_weights"] = pareto_log_weights
        result_record["abi_samples"] = abi_samples
        result_record["abi_psis_resamples"] = abi_psis_resamples
        result_record["pareto_k"] = k_stat
        result_record["time"]["time_psis(exclude_abi_logp)"] = (
            time.time() - tic
        )

    stats_logger.update("pareto_k", {observation_id: k_stat})
    result_save_path = paths["psis_result"](observation_id)
    save_to_file(result_record, result_save_path)

In [None]:
psis_counts = {"0.5<=k<0.7": [0, []], "k<0.5": [0, []], "k>0.7": [0, []]}
for k, v in stats_logger.data["pareto_k"].items():
    if v >= 0.7:
        name = "k>0.7"
    elif v < 0.7 and v >= 0.5:
        name = "0.5<=k<0.7"
    elif v < 0.5:
        name = "k<0.5"
    else:
        raise ValueError
    psis_counts[name][0] += 1
    psis_counts[name][1].append(k)
psis_failed_observation_ids = psis_counts["k>0.7"][1]
psis_accept_inds = psis_counts["k<0.5"][1] + psis_counts["0.5<=k<0.7"][1]
print(len(psis_failed_observation_ids))

In [None]:
## Visualize the results
# from sbi_mcmc.utils.plot_utils import corner_plot

# # observation_id = sorted(step_1_accept_inds)[2]
# # observation_id = sorted(step_1_failed_inds)[1]
# # observation_id = ood_failed_inds[1]
# observation_id = psis_failed_observation_ids[1]
# print(f"observation_id: {observation_id}")
# psis_results = read_from_file(paths["psis_result"](observation_id))
# print(psis_results["pareto_k"])
# abi_samples = psis_results["abi_samples"]
# abi_psis_resamples = psis_results["abi_psis_resamples"]
# transform = None
# transform = task.transform_to_constrained_space
# corner_plot(
#     abi_samples,
#     abi_psis_resamples,
#     labels=["ABI", "ABI(PSIS)"],
#     transform=transform,
#     var_names=task.var_info.var_names_flatten,
# );

In [None]:
stats_logger.update(
    None,
    {
        "psis_counts": psis_counts,
        "reject_inds": psis_failed_observation_ids,
        "accept_inds": psis_accept_inds,
    },
)