In [None]:
import warnings

warnings.filterwarnings("ignore")
from pathlib import Path
import numpy as np
import arviz as az
from tbdynamics.calibration.az_aux import (
    tabulate_calib_results,
    plot_post_prior_comparison,
    plot_trace,
)
from tbdynamics.calibration.utils import get_bcm, get_all_priors
from tbdynamics.constants import params_name

In [None]:
OUT_PATH = Path.cwd() / 'runs/r2608'
idata = az.from_netcdf(OUT_PATH / 'calib_full_out.nc')
burnt_idata = idata.sel(draw=np.s_[50000:])

In [None]:
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}
covid_effects = {"detection_reduction": True, "contact_reduction": True}
bcm = get_bcm(params, covid_effects)

In [None]:
fig = plot_post_prior_comparison(burnt_idata, bcm.priors, params_name)

In [None]:
tabulate_calib_results(burnt_idata, params_name)

In [None]:
import matplotlib.pyplot as plt
def plot_trace1(idata: az.InferenceData, params_name: dict):
    """
    Plot trace plots for the InferenceData object, excluding parameters containing '_dispersion'.
    Adds descriptive titles from `params_name`.

    Args:
        idata: InferenceData object from ArviZ containing calibration outputs.
        params_name: Dictionary mapping parameter names to descriptive titles.

    Returns:
        trace_fig: The figure object containing trace plots.
    """
    # Filter out parameters containing '_dispersion'
    filtered_posterior = idata.posterior.drop_vars(
        [var for var in idata.posterior.data_vars if "_dispersion" in var]
    )

    # Plot trace plots with the filtered parameters
    az.plot_trace(
        filtered_posterior, figsize=(16, 3.1 * len(filtered_posterior.data_vars))
    )

    # Get the current figure
    trace_fig = plt.gcf()

    # Set titles for each row of plots
    var_names = list(filtered_posterior.data_vars.keys())  # Get the list of variable names
    for i, var_name in enumerate(var_names):
        row_axes = trace_fig.axes[i * 2:i * 2 + 2]  # Get the axes in the current row
        title = params_name.get(var_name, var_name)  # Get the title from params_name or default to var_name
        row_axes[0].set_title(title, fontsize=14, loc="center")  # Set title for the first column
        row_axes[1].set_title("")  # Clear the title for the second column

    plt.tight_layout()

    return trace_fig  # Return the figure object

In [None]:
tracing = plot_trace1(burnt_idata,params_name)

In [None]:
tracing.savefig('docs/param_traces.png', dpi=300, bbox_inches='tight', format='png', pad_inches=0)