In [None]:
from model.calibration import fit_model, check_fit
from model.optimisation import optimise_interventions
from plotting.plots import plot_future_trajectories, make_intervention_piechart


def run_analysis(target_incidence=100, minimised_outcomes=["incidence_per100k", "cumulative_future_deaths"]):

    bcm, mle_params = fit_model(target_incidence=target_incidence)
    check_fit(bcm, mle_params)
    derived_outputs = {"baseline": bcm.run(mle_params | {"decision_var_trans": 0., "decision_var_cdr": 0., "decision_var_pt": 0.}).derived_outputs}

    opti_decision_vars = {}
    for minimised_indicator in minimised_outcomes:
        opti_bcm, opti_params = optimise_interventions(mle_params, minimised_indicator=minimised_indicator)
        res = opti_bcm.run(opti_params)
        derived_outputs[minimised_indicator] = res.derived_outputs
        opti_decision_vars[minimised_indicator] = opti_params

    return derived_outputs, opti_decision_vars, mle_params

In [None]:
from pathlib import Path
import pickle

def store_analysis_outputs(master_derived_outputs, master_opti_decision_vars, master_mle_params, folder_name = "test"):
    folder = Path.cwd() / "store" / folder_name
    folder.mkdir(exist_ok=True)

    for data, filename in zip([master_derived_outputs, master_opti_decision_vars, master_mle_params], ["master_derived_outputs", "master_opti_decision_vars", "master_mle_params"]):
        full_path = folder / f"{filename}.pickle"
        with open(full_path, 'wb') as f:
            pickle.dump(data, f)




In [None]:
master_derived_outputs, master_opti_decision_vars, master_mle_params = {}, {}, {}
for incidence in [50, 100, 200, 500, 1000]:
    print(f"Running for inc={incidence}")
    derived_outputs, opti_decision_vars, mle_params = run_analysis(target_incidence=incidence, minimised_outcomes=["incidence_per100k", "cumulative_incidence", "tb_deaths", "cumulative_future_deaths"])
    master_derived_outputs[incidence] = derived_outputs
    master_opti_decision_vars[incidence] = opti_decision_vars
    master_mle_params[incidence] = mle_params


In [None]:
store_analysis_outputs(master_derived_outputs, master_opti_decision_vars, folder_name = "test")

In [None]:
stored_outputs = load_analysis_outputs("test")
master_derived_outputs, master_opti_decision_vars = stored_outputs["master_derived_outputs"], stored_outputs["master_opti_decision_vars"]

In [None]:
from matplotlib import pyplot as plt
import numpy as np

from plotting.plots import output_names, intervention_names

sc_titles = {
    'baseline': 'no intervention', 'incidence_per100k': 'minimising incidence', 
    'cumulative_future_deaths': 'minimising cumulative deaths',
    'cumulative_incidence': 'minimising cumulative incidence',
    'tb_deaths': 'minimising mortality'}

sc_titles_split = {
    'incidence_per100k': 'minimising\nincidence', 'cumulative_future_deaths': 'minimising\ncumulative\ndeaths',
    'cumulative_incidence': 'minimising\ncumulative\nincidence', 'tb_deaths': 'minimising\nmortality'
    }
sc_colors = {'baseline': 'black', 'incidence_per100k': 'tomato', 'cumulative_future_deaths': 'cornflowerblue', 'cumulative_incidence': 'forestgreen', 'tb_deaths': 'orange'}

ls = {'baseline': '--', 'incidence_per100k': '-', 'cumulative_future_deaths': ':', 'cumulative_incidence': '-.', 'tb_deaths': '-'}

def make_multi_analysis_figure(master_derived_outputs, master_opti_decision_vars):
    plt.rcParams['font.family'] = 'Times New Roman'

    n_analyses = len(master_derived_outputs)
    assert n_analyses > 1, "Plotting code is not compatible with single analysis"
    fig, axs = plt.subplots(n_analyses, 5, figsize=(15, 3.5 * n_analyses), gridspec_kw={'width_ratios': [1, 6, 5, 2, 2]})

    # Analysis title  |  Optimal plan  |  Incidence trajectories  |  Cum TB deaths  |  Cum Paed TB deaths
    for i_row, inc in enumerate(master_derived_outputs):
        derived_outputs, opti_decision_vars = master_derived_outputs[inc], master_opti_decision_vars[inc]

        # Title
        ax = axs[i_row][0]
        if i_row == 0:
            ax.set_title("Baseline TB incidence\n(/100,000 persons/year)")
        ax.text(0.5, 0.5, inc, rotation=0, ha='center', va='center', fontsize=15)
        ax.axis("off")

        # Optimal intervention plan
        ax = axs[i_row][1]
        labels = [intervention_names[dec_var.split("decision_var_")[1]].replace(" ", "\n") for dec_var in list(opti_decision_vars['incidence_per100k'].keys())] 
        bar_width = 0.35
        x_positions = np.arange(len(labels))
        offset = 0.
        for minimised_indicator, opti_decision_var_dict in opti_decision_vars.items():            
            x_pos = [x + offset for x in x_positions]
            ax.bar(x_pos, [100.*v for v in opti_decision_var_dict.values()], color=sc_colors[minimised_indicator], width=bar_width/len(opti_decision_vars), edgecolor='grey', label=sc_titles_split[minimised_indicator])
            offset += bar_width / len(opti_decision_vars) # 0 if minimised_indicator == 'incidence_per100k' else bar_width


        ax.axhline(y=1., color='grey', linestyle='--')
        ax.set_title("Optimal intervention plan")
        ax.set_xticks([x + bar_width/2 for x in range(len(labels))], labels)
        ax.set_ylabel('Intervention coverage (%)')
        if i_row == 0:
            ax.legend(loc='upper left')

        # Incidence trajectories 
        ax = axs[i_row][2]
        xmin = 2023
        ymax = 0.
        output="incidence_per100k"
        for sc_name, derived_df in derived_outputs.items():
            derived_df[output].loc[xmin:].plot(label=sc_titles[sc_name], ax=ax, color=sc_colors[sc_name], linestyle=ls[sc_name])
            ymax = max(ymax, derived_df[output].loc[xmin:].max())

        xtick_years = [2025, 2030, 2035, 2040],
        # ax.set_xticks(xtick_years, xtick_years)
        ax.set_ylabel(output_names[output])
        ax.set_ylim((0, 1.55 * ymax))
        ax.legend()

        # Bar plots
        for j, output in enumerate(["cumulative_future_deaths", "cumulative_future_paed_deaths"]):
            ax = axs[i_row][3 + j]
            names = [sc_titles_split[sc_name] for sc_name in derived_outputs if sc_name != 'baseline']
            values = [derived_outputs[sc_name][output].loc[2040] for sc_name in derived_outputs if sc_name != 'baseline']
            colors = [sc_colors[sc_name] for sc_name in derived_outputs if sc_name != 'baseline']

            ax.bar(names, values, color=colors)
            ax.set_ylabel(output_names[output])
            ax.tick_params(axis='x', labelrotation=45)

        # Cum TB deaths
        ax = axs[i_row][3]

        # Cum Paed TB deaths
        ax = axs[i_row][4]
        
    fig.tight_layout()
    
    return fig

fig = make_multi_analysis_figure(master_derived_outputs, master_opti_decision_vars)
plt.savefig('multi.png', dpi=300)


In [None]:
master_opti_decision_vars[200]

In [None]:
def get_main_numbers(master_derived_outputs, low_inc=200, high_inc=1000):
    for incidence in [low_inc, high_inc]:
        print(f"incidence: {incidence}")
        derived_outputs = master_derived_outputs[incidence]

        mortality_when_min_mortality = derived_outputs["cumulative_future_deaths"]["cumulative_future_deaths"].loc[2040]
        mortality_when_min_incidence = derived_outputs["incidence_per100k"]["cumulative_future_deaths"].loc[2040]
        perc_greater = 100. * (mortality_when_min_incidence - mortality_when_min_mortality) / mortality_when_min_mortality
        print(f"optimising for TB incidence led to an estimated {round(perc_greater)}% higher cumulative TB mortality compared to minimising cumulative mortality") 

        incidence_when_min_mortality = derived_outputs["cumulative_future_deaths"]["incidence_per100k"].loc[2040]
        incidence_when_min_incidence = derived_outputs["incidence_per100k"]["incidence_per100k"].loc[2040]
        perc_greater = 100. * (incidence_when_min_mortality - incidence_when_min_incidence) / incidence_when_min_incidence
        print(f"TB incidence in 2040 was only {round(perc_greater)}% higher when minimising cumulative mortality") 


In [None]:
get_main_numbers(master_derived_outputs, 100, 500)

In [None]:
get_main_numbers(master_derived_outputs, 100, 500)

In [None]:
master_derived_outputs[100].keys()

In [None]:
from matplotlib import pyplot as plt

sc_titles = {'baseline': 'no intervention', 'incidence_per100k': 'minimising incidence', 'cumulative_future_deaths': 'minimising cumulative deaths'}
 
sc_titles_split = {'incidence_per100k': 'minimising\nincidence', 'cumulative_future_deaths': 'minimising\ncumulative\ndeaths'}
sc_colors = {'baseline': 'black', 'incidence_per100k': 'tomato', 'cumulative_future_deaths': 'cornflowerblue'}
ls = {'baseline': '--', 'incidence_per100k': '-', 'cumulative_future_deaths': '-'}

def plot_abstract_figure(derived_outputs, output="incidence_per100k"):
    fig, axs = plt.subplots(1, 2, figsize=(6, 3.5), gridspec_kw={'width_ratios': [5, 2]})

    # Add content to the subplots (replace with your data)
    xmin = 2023
    ymax = 0.
    for sc_name, derived_df in derived_outputs.items():
        derived_df[output].loc[xmin:].plot(label=sc_titles[sc_name], ax=axs[0], color=sc_colors[sc_name], linestyle=ls[sc_name])
        ymax = max(ymax, derived_df[output].loc[xmin:].max())

    axs[0].set_ylabel(output_names[output])
    axs[0].set_ylim((0, 1.55 * ymax))
    axs[0].legend()

    # Bar plot
    names = [sc_titles_split[sc_name] for sc_name in derived_outputs if sc_name != 'baseline']
    values = [derived_outputs[sc_name]['cumulative_future_deaths'].loc[2040] for sc_name in derived_outputs if sc_name != 'baseline']
    colors = [sc_colors[sc_name] for sc_name in derived_outputs if sc_name != 'baseline']

    axs[1].bar(names, values, color=colors)
    axs[1].set_ylabel('TB deaths over 2025-2040')
    plt.xticks(rotation=45)
    # axs[1].set_title('Panel 2')

    # Adjust layout to prevent overlap
    plt.tight_layout()

    # Show the plot
    # plt.show()
    plt.savefig('abstract_figure.png', dpi=100) 

plot_abstract_figure(master_derived_outputs[100])


In [None]:
derived_outputs, opti_decision_vars = run_analysis(target_incidence=50)

In [None]:
opti_decision_vars

In [None]:
for sc, opti_vars in opti_decision_vars.items():
    print(sc)
    print(sum(opti_vars.values()))

In [None]:
output_names = {
    "incidence_per100k": "TB incidence (/100,000/y)",
    "ltbi_prevalence_perc": "LTBI prevalence (%)",
    "cumulative_future_deaths": "Cumulative TB deaths",
    "cumulative_future_paed_deaths": "Cumulative paediatric TB deaths",

}

from matplotlib import pyplot as plt


def plot_optimised_trajectories(derived_outputs, output="incidence_per100k", ax=None):

    if not ax:
        fig, ax = plt.subplots(1, 1)

    xmin = 2020
    ymax = 0.
    for sc_name, derived_df in derived_outputs.items():
        derived_df[output].loc[xmin:].plot(label=sc_name, )
        ymax = max(ymax, derived_df[output].loc[xmin:].max())

    ax.set_ylabel(output_names[output])
    ax.set_ylim((0, 1.2 * ymax))
    ax.legend()


plot_optimised_trajectories(derived_outputs)

In [None]:
plot_optimised_trajectories(master_derived_outputs[100], output="incidence_per100k")
plot_optimised_trajectories(master_derived_outputs[100], output="cumulative_future_deaths")

In [None]:
plot_optimised_trajectories(master_derived_outputs[500], output="incidence_per100k")
plot_optimised_trajectories(master_derived_outputs[500], output="cumulative_future_deaths")

In [None]:
from plotting.plots import make_intervention_piechart
for sc_name, opti_vars in opti_decision_vars.items():
    ax = make_intervention_piechart(opti_vars)
    ax.set_title(sc_name)

In [None]:
# relative difference in incidence
def get_relative_diff(derived_outputs, output, scenarios=["incidence_per100k", "cumulative_future_deaths"]):

    output_1 = derived_outputs[scenarios[1]][output].loc[2040]
    output_0 = derived_outputs[scenarios[0]][output].loc[2040]

    print(output)
    for sc in scenarios:
        print(f"Minimising {sc}: {derived_outputs[sc][output].loc[2040]}")


    return 100 * (output_1 - output_0) / output_0



print(get_relative_diff(derived_outputs, "incidence_per100k", ["cumulative_future_deaths", "incidence_per100k"]))
print()
print(get_relative_diff(derived_outputs, "cumulative_future_deaths", ["incidence_per100k", "cumulative_future_deaths"]))

