In [None]:
import warnings
warnings.filterwarnings("ignore")
from pathlib import Path
import pandas as pd
from estival.sampling import tools as esamp
from tbdynamics.calib_utils import plot_output_ranges
from tbdynamics.inputs import load_targets
import arviz as az
from tbdynamics.calib_utils import get_bcm
from numpyro.distributions.transforms import AffineTransform, ComposeTransform
from numpyro.distributions import TransformedDistribution
import numpyro.distributions.constraints as constraints

In [None]:
OUT_PATH = Path.cwd() / 'runs/r1107'

In [None]:
quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]
spaghetti = pd.read_hdf(OUT_PATH / 'results.hdf', 'spaghetti')
quantile_outputs = esamp.quantiles_for_results(spaghetti, quantiles)
targets = load_targets()

In [None]:
# plot_spaghetti(spaghetti, ['total_population','notification'], 2, targets)

In [None]:
# plot_spaghetti(spaghetti, ['prevalence_pulmonary','incidence'], 2, targets)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['total_population','notification'], quantiles, 1, 2010, 2025)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['prevalence_pulmonary','incidence', 'percentage_latent'], quantiles, 1, 2010, 2025)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['mortality_raw'], quantiles, 1, 2010, 2025)

In [None]:
idata = az.from_netcdf(OUT_PATH / 'calib_full_out.nc')

In [None]:
idata = idata.sel(chain=[3,4,5,6])

In [None]:
az.plot_trace(idata, figsize=(16,3.1*len(idata.posterior)))

In [None]:
az.summary(idata)

In [None]:
import estival.priors as esp
from numpyro import distributions as dist
from typing import List
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def convert_prior_to_numpyro(prior):
    """Converts a given prior to a corresponding Numpyro distribution based on its type."""
    if isinstance(prior, esp.UniformPrior):
        return dist.Uniform(low=prior.start, high=prior.end), None
    elif isinstance(prior, esp.TruncNormalPrior):
         return dist.TruncatedNormal(loc=prior.mean, scale=prior.stdev, low=prior.trunc_range[0], high=prior.trunc_range[1]), (prior.trunc_range[0], prior.trunc_range[1])
    elif isinstance(prior, esp.GammaPrior):
        rate = 1.0 / prior.scale
        return dist.Gamma(concentration=prior.shape, rate=rate), None
    else:
        raise TypeError(f"Unsupported prior type: {type(prior).__name__}")

def convert_all_priors_to_numpyro(priors):
    numpyro_priors = {}
    for key, prior in priors.items():
        numpyro_prior, _ = convert_prior_to_numpyro(prior)
        numpyro_priors[key] = numpyro_prior
    return numpyro_priors

In [None]:
# def convert_all_priors_to_numpyro(priors):
#     numpyro_priors = {}
#     for key, prior in priors.items():
#         numpyro_priors[key] = convert_prior_to_numpyro(prior)
#     return numpyro_priors

In [None]:
params = {
    "start_population_size": 2300000.0,
    "seed_time": 1830.0,
    "seed_num": 100.0,
    "seed_duration": 20.0,
}

In [None]:
def normalize_prior_to_posterior(prior, posterior_samples, x_vals):
    """Normalize the prior density to match the area of the posterior density over the given range."""
    prior_density = np.exp(prior.log_prob(x_vals))
    posterior_density, bin_edges = np.histogram(posterior_samples, bins=x_vals, density=True)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    
    area_prior = np.trapz(prior_density, x_vals)
    area_posterior = np.trapz(posterior_density, bin_centers)
    
    scaling_factor = area_posterior / area_prior if area_prior != 0 else 1
    normalized_prior_density = prior_density * scaling_factor
    
    return normalized_prior_density, bin_centers, posterior_density

def plot_post_prior_comparison(
    idata: az.InferenceData, 
    req_vars: list, 
    priors: list,
):
    """Plot comparison of model posterior outputs against priors.

    Args:
        idata: Arviz inference data from calibration
        req_vars: User-requested variables to plot
        priors: List of Numpyro prior distribution objects

    Returns:
        The figure object
    """
    num_vars = len(req_vars)
    num_rows = (num_vars + 1) // 2  # This ensures an even distribution across two columns

    plot = az.plot_density(
        idata, 
        var_names=req_vars, 
        shade=0.3, 
        grid=(num_rows, 2)  # Set the grid to have num_rows rows and 2 columns
    )

    for i_ax, ax in enumerate(plot.ravel()):
        if i_ax < len(req_vars):
            var_name = req_vars[i_ax]
            posterior_samples = idata.posterior[var_name].values.flatten()
            low, high = np.percentile(posterior_samples, [2.5, 97.5])
            x_vals = np.linspace(low, high, 100)

            normalized_prior_density, bin_centers, posterior_density = normalize_prior_to_posterior(priors[i_ax], posterior_samples, x_vals)
            ax.fill_between(x_vals, normalized_prior_density, color="k", alpha=0.2, linewidth=2)
            # ax.plot(x_vals, normalized_prior_density, color="k", linewidth=1)  # Add a line plot for better visibility

            # ax.plot(bin_centers, posterior_density, color="b", linewidth=1, linestyle='dashed')  # Posterior density for comparison

    plt.show()


In [None]:
numpyro_priors = convert_all_priors_to_numpyro(get_bcm(params).priors)

In [None]:
req_vars = list(numpyro_priors.keys())

In [None]:
plot_post_prior_comparison(idata, req_vars, [numpyro_priors[var] for var in req_vars])