In [None]:
import warnings
warnings.filterwarnings("ignore")
from pathlib import Path
import pandas as pd
from estival.sampling import tools as esamp
from tbdynamics.calib_utils import plot_output_ranges, plot_quantiles_for_case_notifications
from tbdynamics.inputs import load_targets
import arviz as az
from tbdynamics.calib_utils import get_bcm
from scipy.stats import gaussian_kde

In [None]:
OUT_PATH = Path.cwd() / 'runs/r1207'

In [None]:
quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]
spaghetti = pd.read_hdf(OUT_PATH / 'results.hdf', 'spaghetti')
quantile_outputs = esamp.quantiles_for_results(spaghetti, quantiles)
targets = load_targets()

In [None]:
# plot_spaghetti(spaghetti, ['total_population','notification'], 2, targets)

In [None]:
# plot_spaghetti(spaghetti, ['prevalence_pulmonary','incidence'], 2, targets)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['total_population','notification'], quantiles, 1, 2010, 2025)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['prevalence_pulmonary','incidence', 'percentage_latent'], quantiles, 1, 2010, 2025)

In [None]:
plot_output_ranges(quantile_outputs,targets, ['mortality_raw'], quantiles, 1, 2010, 2025)

In [None]:
plot_quantiles_for_case_notifications(quantile_outputs['incidence_raw'], pd.Series(targets['notification']), quantiles, plot_end_date=2025)

In [None]:
idata = az.from_netcdf(OUT_PATH / 'calib_full_out.nc')

In [None]:
idata = idata.sel(chain=[0,1,2,3], draw=slice(50000,None))

In [None]:
az.plot_trace(idata, figsize=(16,3.1*len(idata.posterior)))

In [None]:
az.summary(idata)

In [None]:
import estival.priors as esp
from numpyro import distributions as dist
import numpy as np
import matplotlib.pyplot as plt

In [None]:
params = {
    "start_population_size": 2300000.0,
    "seed_time": 1830.0,
    "seed_num": 100.0,
    "seed_duration": 20.0,
    # "contact_rate": 0.02977583831288669,
    # "rr_infection_latent": 0.20344010763518713,
    # "rr_infection_recovered": 0.40580870889350107,
    # "progression_multiplier": 0.8810860029360905,
    # "smear_positive_death_rate": 0.4313851033562638,
    # "smear_negative_death_rate": 0.03350161278620193,
    # "smear_positive_self_recovery": 0.28604824197753914,
    # "smear_negative_self_recovery": 0.15805647865361552,
    "screening_scaleup_shape": 0.25,
    # "screening_inflection_time": 1995.1487440977369,
    # "time_to_screening_end_asymp": 1.111127871433536,
    # "detection_reduction": 0.24351558358481182,
    # "contact_reduction": 0.3813306382411676,
}

In [None]:
def convert_prior_to_numpyro(prior):
    """
    Converts a given custom prior to a corresponding Numpyro distribution and its bounds based on its type.
    
    Args:
        prior: A custom prior object.
        
    Returns:
        A tuple of (Numpyro distribution, bounds).
    """
    if isinstance(prior, esp.UniformPrior):
        return dist.Uniform(low=prior.start, high=prior.end), (prior.start, prior.end)
    elif isinstance(prior, esp.TruncNormalPrior):
        return dist.TruncatedNormal(loc=prior.mean, scale=prior.stdev, low=prior.trunc_range[0], high=prior.trunc_range[1]), (prior.trunc_range[0], prior.trunc_range[1])
    elif isinstance(prior, esp.GammaPrior):
        rate = 1.0 / prior.scale
        return dist.Gamma(concentration=prior.shape, rate=rate), None
    elif isinstance(prior, esp.BetaPrior):
        return dist.Beta(concentration1=prior.a, concentration0=prior.b), (0, 1)
    else:
        raise TypeError(f"Unsupported prior type: {type(prior).__name__}")

def convert_all_priors_to_numpyro(priors):
    """
    Converts a dictionary of custom priors to a dictionary of corresponding Numpyro distributions.
    
    Args:
        priors: Dictionary of custom prior objects.
        
    Returns:
        Dictionary of Numpyro distributions.
    """
    numpyro_priors = {}
    for key, prior in priors.items():
        numpyro_prior, _ = convert_prior_to_numpyro(prior)
        numpyro_priors[key] = numpyro_prior
    return numpyro_priors

def normalize_prior_to_posterior(prior, posterior_samples, x_vals_prior, x_vals_posterior):
    """
    Normalize the prior density to match the area of the posterior density over the given range.
    
    Args:
        prior: Numpyro distribution object.
        posterior_samples: Posterior samples array.
        x_vals_prior: X values for the prior density.
        x_vals_posterior: X values for the posterior density.
        
    Returns:
        Tuple of normalized prior density, x values for prior, normalized posterior density, and x values for posterior.
    """
    prior_density = np.exp(prior.log_prob(x_vals_prior))
    posterior_kde = gaussian_kde(posterior_samples)
    posterior_density = posterior_kde(x_vals_posterior)
    
    area_prior = np.trapz(prior_density, x_vals_prior)
    area_posterior = np.trapz(posterior_density, x_vals_posterior)
    
    # Normalize prior density
    normalized_prior_density = prior_density / area_prior if area_prior != 0 else prior_density
    
    # Scale posterior density to match the area of the prior
    scaling_factor = area_prior / area_posterior if area_posterior != 0 else 1
    normalized_posterior_density = posterior_density * scaling_factor
    
    # Verify areas for debugging
    new_area_prior = np.trapz(normalized_prior_density, x_vals_prior)
    new_area_posterior = np.trapz(normalized_posterior_density, x_vals_posterior)
    print(f"Normalized Prior Area: {new_area_prior}, Normalized Posterior Area: {new_area_posterior}")

    return normalized_prior_density, x_vals_prior, normalized_posterior_density, x_vals_posterior

def plot_post_prior_comparison(idata, req_vars, priors):
    """
    Plot comparison of model posterior outputs against priors.
    
    Args:
        idata: Arviz inference data from calibration.
        req_vars: User-requested variables to plot.
        priors: Dictionary of custom prior objects.
        
    Returns:
        The figure object.
    """
    num_vars = len(req_vars)
    num_rows = (num_vars + 1) // 2  # Ensure even distribution across two columns

    fig, axs = plt.subplots(num_rows, 2, figsize=(10, 5*num_rows))
    axs = axs.ravel()

    for i_ax, ax in enumerate(axs):
        if i_ax < len(req_vars):
            var_name = req_vars[i_ax]
            posterior_samples = idata.posterior[var_name].values.flatten()
            low_post = np.min(posterior_samples)
            high_post = np.max(posterior_samples)
            x_vals_posterior = np.linspace(low_post, high_post, 100)
            
            numpyro_prior, prior_bounds = convert_prior_to_numpyro(priors[var_name])
            if prior_bounds:
                low_prior, high_prior = prior_bounds
                x_vals_prior = np.linspace(low_prior, high_prior, 100)
            else:
                x_vals_prior = x_vals_posterior  # Fallback if no specific prior bounds are given

            normalized_prior_density, x_vals_prior, normalized_posterior_density, x_vals_posterior = normalize_prior_to_posterior(
                numpyro_prior, posterior_samples, x_vals_prior, x_vals_posterior)
            
            ax.fill_between(x_vals_prior, normalized_prior_density, color="k", alpha=0.2, linewidth=2, label='Normalized Prior')
            ax.plot(x_vals_posterior, normalized_posterior_density, color="b", linewidth=1, linestyle='dashed', label='Normalized Posterior')
            ax.set_title(f'{var_name}')
            ax.legend()

    plt.tight_layout()
    plt.show()

In [None]:
priors = get_bcm(params).priors
req_vars = list(priors.keys())

In [None]:
plot_post_prior_comparison(idata, req_vars, priors)

In [None]:
variables = list(idata.posterior.data_vars)
n_variables = len(variables)

# Calculate the number of rows needed for 2 columns
nrows = (n_variables + 1) // 2

fig, axes = plt.subplots(nrows=nrows, ncols=2, figsize=(12, nrows * 4))
axes = axes.flatten()

# Plot each variable
for ax, var in zip(axes, variables):
    az.plot_rank(idata, var_names=[var], kind="bars", ax=ax)

# Remove any empty subplots
for i in range(n_variables, len(axes)):
    fig.delaxes(axes[i])

plt.tight_layout()
plt.show()