In [None]:
import warnings
warnings.filterwarnings("ignore")
from pathlib import Path
import pandas as pd
from estival.sampling import tools as esamp
from tbdynamics.calib_utils import plot_output_ranges
from tbdynamics.inputs import load_targets
import arviz as az
from tbdynamics.calib_utils import get_bcm

In [None]:
OUT_PATH = Path.cwd() / 'runs/r1206'

In [None]:
quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]
spaghetti = pd.read_hdf(OUT_PATH / 'results.hdf', 'spaghetti')
quantile_outputs = esamp.quantiles_for_results(spaghetti, quantiles)
targets = load_targets()

In [None]:
# plot_spaghetti(spaghetti, ['total_population','notification'], 2, targets)

In [None]:
# plot_spaghetti(spaghetti, ['prevalence_pulmonary','incidence'], 2, targets)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['total_population','notification'], quantiles, 1, 2010, 2025)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['prevalence_pulmonary','incidence', 'percentage_latent'], quantiles, 1, 2010, 2025)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['case_detection_rate'], quantiles, 1, 2015, 2025)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['mortality_raw'], quantiles, 1, 2010, 2025)

In [None]:
idata = az.from_netcdf(OUT_PATH / 'calib_full_out.nc')

In [None]:
idata = idata.sel(draw=slice(5000, None))

In [None]:
az.plot_trace(idata, figsize=(16,3.1*len(idata.posterior)))

In [None]:
az.summary(idata)

In [None]:
import estival.priors as esp
from numpyro import distributions as dist
from typing import List
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def convert_prior_to_numpyro(prior):
    """Converts a given prior to a corresponding Numpyro distribution based on its type."""
    if isinstance(prior, esp.UniformPrior):
        # Convert UniformPrior to Numpyro's Uniform distribution
        return dist.Uniform(low=prior.start, high=prior.end)
    
    elif isinstance(prior, esp.TruncNormalPrior):
        # Since Numpyro doesn't have a direct TruncatedNormal, use Normal and apply truncation
        base_normal = dist.Normal(loc=prior.mean, scale=prior.stdev)
        return dist.TruncatedDistribution(base_distribution=base_normal, low=prior.trunc_range[0], high=prior.trunc_range[1])
    
    elif isinstance(prior, esp.GammaPrior):
        # Convert GammaPrior to Numpyro's Gamma distribution (Numpyro uses 'rate' instead of 'scale')
        rate = 1.0 / prior.scale
        return dist.Gamma(concentration=prior.shape, rate=rate)
    
    else:
        raise TypeError(f"Unsupported prior type: {type(prior).__name__}")

In [None]:
def convert_all_priors_to_numpyro(priors):
    numpyro_priors = {}
    for key, prior in priors.items():
        numpyro_priors[key] = convert_prior_to_numpyro(prior)
    return numpyro_priors

In [None]:
numpyro_priors = convert_all_priors_to_numpyro(get_bcm().priors)

In [None]:
def plot_post_prior_comparison(
    idata: az.InferenceData, 
    req_vars: List[str], 
    priors: List[dist.Distribution],
):
    """Plot comparison of model posterior outputs against priors.

    Args:
        idata: Arviz inference data from calibration
        req_vars: User-requested variables to plot
        priors: List of Numpyro prior distribution objects

    Returns:
        The figure object
    """
    # Determine the number of rows needed for two columns
    num_vars = len(req_vars)
    num_rows = (num_vars + 1) // 2  # This ensures an even distribution across two columns

    # Create the density plot with specified grid dimensions
    plot = az.plot_density(
        idata, 
        var_names=req_vars, 
        shade=0.3, 
        grid=(num_rows, 2)  # Set the grid to have num_rows rows and 2 columns
    )

    # Overlay the prior distributions
    for i_ax, ax in enumerate(plot.ravel()):
        if i_ax < len(req_vars):  # Ensure we don't exceed the number of requested variables
            ax_limits = ax.get_xlim()
            x_vals = np.linspace(*ax_limits, 50)
            y_vals = np.exp(priors[i_ax].log_prob(x_vals))
            y_vals *= ax.get_ylim()[1] / max(y_vals)  # Normalize prior to the plot's y-axis
            ax.fill_between(x_vals, y_vals, color="k", alpha=0.2, linewidth=2)

    plt.show()

In [None]:
plot_post_prior_comparison(idata, list(numpyro_priors.keys()), [numpyro_priors[var] for var in list(numpyro_priors.keys())])