In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from tqdm import tqdm

os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
from sbi_mcmc.tasks import *
from sbi_mcmc.utils.bf_utils import compute_summary_statistics
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.utils import *

In [None]:
stuff = get_stuff(
    job="ood",
)
task = stuff["task"]
paths = stuff["paths"]
test_dataset = stuff["test_dataset"]
test_dataset_name = stuff["test_dataset_name"]
stats_logger = stuff["stats_logger"]

Out of distribution test for the test datasets, against the training dataset.

In [None]:
approximator = keras.saving.load_model(paths["save_model_path"])

In [None]:
s_train_file = paths["inference_diagnostic_dir"] / f"train_summary_outputs.pkl"
s_test_file = (
    paths["inference_diagnostic_dir"]
    / f"{test_dataset_name}_summary_outputs.pkl"
)

train_dataset = read_from_file(paths["dataset_dir"] / f"train_dataset.pkl")
# use at most 10000 simulations for train summary outputs
train_dataset["observables"] = train_dataset["observables"][:10000]
with stats_logger.timer("compute_train_summary_outputs"):
    s_train = compute_summary_statistics(approximator, train_dataset)
save_to_file(s_train, s_train_file)
with stats_logger.timer(f"compute_test_summary_outputs_{test_dataset_name}"):
    s_test = compute_summary_statistics(approximator, test_dataset)
save_to_file(s_test, s_test_file)
print(s_train.shape, s_test.shape)

In [None]:
import seaborn as sns
from sklearn.covariance import EmpiricalCovariance

summary_test = "Mahalanobis"
with stats_logger.timer(f"compute_{summary_test}_{test_dataset_name}"):
    if summary_test == "Mahalanobis":
        cov = EmpiricalCovariance().fit(s_train)
        mahalanobis_from_sims = cov.mahalanobis(s_train)
        test_mahalanobis_from_sims = cov.mahalanobis(s_test)

        train_t_statistics = mahalanobis_from_sims
        test_t_statistics = test_mahalanobis_from_sims
    else:
        raise NotImplementedError

In [None]:
import matplotlib.pyplot as plt

with stats_logger.timer(f"thresholding_{test_dataset_name}"):
    quantile = 0.95
    cutoff = np.quantile(train_t_statistics, quantile)

    print(f"Cutoff: {cutoff}")

    # How much percent of the test samples are above the cutoff?
    print(
        f"Test samples above cutoff: {np.mean(test_t_statistics > cutoff) * 100:.2f}%"
    )
    ood_failed_inds = list(np.where(test_t_statistics > cutoff)[0])
    ood_accept_inds = list(np.where(test_t_statistics <= cutoff)[0])
    print(f"{len(ood_failed_inds)}/{len(test_t_statistics)}")

stats_logger.update(
    f"{summary_test}_{test_dataset_name}",
    {
        "cutoff": cutoff,
        "quantile": quantile,
        "ood_failed_inds": ood_failed_inds,
        "ood_accept_inds": ood_accept_inds,
        "train_t_statistics": train_t_statistics,
        "test_t_statistics": test_t_statistics,
    },
)
plt.hist(train_t_statistics, bins=50, alpha=0.5, label="Train", density=True)
plt.hist(test_t_statistics, bins=50, alpha=0.5, label="Test", density=True)

plt.axvline(
    cutoff,
    color="red",
    label=f"{quantile * 100:.0f}% quantile of {summary_test} distance in train set ($H_0$)",
)

plt.legend()
plt.xlabel("Test statistic")

# Save the figure
figure_path = (
    paths["inference_diagnostic_dir"]
    / "ood"
    / f"{summary_test}_distribution_{test_dataset_name}.png"
)
figure_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(figure_path, dpi=300)
plt.show()