In [None]:
from typing import Dict
import numpy as np
from summer2 import CompartmentalModel
from summer2.functions.time import get_sigmoidal_interpolation_function
from summer2.parameters import Parameter, Function, Time

from tbdynamics.tools.utils import triangle_wave_func
from tbdynamics.tools.inputs import get_birth_rate, get_death_rate, process_death_rate
from tbdynamics.constants import COMPARTMENTS, INFECTIOUS_COMPARTMENTS, AGE_STRATA
from tbdynamics.camau.outputs import request_model_outputs
from tbdynamics.camau.strats import get_organ_strat, get_act3_strat, get_age_strat
from tbdynamics.tools.detect import get_detection_func
from tbdynamics.camau.model import seed_infectious, add_latency_flow, add_infection_flow, add_treatment_related_outcomes

PLACEHOLDER_PARAM = 0.0

In [2]:
def build_model(
    fixed_params: Dict[str, any],
    matrix: np.ndarray,
    covid_effects: Dict[str, bool],
    implement_act3: bool = True
) -> CompartmentalModel:
    """
    Builds a compartmental model for TB transmission, incorporating infection dynamics,
    treatment, and stratifications for age, organ status, and ACT3 trial arms.

    Args:
        fixed_params: Fixed parameter dictionary (e.g., time range, population size).
        matrix: Age-mixing matrix for contact patterns.
        covid_effects: Effects of COVID-19 on TB detection and transmission.
        improved_detection_multiplier: Multiplier for improved case detection.

    Returns:
        A configured CompartmentalModel instance.
    """
    model = CompartmentalModel(
        times=(fixed_params["time_start"], fixed_params["time_end"]),
        compartments=COMPARTMENTS,
        infectious_compartments=INFECTIOUS_COMPARTMENTS,
        timestep=fixed_params["time_step"],
    )

    birth_rates = get_birth_rate()
    death_rates = get_death_rate()
    death_df = process_death_rate(death_rates, AGE_STRATA, birth_rates.index)
    model.set_initial_population({"susceptible": Parameter("start_population_size")})
    seed_infectious(model)
    crude_birth_rate = get_sigmoidal_interpolation_function(
        birth_rates.index, birth_rates.values
    )
    model.add_crude_birth_flow("birth", crude_birth_rate, "susceptible")

    # model.add_universal_death_flows(
    #     "universal_death", PLACEHOLDER_PARAM
    # )  # Adjust later in age strat
    add_infection_flow(model, covid_effects["contact_reduction"])
    add_latency_flow(model)
    model.add_transition_flow(
        "self_recovery", PLACEHOLDER_PARAM, "infectious", "recovered"
    )  # Adjust later in organ strat
    model.add_transition_flow(
        "detection", PLACEHOLDER_PARAM, "infectious", "on_treatment"
    )
    add_treatment_related_outcomes(model)
    model.add_death_flow(
        "infect_death", PLACEHOLDER_PARAM, "infectious"
    )  # Adjust later organ strat

    age_strat = get_age_strat(death_df, fixed_params, matrix)
    model.stratify_with(age_strat)

    detection_func = get_detection_func(covid_effects["detection_reduction"])

    organ_strat = get_organ_strat(fixed_params, detection_func)
    model.stratify_with(organ_strat)
    if implement_act3:
        act3_strat = get_act3_strat(COMPARTMENTS, fixed_params)
        model.stratify_with(act3_strat)

    request_model_outputs(model, covid_effects["detection_reduction"])

    return model


In [None]:
params = {'contact_rate': 0.012679436231766055,
 'rr_infection_latent': 0.3294932508745867,
 'rr_infection_recovered': 0.30030780867913565,
 'smear_positive_death_rate': 0.3783502503918692,
 'smear_negative_death_rate': 0.03100638773988429,
 'smear_positive_self_recovery': 0.18819938594539462,
 'smear_negative_self_recovery': 0.14983813581079153,
 'time_to_screening_end_asymp': 1.798644977945414,
 'early_prop_adjuster': 0.07113833294294106,
 'late_reactivation_adjuster': 1.3209884326164134,
 'detection_reduction': 0.28243429069025955,
 'total_population_dispersion': 2401.833907561055,
 'notif_dispersion': 0.07279813928865936,
 'latent_dispersion': 11.438199850872712}

In [None]:
model = build_model