In [None]:
from typing import Dict
import numpy as np
import arviz as az
from summer2 import CompartmentalModel
from summer2.functions.time import get_sigmoidal_interpolation_function
from summer2.parameters import Parameter

from tbdynamics.tools.inputs import get_birth_rate, get_death_rate, process_death_rate
from tbdynamics.constants import COMPARTMENTS, INFECTIOUS_COMPARTMENTS, AGE_STRATA, QUANTILES
from tbdynamics.camau.outputs import request_model_outputs
from tbdynamics.camau.strats import get_organ_strat, get_act3_strat, get_age_strat
from tbdynamics.tools.detect import get_detection_func
from tbdynamics.camau.model import (
    seed_infectious,
    add_latency_flow,
    add_infection_flow,
    add_treatment_related_outcomes,
)
from tbdynamics.camau.constants import ACT3_STRATA
from tbdynamics.tools.inputs import get_mix_from_strat_props, load_params, matrix
from tbdynamics.settings import CM_PATH, OUT_PATH, DOCS_PATH
from tbdynamics.camau.calibration.utils import get_bcm
import estival.sampling.tools as esamp
import pandas as pd

pd.options.plotting.backend = "plotly"
import plotly.graph_objects as go

PLACEHOLDER_PARAM = 0.0

In [2]:
def build_model(
    fixed_params: Dict[str, any],
    matrix: np.ndarray,
    covid_effects: Dict[str, bool],
    implement_act3: bool = True
) -> CompartmentalModel:
    """
    Builds a compartmental model for TB transmission, incorporating infection dynamics,
    treatment, and stratifications for age, organ status, and ACT3 trial arms.

    Args:
        fixed_params: Fixed parameter dictionary (e.g., time range, population size).
        matrix: Age-mixing matrix for contact patterns.
        covid_effects: Effects of COVID-19 on TB detection and transmission.
        improved_detection_multiplier: Multiplier for improved case detection.

    Returns:
        A configured CompartmentalModel instance.
    """
    model = CompartmentalModel(
        times=(fixed_params["time_start"], fixed_params["time_end"]),
        compartments=COMPARTMENTS,
        infectious_compartments=INFECTIOUS_COMPARTMENTS,
        timestep=fixed_params["time_step"],
    )

    birth_rates = get_birth_rate()
    death_rates = get_death_rate()
    death_df = process_death_rate(death_rates, AGE_STRATA, birth_rates.index)
    model.set_initial_population({"susceptible": Parameter("start_population_size")})
    seed_infectious(model)
    crude_birth_rate = get_sigmoidal_interpolation_function(
        birth_rates.index, birth_rates.values
    )
    model.add_crude_birth_flow("birth", crude_birth_rate, "susceptible")

    # model.add_universal_death_flows(
    #     "universal_death", PLACEHOLDER_PARAM
    # )  # Adjust later in age strat
    add_infection_flow(model, covid_effects["contact_reduction"])
    add_latency_flow(model)
    model.add_transition_flow(
        "self_recovery", PLACEHOLDER_PARAM, "infectious", "recovered"
    )  # Adjust later in organ strat
    model.add_transition_flow(
        "detection", PLACEHOLDER_PARAM, "infectious", "on_treatment"
    )
    add_treatment_related_outcomes(model)
    model.add_death_flow(
        "infect_death", PLACEHOLDER_PARAM, "infectious"
    )  # Adjust later organ strat

    age_strat = get_age_strat(death_df, fixed_params, matrix)
    model.stratify_with(age_strat)

    detection_func = get_detection_func(covid_effects["detection_reduction"])

    organ_strat = get_organ_strat(fixed_params, detection_func)
    model.stratify_with(organ_strat)
    if implement_act3:
        act3_strat = get_act3_strat(COMPARTMENTS, fixed_params)
        model.stratify_with(act3_strat)

    request_model_outputs(model, covid_effects["detection_reduction"])

    return model


In [3]:
fixed_params = load_params(CM_PATH / "params.yml")
covid_effects = {
    'detection_reduction':True,
    'contact_reduction':False
}

In [4]:
idata_raw = az.from_netcdf(OUT_PATH / 'camau/r0204/calib_full_out.nc')

In [5]:
burnt_idata = idata_raw.sel(draw=np.s_[50000:])
idata_extract = az.extract(burnt_idata, num_samples=500)

In [6]:
params = {
    "start_population_size": 30000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
    # "contact_rate": 0.02,
    "rr_infection_latent": 0.1890473700762809,
    "rr_infection_recovered": 0.17781844797545143,
    "smear_positive_death_rate": 0.3655528915762244,
    "smear_negative_death_rate": 0.027358324164819155,
    "smear_positive_self_recovery": 0.18600338108638945,
    "smear_negative_self_recovery": 0.11333894801537307,
    "screening_scaleup_shape": 0.3,
    "screening_inflection_time": 1993,
    # "time_to_screening_end_asymp": 2.1163556520843936,
    "acf_sensitivity": 0.90,
    # "prop_mixing_same_stratum": 0.6920672992582717,
    # "early_prop_adjuster": -0.017924441638418186,
    # "late_reactivation_adjuster": 1.1083422207175728,
    "detection_reduction": 0.30,
    # "total_population_dispersion": 3644.236227852164,
    # "notif_dispersion": 88.37092488550051,
    # "latent_dispersion": 7.470896188551709,
}

In [7]:
bcm = get_bcm(params, covid_effects)

In [8]:
base_results = esamp.model_results_for_samples(idata_extract, bcm).results

In [None]:
base_quantiles = esamp.quantiles_for_results(base_results, QUANTILES)

In [13]:
fig = go.Figure()

# Loop through all ages and add them to the same figure
[
    fig.add_trace(
        go.Scatter(
            x=base_quantiles[f"early_activation_age_{age}"].loc[2010:2025, 0.500].index,
            y=base_quantiles[f"early_activation_age_{age}"].loc[2010:2025, 0.500].values,
            mode='lines',
            name=f'Age {age}'
        )
    )
    for age in AGE_STRATA
]

fig.update_layout(
    title="Early Activation by Age Group",
    xaxis_title="Time",
    yaxis_title="Rate",
    legend_title="Age group",
    template="simple_white",
    width=800,
    height=600,
)

fig.show()


In [11]:
fig = go.Figure()

# Loop through all ages and add them to the same figure
[
    fig.add_trace(
        go.Scatter(
            x=base_quantiles[f"late_activation_age_{age}"].loc[:, 0.500].index,
            y=base_quantiles[f"late_activation_age_{age}"].loc[:, 0.500].values,
            mode='lines',
            name=f'Age {age}'
        )
    )
    for age in AGE_STRATA
]

fig.update_layout(
    title="Late Activation by Age Group",
    xaxis_title="Time",
    yaxis_title="Rate",
    legend_title="Age group",
    template="simple_white",
    width=800,
    height=600,
)

fig.show()
