In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

os.environ["JAX_ENABLE_X64"] = "False"
os.environ["KERAS_BACKEND"] = "jax"

sys.path.append("../BayesFlow/")
sys.path.append("../")

import pytensor
from jax import config

pytensor.config.floatX = "float32"
config.floatX = "float32"
config.update("jax_enable_x64", False)

import warnings
from pathlib import Path
from pprint import pprint

import bayesflow as bf
import keras
import matplotlib.pyplot as plt
import numpy as np
from sbi_mcmc.utils.experiment_utils import *
from sbi_mcmc.utils.utils import *

warnings.filterwarnings(
    "ignore", message="The figure layout has changed to tight"
)

In [None]:
from sbi_mcmc.tasks import (
    BernoulliGLMTask,
    CustomDDM,
    GeneralizedExtremeValue,
    PsychometricTask,
)

# task = GeneralizedExtremeValue()
task = None  # If task is None, the task will be created in the get_stuff function according to the `config.yaml` file.
stuff = get_stuff(task, job="training")
paths = stuff["paths"]
task = stuff["task"]
config = stuff["config"]
SMOKE_TEST = config.get("smoke_test", False)
print(task.var_names)

result_logger = PickleStatLogger(
    paths["training_result_dir"] / "training_record.pkl", overwrite=False
)
print(
    f"save_dir: {paths['save_dir']},\nresult_logger: {result_logger.filepath}"
)

In [None]:
task_config = get_task_configs(task)
dataset_size_dict = task_config["dataset_size_dict"]
epochs = task_config["epochs"]
batch_size = task_config["batch_size"]
pprint(task_config)

Generate or read prior simulations

In [None]:
import logging

logging.getLogger("sbi_mcmc.tasks.tasks").setLevel(logging.CRITICAL)
logging.getLogger("sbi_mcmc.tasks.ddm").setLevel(logging.CRITICAL)

for name, logger in logging.root.manager.loggerDict.items():
    if name.startswith("sbi_mcmc.tasks."):
        logger.setLevel(logging.CRITICAL)
REGENERATE = ["train", "val", "diagnostic", "lc2st_cal"]
# REGENERATE = []
if SMOKE_TEST:
    assert len(REGENERATE) == 0, "REGENERATE should be empty for smoke test. "
prior_simulations_N = {}
for dataset_name in ["train", "val", "diagnostic", "lc2st_cal"]:
    filepath = paths["dataset_dir"] / f"{dataset_name}_dataset.pkl"
    num_simulations = dataset_size_dict[dataset_name]
    if dataset_name in REGENERATE:
        with result_logger.timer(f"simulation_{dataset_name}"):
            simulations = task.sample(num_simulations)
        save_to_file(simulations, filepath)
    else:
        print(f"Loading {dataset_name} dataset from {filepath}")
        simulations = read_from_file(filepath)
    print(f"{filepath.name}: {filepath.stat().st_size / 1e6:.2f} MB")
    prior_simulations_N[dataset_name] = simulations

if not SMOKE_TEST:
    check_dataset_size_consistency(prior_simulations_N, dataset_size_dict)
print("=== \nTrain dataset:")
for k, v in prior_simulations_N["train"].items():
    print(k, v.shape)
check_prior_simulations(prior_simulations_N, task)

prior_simulations = prior_simulations_N

In [None]:
bf_workflow_kwargs, bf_info = get_bf_configs(task, smoke_test=SMOKE_TEST)
amortized_training_workflow = bf.BasicWorkflow(
    **bf_workflow_kwargs,
)

In [None]:
# # For resuming training from a checkpoints
# amortized_training_workflow.approximator = keras.saving.load_model(
#     bf_info["save_model_path"]
# )

In [None]:
with result_logger.timer("training"):
    history = amortized_training_workflow.fit_offline(
        prior_simulations["train"],
        epochs=epochs,
        batch_size=batch_size,
        validation_data=prior_simulations["val"],
    )

In [None]:
fig = bf.diagnostics.plots.loss(history)
fig.savefig(paths["figure_dir"] / "loss.png", dpi=300)

In [None]:
model_path = bf_info["save_model_path"]
if SMOKE_TEST:
    model_path = Path(
        *[p for p in bf_info["save_model_path"].parts if p != "smoke_test"]
    )
approximator = keras.saving.load_model(model_path)
print(f"Approximator loaded from {model_path}, ")

In [None]:
with result_logger.timer("diagnostic_post_draws"):
    post_draws = approximator.sample(
        conditions=prior_simulations["diagnostic"], num_samples=1000
    )

In [None]:
with result_logger.timer("diagnostic_transform_prior_sims_params"):
    prior_simulations["diagnostic"]["parameters_original"] = (
        task.transform_to_constrained_space(
            prior_simulations["diagnostic"]["parameters"]
        )
    )
with result_logger.timer("diagnostic_transform_post_draws_params"):
    abi_samples = post_draws["parameters"]
    abi_samples_constrained = np.zeros_like(abi_samples)
    for i in range(abi_samples.shape[0]):
        abi_samples_constrained[i] = task.transform_to_constrained_space(
            abi_samples[i]
        )
    post_draws["parameters_original"] = abi_samples_constrained

In [None]:
from bayesflow.diagnostics import plots as bf_plots

display_names = {
    "CustomDDM(dt-0.0001)": (
        "$v_1$",
        "$v_2$",
        "$a_1$",
        "$a_2$",
        r"$\tau_c$",
        r"$\tau_n$",
    ),
    "BernoulliGLM": [rf"$\theta_{i + 1}$" for i in range(10)],
    "psychometric_curve_overdispersion": [
        r"$\tilde{m}$",
        "$w$",
        r"$\gamma$",
        r"$\lambda$",
        r"$\eta$",
    ],
    "GEV": [r"$\mu$", r"$\sigma$", r"$\xi$"],
}

plot_fns = {
    "recovery": bf_plots.recovery,
    "calibration_ecdf": bf_plots.calibration_ecdf,
    "z_score_contraction": bf_plots.z_score_contraction,
    "calibration_histogram": bf.diagnostics.plots.calibration_histogram,
}
kwargs = {"calibration_ecdf_kwargs": {"difference": True}}

for constrained_space in [True, False]:
    param_key = "parameters"
    if constrained_space:
        param_key += "_original"
        figure_dir = paths["figure_dir"] / "constrained_space"
    else:
        figure_dir = paths["figure_dir"] / "unconstrained_space"
    figure_dir.mkdir(parents=True, exist_ok=True)

    figures = {}
    for k, plot_fn in plot_fns.items():
        figures[k] = plot_fn(
            estimates=post_draws[param_key],
            targets=prior_simulations["diagnostic"][param_key],
            variable_names=display_names.get(
                task.name, task.var_info.var_names_flatten
            ),
            **kwargs.get(f"{k}_kwargs", {}),
        )
        filepath = figure_dir / f"{k}({bf_info['flow_type']}).pdf"
        figures[k].savefig(filepath)
        print(f"Saved {filepath}")