In [None]:
import warnings

warnings.filterwarnings("ignore")
from pathlib import Path
import numpy as np
import arviz as az
from tbdynamics.calibration.az_aux import (
    tabulate_calib_results,
    plot_post_prior_comparison,
    plot_trace,
)
from tbdynamics.calibration.utils import get_bcm
from tbdynamics.constants import params_name
from tbdynamics.inputs import get_death_rate, process_universal_death_rate, get_birth_rate
import numpyro.distributions as dist
import estival.priors as esp
import matplotlib.pyplot as plt
from jax import random
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm, gaussian_kde
import pandas as pd


In [2]:
def plot_derived_comparison(priors, posterior_metrics, params_name):
    """
    Plot comparison of derived outputs (disease duration and CFR) between priors and posteriors.

    Args:
        priors: Dictionary of Numpyro prior distributions for the derived metrics.
        posterior_metrics: Dictionary of posterior samples for the derived metrics.
        params_name: Dictionary mapping parameter names to descriptive titles.

    Returns:
        The figure object.
    """
    # Derived parameters to compare
    derived_vars = ['duration_positive', 'cfr_positive', 'duration_negative', 'cfr_negative']
    num_vars = len(derived_vars)
    num_rows = (num_vars + 1) // 2  # Ensure even distribution across two columns

    # Clear previous figures to avoid double plotting issues
    plt.clf()

    # Set figure size to match A4 page width in portrait mode and adjust height based on rows
    fig, axs = plt.subplots(num_rows, 2, figsize=(28, 6.2 * num_rows))
    axs = axs.ravel()

    for i_ax, ax in enumerate(axs):
        if i_ax < num_vars:
            var_name = derived_vars[i_ax]

            # Extract posterior samples and flatten them
            posterior_samples = posterior_metrics[f'post_{var_name}'].flatten()
            low_post = np.min(posterior_samples)
            high_post = np.max(posterior_samples)
            x_vals_posterior = np.linspace(low_post, high_post, 100)

            # Compute the posterior density using KDE
            post_kde = gaussian_kde(posterior_samples)
            posterior_density = post_kde(x_vals_posterior)

            # Extract the prior distribution and sample it, flatten the samples
            prior_dist = priors[f'prior_{var_name}']
            prior_samples = prior_dist.sample(random.PRNGKey(0), (1000,)).flatten()
            low_prior = np.min(prior_samples)
            high_prior = np.max(prior_samples)
            x_vals_prior = np.linspace(low_prior, high_prior, 100)

            # Compute the prior density using KDE
            prior_kde = gaussian_kde(prior_samples)
            prior_density = prior_kde(x_vals_prior)

            # Plot the prior distribution
            ax.fill_between(
                x_vals_prior,
                prior_density,
                color="k",
                alpha=0.2,
                linewidth=2,
                label="Prior",
            )
            # Plot the posterior distribution
            ax.plot(
                x_vals_posterior,
                posterior_density,
                color="b",
                linewidth=1,
                linestyle="solid",
                label="Posterior",
            )

            # Set the title using the descriptive name from params_name
            title = params_name.get(var_name, var_name)  # Use var_name if not in params_name
            ax.set_title(title, fontsize=30, fontname='Arial')
            ax.tick_params(axis='both', labelsize=24)

            # Add legend to the first subplot
            if i_ax == 0:
                ax.legend(fontsize=24)
        else:
            ax.axis("off")  # Turn off empty subplots if the number of req_vars is odd

    # Adjust padding and spacing
    plt.tight_layout(h_pad=1.0, w_pad=5)
    return fig

In [2]:
OUT_PATH = Path.cwd().parent / 'runs/best2208'
idata = az.from_netcdf(OUT_PATH / 'calib_full_out.nc')
burnt_idata = idata.sel(draw=np.s_[50000:])

In [3]:
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}
covid_effects = {"detection_reduction": True, "contact_reduction": False}
bcm = get_bcm(params, covid_effects)

In [None]:
fig = plot_post_prior_comparison(burnt_idata, bcm.priors, params_name)

In [6]:
universal_death = process_universal_death_rate(get_death_rate())

In [9]:
priors = {
    'smear_positive_death_rate': bcm.priors['smear_positive_death_rate'],
    'smear_positive_self_recovery': bcm.priors['smear_positive_self_recovery'],
    'smear_negative_death_rate': bcm.priors['smear_negative_death_rate'],
    'smear_negative_self_recovery': bcm.priors['smear_negative_self_recovery'],
}
# posterior_metrics = process_idata_for_derived_metrics(burnt_idata)

In [8]:
def plot_derived_comparison(priors, posterior_metrics, params_name):
    """
    Plot comparison of derived outputs (disease duration and CFR) between priors and posteriors.

    Args:
        priors: Dictionary of Numpyro prior distributions for the derived metrics.
        posterior_metrics: Dictionary of posterior samples for the derived metrics.
        params_name: Dictionary mapping parameter names to descriptive titles.

    Returns:
        The figure object.
    """
    # Derived parameters to compare
    derived_vars = ['duration_positive', 'cfr_positive', 'duration_negative', 'cfr_negative']
    num_vars = len(derived_vars)
    num_rows = (num_vars + 1) // 2  # Ensure even distribution across two columns

    # Clear previous figures to avoid double plotting issues
    plt.clf()

    # Set figure size to match A4 page width in portrait mode and adjust height based on rows
    fig, axs = plt.subplots(num_rows, 2, figsize=(28, 6.2 * num_rows))
    axs = axs.ravel()

    for i_ax, ax in enumerate(axs):
        if i_ax < num_vars:
            var_name = derived_vars[i_ax]

            # Extract posterior samples and flatten them
            posterior_samples = posterior_metrics[f'post_{var_name}'].flatten()
            low_post = np.min(posterior_samples)
            high_post = np.max(posterior_samples)
            x_vals_posterior = np.linspace(low_post, high_post, 100)

            # Compute the posterior density using KDE
            post_kde = gaussian_kde(posterior_samples)
            posterior_density = post_kde(x_vals_posterior)

            # Extract the prior distribution and sample it, flatten the samples
            prior_dist = priors[f'prior_{var_name}']
            prior_samples = prior_dist.sample(random.PRNGKey(0), (1000,)).flatten()
            low_prior = np.min(prior_samples)
            high_prior = np.max(prior_samples)
            x_vals_prior = np.linspace(low_prior, high_prior, 100)

            # Compute the prior density using KDE
            prior_kde = gaussian_kde(prior_samples)
            prior_density = prior_kde(x_vals_prior)

            # Plot the prior distribution
            ax.fill_between(
                x_vals_prior,
                prior_density,
                color="k",
                alpha=0.2,
                linewidth=2,
                label="Prior",
            )
            # Plot the posterior distribution
            ax.plot(
                x_vals_posterior,
                posterior_density,
                color="b",
                linewidth=1,
                linestyle="solid",
                label="Posterior",
            )

            # Set the title using the descriptive name from params_name
            title = params_name.get(var_name, var_name)  # Use var_name if not in params_name
            ax.set_title(title, fontsize=30, fontname='Arial')
            ax.tick_params(axis='both', labelsize=24)

            # Add legend to the first subplot
            if i_ax == 0:
                ax.legend(fontsize=24)
        else:
            ax.axis("off")  # Turn off empty subplots if the number of req_vars is odd

    # Adjust padding and spacing
    plt.tight_layout(h_pad=1.0, w_pad=5)
    return fig

In [None]:
tabulate_calib_results(burnt_idata, params_name)

In [11]:
tracing = plot_trace(idata, params_name)

In [None]:
tracing

In [13]:
# tracing.savefig('../docs/param_traces.png', dpi=300, bbox_inches='tight', format='png', pad_inches=0)

In [None]:
def calculate_derived_metrics(
    death_rate, recovery_rate, natural_death_rate, time_period
):
    """
    Calculate derived disease duration and CFR from death, recovery, and natural death rates.

    Args:
        death_rate (float): The TB-related death rate (μT).
        recovery_rate (float): The self-recovery rate (γ).
        natural_death_rate (float): The natural death rate (μ).
        time_period (float): The time period T for the calculation.

    Returns:
        disease_duration (float): The average duration of the disease (1 / (death_rate + recovery_rate + natural_death_rate)).
        cfr (float): The case fatality rate calculated based on the given formula.
    """
    # Calculate disease duration
    disease_duration = 1 / (death_rate + recovery_rate + natural_death_rate)

    # Calculate CFR based on the formula
    term1 = (recovery_rate / (recovery_rate + death_rate)) * np.exp(
        -natural_death_rate * time_period
    )
    term2 = (death_rate / (recovery_rate + death_rate)) * np.exp(
        -(recovery_rate + death_rate + natural_death_rate) * time_period
    )
    cfr = 1 - term1 - term2

    return disease_duration, cfr


# Helper function to sample from truncated normal distribution
def sample_truncated_normal(mean, stdev, trunc_range, num_samples=100000):
    """
    Samples from a truncated normal distribution given the mean, stdev, and truncation range.

    Args:
    - mean: Mean of the distribution.
    - stdev: Standard deviation of the distribution.
    - trunc_range: Tuple containing (lower bound, upper bound) for truncation.
    - num_samples: Number of samples to draw.

    Returns:
    - samples: NumPy array of sampled values.
    """
    a, b = (trunc_range[0] - mean) / stdev, (
        trunc_range[1] - mean
    ) / stdev  # Convert to truncated normal bounds
    return truncnorm(a, b, loc=mean, scale=stdev).rvs(num_samples)


# Function to sample from the priors and plot the densities
def sample_priors_and_plot_derived(priors, num_samples=10000):
    """
    Sample from the priors, calculate derived metrics (disease duration and CFR), and plot the density.

    Args:
    - priors: A dictionary containing the truncated normal priors with mean, stdev, and truncation range.
    - num_samples: Number of samples to draw from each prior.

    Returns:
    - A dictionary containing the sampled values and derived metrics.
    """
    derived_metrics = {
        "duration_positive": [],
        "cfr_positive": [],
        "duration_negative": [],
        "cfr_negative": [],
    }

    # Set up plot for derived metrics
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))
    axs = axs.ravel()

    # Titles for the plots
    plot_titles = [
        "Duration (Smear Positive)",
        "CFR (Smear Positive)",
        "Duration (Smear Negative)",
        "CFR (Smear Negative)",
    ]
    results = []

    # Define universal death rate (assumed)
    # universal_death = 0.01  # You need to adjust this value according to your model

    # List of cases (positive/negative) to handle them uniformly
    cases = ["positive", "negative"]

    for case in cases:
        # Sample death rates for the case
        death_rate_key = f"smear_{case}_death_rate"
        recovery_rate_key = f"smear_{case}_self_recovery"

        # Sample death rate
        samples_death_rate = sample_truncated_normal(
            priors[death_rate_key].mean,
            priors[death_rate_key].stdev,
            priors[death_rate_key].trunc_range,
            num_samples,
        )

        # Sample recovery rate
        recovery_rate = sample_truncated_normal(
            priors[recovery_rate_key].mean,
            priors[recovery_rate_key].stdev,
            priors[recovery_rate_key].trunc_range,
            num_samples,
        )

        # Calculate derived metrics (duration and CFR) for each sample
        for death_rate, rec_rate in zip(samples_death_rate, recovery_rate):
            duration, cfr = calculate_derived_metrics(
                death_rate, rec_rate, universal_death[2023], 1
            )
            derived_metrics[f"duration_{case}"].append(duration)
            derived_metrics[f"cfr_{case}"].append(cfr)

    # Convert lists to numpy arrays for easier plotting
    derived_metrics = {key: np.array(value) for key, value in derived_metrics.items()}

    # Plot the derived metrics
    for i, (metric_key, metric_values) in enumerate(derived_metrics.items()):
        kde = gaussian_kde(metric_values)
        x_vals = np.linspace(np.min(metric_values), np.max(metric_values), 1000)
        density = kde(x_vals)
        # Plotting the density
        axs[i].plot(x_vals, density, color="blue")

        # Calculate the mean and 95% CI (0.025 and 0.975 quantiles)
        mean_val = np.mean(metric_values)
        quantiles = np.percentile(metric_values, [2.5, 97.5])
        # Append the results to the table list
        results.append(
            {
                "Metric": plot_titles[i],
                "Mean": f"{mean_val:.3f}",
                "2.5% Quantile": f"{quantiles[0]:.f}",
                "97.5% Quantile": f"{quantiles[1]:.3f}",
            }
        )

        # # Plot vertical lines for the mean and quantiles
        # axs[i].axvline(
        #     mean_val, color="red", linestyle="--", label=f"Mean: {mean_val:.4f}"
        # )
        # axs[i].axvline(
        #     quantiles[0],
        #     color="green",
        #     linestyle=":",
        #     label=f"2.5% Quantile: {quantiles[0]:.4f}",
        # )
        # axs[i].axvline(
        #     quantiles[1],
        #     color="green",
        #     linestyle=":",
        #     label=f"97.5% Quantile: {quantiles[1]:.4f}",
        # )
        axs[i].set_title(plot_titles[i])

    plt.tight_layout()
    plt.show()
    results_df = pd.DataFrame(results)
    print("\nDerived Metrics with Mean and 95% CI:")
    print(results_df)

    return derived_metrics


# Example usage:
priors = {
    "smear_positive_death_rate": bcm.priors["smear_positive_death_rate"],
    "smear_positive_self_recovery": bcm.priors["smear_positive_self_recovery"],
    "smear_negative_death_rate": bcm.priors["smear_negative_death_rate"],
    "smear_negative_self_recovery": bcm.priors["smear_negative_self_recovery"],
}

# Sample from the priors, calculate derived metrics, and plot
derived_metrics = sample_priors_and_plot_derived(priors)

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm, gaussian_kde
import pandas as pd
import jax.random as random

# Helper functions
def calculate_derived_metrics(death_rate, recovery_rate, natural_death_rate, time_period):
    """Calculate derived disease duration and CFR."""
    disease_duration = 1 / (death_rate + recovery_rate + natural_death_rate)
    term1 = (recovery_rate / (recovery_rate + death_rate)) * np.exp(-natural_death_rate * time_period)
    term2 = (death_rate / (recovery_rate + death_rate)) * np.exp(-(recovery_rate + death_rate + natural_death_rate) * time_period)
    cfr = 1 - term1 - term2
    return disease_duration, cfr

def sample_truncated_normal(mean, stdev, trunc_range, num_samples=100000):
    """Sample from a truncated normal distribution."""
    a, b = (trunc_range[0] - mean) / stdev, (trunc_range[1] - mean) / stdev
    return truncnorm(a, b, loc=mean, scale=stdev).rvs(num_samples)

# Integrated function to sample priors, calculate derived metrics, and plot both prior and posterior
def plot_derived_comparison(priors, posterior_metrics, params_name, num_samples=10000):
    """
    Plot comparison of derived outputs (disease duration and CFR) between priors and posteriors.

    Args:
        priors: Dictionary of prior distributions for the derived metrics.
        posterior_metrics: Dictionary of posterior samples for the derived metrics.
        params_name: Dictionary mapping parameter names to descriptive titles.
        num_samples: Number of samples to draw from each prior.

    Returns:
        The figure object and prints the derived metrics table (mean, 2.5% and 97.5% quantiles).
    """
    # Derived parameters to compare
    derived_vars = ['duration_positive', 'cfr_positive', 'duration_negative', 'cfr_negative']
    num_vars = len(derived_vars)
    num_rows = (num_vars + 1) // 2  # Even distribution across two columns

    # Set up plot for derived metrics
    fig, axs = plt.subplots(num_rows, 2, figsize=(28, 6.2 * num_rows))
    axs = axs.ravel()

    # Titles for the plots
    plot_titles = [
        "Duration (Smear Positive)", 
        "CFR (Smear Positive)", 
        "Duration (Smear Negative)", 
        "CFR (Smear Negative)"
    ]
    
    results = []

    # Define universal death rate (assumed)
    universal_death = 0.01  # You need to adjust this value according to your model

    for i_ax, ax in enumerate(axs):
        if i_ax < num_vars:
            var_name = derived_vars[i_ax]

            # Posterior samples
            posterior_samples = posterior_metrics[f'post_{var_name}'].flatten()
            low_post = np.min(posterior_samples)
            high_post = np.max(posterior_samples)
            x_vals_posterior = np.linspace(low_post, high_post, 100)
            post_kde = gaussian_kde(posterior_samples)
            posterior_density = post_kde(x_vals_posterior)

            # Sample from the priors and calculate derived metrics
            case = 'positive' if 'positive' in var_name else 'negative'
            death_rate_key = f'smear_{case}_death_rate'
            recovery_rate_key = f'smear_{case}_self_recovery'
            
            # Sample death rate and recovery rate
            samples_death_rate = sample_truncated_normal(
                priors[death_rate_key].mean, priors[death_rate_key].stdev, priors[death_rate_key].trunc_range, num_samples
            )
            recovery_rate = sample_truncated_normal(
                priors[recovery_rate_key].mean, priors[recovery_rate_key].stdev, priors[recovery_rate_key].trunc_range, num_samples
            )
            
            # Calculate derived metrics (duration and CFR) for each sample
            derived_metrics = []
            for death_rate, rec_rate in zip(samples_death_rate, recovery_rate):
                duration, cfr = calculate_derived_metrics(death_rate, rec_rate, universal_death, 1)
                derived_metrics.append(duration if 'duration' in var_name else cfr)

            derived_metrics = np.array(derived_metrics)

            # Prior density calculation
            prior_kde = gaussian_kde(derived_metrics)
            low_prior = np.min(derived_metrics)
            high_prior = np.max(derived_metrics)
            x_vals_prior = np.linspace(low_prior, high_prior, 100)
            prior_density = prior_kde(x_vals_prior)

            # Plot prior and posterior distributions
            ax.fill_between(x_vals_prior, prior_density, color="k", alpha=0.2, linewidth=2, label="Prior")
            ax.plot(x_vals_posterior, posterior_density, color="b", linewidth=1, linestyle="solid", label="Posterior")

            # Set the title using the descriptive name from params_name
            title = params_name.get(var_name, var_name)
            ax.set_title(title, fontsize=30, fontname='Arial')
            ax.tick_params(axis='both', labelsize=24)

            # Calculate the mean and 95% CI (0.025 and 0.975 quantiles)
            mean_val = np.mean(derived_metrics)
            quantiles = np.percentile(derived_metrics, [2.5, 97.5])

            # Append the results to the table list
            results.append({
                "Metric": plot_titles[i_ax],
                "Mean": f"{mean_val:.3f}",
                "2.5% Quantile": f"{quantiles[0]:.3f}",
                "97.5% Quantile": f"{quantiles[1]:.3f}"
            })

            # Add legend to the first subplot
            if i_ax == 0:
                ax.legend(fontsize=24)

        else:
            ax.axis("off")  # Turn off empty subplots if there are extra axes

    # Adjust padding and spacing
    plt.tight_layout()
    plt.show()

    # Create a DataFrame to display the results
    results_df = pd.DataFrame(results)
    print("\nDerived Metrics with Mean and 95% CI:")
    print(results_df)


# Example usage:
priors = {
    "smear_positive_death_rate": bcm.priors["smear_positive_death_rate"],
    "smear_positive_self_recovery": bcm.priors["smear_positive_self_recovery"],
    "smear_negative_death_rate": bcm.priors["smear_negative_death_rate"],
    "smear_negative_self_recovery": bcm.priors["smear_negative_self_recovery"],
}

# Assuming posterior_metrics contains the posterior samples for the derived metrics


In [7]:
def process_idata_for_derived_metrics(idata):
    """
    Extract the necessary posterior samples from idata and calculate derived metrics.

    Args:
        idata: ArviZ InferenceData containing the posterior samples.

    Returns:
        A dictionary containing the derived metrics for both smear-positive and smear-negative cases.
    """
    # Extract posterior samples
    death_rate_pos = idata.posterior['smear_positive_death_rate'].values
    recovery_rate_pos = idata.posterior['smear_positive_self_recovery'].values
    death_rate_neg = idata.posterior['smear_negative_death_rate'].values
    recovery_rate_neg = idata.posterior['smear_negative_self_recovery'].values

    # Calculate derived metrics for smear-positive and smear-negative cases
    post_duration_pos, post_cfr_pos = calculate_derived_metrics(death_rate_pos, recovery_rate_pos, universal_death[2023], 1)
    post_duration_neg, post_cfr_neg = calculate_derived_metrics(death_rate_neg, recovery_rate_neg, universal_death[2023], 1)

    # Return dictionary of derived posterior metrics
    return {
        'post_duration_positive': post_duration_pos.flatten(),
        'post_cfr_positive': post_cfr_pos.flatten(),
        'post_duration_negative': post_duration_neg.flatten(),
        'post_cfr_negative': post_cfr_neg.flatten()
    }

In [None]:
posterior_metrics = process_idata_for_derived_metrics(idata)

# Plot comparison of priors and posteriors
plot_derived_comparison(priors, posterior_metrics, params_name)