In [None]:
import os
import pickle
import warnings
from itertools import combinations

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from sklearn import metrics
from sklearn.model_selection import RepeatedStratifiedKFold
from tqdm import tqdm

from config import (
    ALERE_RESULTS_DIR,
    CLINICAL_MBV_FILES,
    CLINICAL_TEST_FILE,
    LOD_COL_FMT,
    META_ANALYSIS_FILE,
    OUTPUT_DIR,
    PROCESSED_DIR,
)
from utils import Barcode2, LogisticGAM, NestedCV  # noqa


## Preliminaries

In [None]:
def parse_dt(arg):
    if pd.isna(arg):
        return pd.NaT
    try:
        if isinstance(arg, (int, float)) or (isinstance(arg, str) and "-" not in arg):
            arg = int(float(arg))
            return pd.to_datetime(arg, origin="1899-12-30", unit="D")
        return pd.to_datetime(arg)
    except pd.errors.OutOfBoundsDatetime:
        return pd.NaT


clinical_kwargs = {
    "true_values": ["HIV+", "Positive", "Pos", "Yes", "MTB pos", "MTB POS", "Detected"],
    "false_values": [
        "HIV-",
        "Negative",
        "Neg",
        "No",
        "MTB neg",
        "MTB NEG",
        "Not detected",
        "MTB pos",
    ],
    "na_values": [
        "Not done",
        "Contaminated/lost",
        "Contaminated/Lost",
        "Not known",
        "Don't Know",
        "Not Applicable",
        "Not applicable",
        "Not Done/Not applicable",
        "Not done/not applicable",
        "Not done/Not applicable",
        "Indeterminate  (invalid/error/no result)",
        "Invalid/error/no result",
        "Indeterminate",
        "I-Indeterminate",
        "No result/Indeterminate",
        "ND",
    ],
}

# Read data files
mb_100 = pd.read_excel(
    CLINICAL_MBV_FILES[0], sheet_name="URINE SAMPLES", **clinical_kwargs
)
mb_100["Cohort"] = "training"
val_320 = pd.read_excel(CLINICAL_MBV_FILES[1], **clinical_kwargs)
val_320["Cohort"] = "validation"

test_244 = pd.read_excel(CLINICAL_TEST_FILE, **clinical_kwargs).rename(
    columns={"age": "age of sample (Years)"}
)
test_244["Cohort"] = "test"

mbv = pd.read_csv(PROCESSED_DIR / "mbv.csv", index_col=0)
val_320.loc[
    val_320["OS_PatientID"].isin(
        mbv["OS_PatientID"][(mbv["Cohort"] == "training") & (mbv["p_cat"] == "S-C+")]
    ),
    "Cohort",
] = "training"

In [None]:
clinical = (
    pd.concat([mb_100, val_320, test_244], ignore_index=True)
    .dropna(axis=1, how="all")
    .drop(
        columns=[
            "SP1_SC_REP_RESULT",
            "SP2_SC_REP_RESULT",
            "SP3_LC_REP_RESULT",
            "OS_Specimen_Type",
            "OS_Requirement_Name",
            "SP2_LC_REP_RESULT",
            "SP_RDST_SP_TYPE",
            "SP_DST_SP_TYPE",
        ]
    )
)
dt_cols = clinical.columns[clinical.columns.str.endswith("_D")]
for col in dt_cols:
    clinical[col] = clinical[col].map(parse_dt)
# Replace mistaken dates in TB_TX_HX_START_D and TB_TX_HX_END_D
for col in ["TB_TX_HX_START_D", "TB_TX_HX_END_D"]:
    clinical[col] = clinical[col].mask(clinical[col].dt.year < clinical["YOB"])
barcode_parsed = clinical["barcode"].map(Barcode2)
clinical = pd.concat(
    [
        clinical,
        pd.DataFrame(
            {
                "Age at Enrol": clinical["ENROL_D"].dt.year - clinical["YOB"],
                "Barcode Parsed": barcode_parsed.map(lambda x: x.standard_form()),
                "Barcode Any Aliquot": barcode_parsed.map(lambda x: x.any_aliquot()),
            }
        ),
    ],
    axis=1,
)

# Harmonize values
RES = ["SP1_LC_REP_RESULT", "QFT_RES"]
clinical[RES] = clinical[RES].replace(
    {"Neg": False, "Pos": True, "Negative": False, "Positive": True}
)
XP_RIF = ["SP1_XP_RIF", "SP1_XP_REPEAT_RIF"]
clinical[XP_RIF] = clinical[XP_RIF].replace({"Not detected": False, "Detected": True})
PREDICTIONS = [
    "hsMSD_PairO_Prediction",
    "hsMSD_PairF__Prediction",
    "R1 Categorical Prediction",
    "R2 Categorical Prediction",
]
clinical[PREDICTIONS] = clinical[PREDICTIONS].replace(
    {"TB": True, "NotTB": False, "Non TB": False}
)
SMEAR_GRADES = clinical.columns[clinical.columns.str.endswith(("GRADE", "Q", "CD4"))]
smear = {"Scanty": 0, "1+": 1, "2+": 2, "3+": 3, "<20": 0, np.nan: pd.NA}
clinical[SMEAR_GRADES] = clinical[SMEAR_GRADES].map(lambda x: smear[x]).convert_dtypes()
clinical["R1 Numeric Prediction"] = (
    clinical["R1 Numeric Prediction"].replace("TMF", np.inf).astype(float)
)
clinical["FU"] = clinical["FU"].replace(
    {
        "Followed up": "Followed Up",
        "No follow up required": "No FU Required",
        "Lost to follow up": "Lost to Follow Up",
    }
)
clinical["CX_DX"] = clinical["CX_DX"].replace(
    {
        "TB likely (pulmonary pleural or pericardial)": "TB likely (pulmonary; pleural; pericardial)",
        "Other Specify": "Other",
    }
)
dsts = clinical.columns[clinical.columns.str.contains("DST")].difference(dt_cols)
clinical[dsts] = clinical[dsts].replace({"S": "S-Sensitive", "R": "R-Resistant"})
clinical = clinical.replace("0", pd.NA)

SYMPTOMS = [
    "COUGH",
    "EXPECTORATION",
    "HEMOPTYSIS",
    "CHEST_PAIN",
    "DYSPNOE",
    "MALAISE",
    "FEVER",
    "SWEATS",
    "WT_LOSS",
    "LYMPH_NODE",
]
clinical[SYMPTOMS] = clinical[SYMPTOMS].replace({"NO": "No"})

# Ordered categories
p_cats = [
    "NonTB_NonLTBI",
    "NonTB_LTBI",
    "Likely_subcl_TB",
    "Clinical_TB",
    "S-C+",
    "S+C+",
]
ordered_categories = [
    ("p_cat", p_cats),
    (SYMPTOMS, ["No", "Yes", "< 2 weeks", "< 2 months", "= 2 months", "≥ 2 months"]),
    ("Cohort", ["training", "validation", "test"]),
    (
        ["hsMSD_PairO_Level", "hsMSD_PairF_Level"],
        ["No LAM", "Low LAM", "Middle LAM", "High LAM"],
    ),
]
for cols, cats in ordered_categories:
    cat_index = pd.Index(cats, dtype="str")
    clinical[cols] = clinical[cols].astype(pd.CategoricalDtype(cat_index, ordered=True))

unordered_categories = [
    dsts,
    clinical.columns[clinical.columns.str.endswith("SPC")],
    ["FU_CXR", "FU_SYMPTOMS", "2M_FU_SYMPTOMS", "4M_FU_SYMPTOMS"],
    ["FU_TB_TX_STATUS", "2M_TB_TX_STATUS"],
    ["Country"],
    ["SEX"],
    ["FU"],
    ["CX_DX"],
    ["STUDY_NAME"],
    ["SP1_XP_REPEAT_RESULT"],
    ["TB_TX_ENROL_SCHEME"],
    ["FU_TB_DIAG"],
]
for cols in unordered_categories:
    cat_index = pd.Index(pd.unique(clinical[cols].values.ravel())).dropna()
    clinical[cols] = clinical[cols].astype(pd.CategoricalDtype(cat_index))

clinical["PID_f"] = clinical["PID_b"] = clinical["OS_PatientID"]
clinical = (
    clinical.groupby("PID_f")
    .ffill()
    .groupby("PID_b")
    .bfill()
    .drop_duplicates(subset="OS_PatientID")
)
clinical.index = clinical["Barcode Any Aliquot"]

p_cat_binary = {
    "NonTB_NonLTBI": 0,
    "S+C+": 1,
    "S-C+": 1,
    "Clinical_TB": 1,
    "Likely_subcl_TB": 0,
    "NonTB_LTBI": 0,
}
clinical["y"] = clinical["p_cat"].map(lambda x: p_cat_binary[x])


X_med = pd.read_csv(PROCESSED_DIR / "X_med.csv", index_col=0)
test_set_pred = pd.read_excel(OUTPUT_DIR / "test_set_predictions.xlsx", index_col=0)[
    ["Estimated TB Probability", "Predicted Diagnosis"]
]
test_set_pred["Predicted Diagnosis"] = test_set_pred["Predicted Diagnosis"] == "TB"
test_set_pred.index = test_set_pred.index.map(lambda x: Barcode2(x).any_aliquot())
alere_dfs = [
    pd.read_csv(
        io,
        index_col=0,
        true_values=["Positive", "Positive "],
        false_values=["Negative", "Negative "],
    ).rename(
        columns={
            f"Sample Barcode{col}": "AlereLAM Barcode"
            for col in [" Number", " Number ", ""]
        }
        | {"Result": "AlereLAM Result", "Notes": "AlereLAM Notes"}
    )
    for io in os.scandir(ALERE_RESULTS_DIR)
    if io.name.endswith(".csv")
]

alere = pd.concat(alere_dfs, ignore_index=True)
alere_replacements = {
    "FIND 05 61 0329 U23 06": "FIND 05 61 0309 U23 06",
    "FIND 05 61 3075 U23 01": "FIND 05 61 0375 U23 01",
    "FIND 05 01 0094 U23 01": "FIND 05 01 0091 U23 01",
    "FIND 05 06 0021 U23 01": "FIND 05 61 0021 U23 01",
    "FIND 05 61 0215 U23 01": "FIND 05 61 0125 U23 01",
}

alere.index = (
    alere["AlereLAM Barcode"]
    .replace(alere_replacements)
    .map(lambda x: Barcode2(x).any_aliquot())
)


def process_notes(x):
    if pd.isna(x):
        return x
    if "faint" in x:
        return "faint line"
    elif "dark" in x:
        return "dark urine"
    else:
        return x.strip()


alere["AlereLAM Notes"] = (
    alere["AlereLAM Notes"].map(process_notes).convert_dtypes().astype("category")
)
clinical["R1 Numeric Prediction"] = clinical["R1 Numeric Prediction"].astype("Float64")
clinical = pd.concat([clinical, X_med, test_set_pred, alere], axis=1).convert_dtypes()

## GAM Spline models and cross-validation

In [None]:
def sens_at_spec(y_true, y_score, spec=1):
    fpr, tpr, __ = metrics.roc_curve(y_true, y_score)
    valid_tpr = tpr[fpr <= 1 - spec]
    return valid_tpr.max() if valid_tpr.size else 0.0


scorers = {
    "roc_auc": metrics.get_scorer("roc_auc"),
    "sens_at_perf_spec": metrics.make_scorer(
        sens_at_spec, response_method=("decision_function", "predict_proba")
    ),
    "balanced_accuracy": metrics.get_scorer("balanced_accuracy"),
}


y_tvt = clinical["y"].loc[clinical[X_med.columns].notna().all(axis=1)]
X_tvt = np.arcsinh(
    clinical.loc[y_tvt.index, X_med.columns] / LOD_COL_FMT["LOD_samples"] * 10
)

### Run nested cross-validation (warning: long)

In [None]:
n_splits = 5
random_state = 0
inner_repeats = 4
outer_repeats = 20
n_trials = 200
negligible = 0.01

outer_cv = RepeatedStratifiedKFold(
    n_splits=n_splits, n_repeats=outer_repeats, random_state=random_state
)
ncv_tvt = NestedCV(
    outer_cv,
    n_trials=n_trials,
    inner_repeats=inner_repeats,
    negligible=negligible,
    n_jobs=1,
)
ncv_tvt.fit(X_tvt, y_tvt.values.ravel())


### Alternative: load existing results

In [None]:
with open(PROCESSED_DIR / "ncv_tvt.pkl", "rb") as f:
    ncv_tvt = pickle.load(f)

## Save results to Excel file with clinical data

In [None]:
clinical_2 = pd.concat(
    [
        clinical,
        ncv_tvt.predicted_proba.mean(axis=1).rename(
            "Average Predicted Probability in Nested CV"
        ),
    ],
    axis=1,
)
bool_cols = clinical_2.columns[clinical_2.dtypes == "boolean"]
clinical_2[bool_cols] = clinical_2[bool_cols].astype("Int8")
clinical_2["OS_PatientID_f"] = clinical_2["OS_PatientID_b"] = clinical_2["OS_PatientID"]
clinical_2.to_excel(OUTPUT_DIR / "clinical_combined.xlsx")

## Figure 2

In [None]:
meta_analysis = pd.read_excel(META_ANALYSIS_FILE, header=[0, 1])


def proportion_ci(p, n, z):
    return p + z * np.sqrt(p * (1 - p) / n)


fill_value = {
    (stat, col): proportion_ci(
        meta_analysis[(stat, "Point")], meta_analysis[(stat, "N")], z
    )
    for stat in ["Sensitivity", "Specificity"]
    for col, z in zip(["Low", "High"], [-1.96, 1.96])
}
meta_analysis.fillna(fill_value, inplace=True)


def threshold_average_with_ci(
    predicted_proba, y_tvt, alpha=0.05, specific_threshold=0.72
):
    thresh_vals = np.arange(0, 1.00001, 0.00001)
    fpr_list = []
    tpr_list = []
    thresh_list = []
    auc_list = []
    tpr_at_thresh_list = []
    fpr_at_thresh_list = []

    for fold in predicted_proba.columns:
        y_scores = predicted_proba[fold].dropna()
        y_true = y_tvt.loc[y_scores.index]

        fpr, tpr, thresholds = metrics.roc_curve(y_true, y_scores)
        auc = metrics.auc(fpr, tpr)

        fpr_list.append(fpr)
        tpr_list.append(tpr)
        thresh_list.append(thresholds)
        auc_list.append(auc)

        # Compute TPR and FPR at the specific threshold
        y_pred = (y_scores >= specific_threshold).astype(int)
        tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
        tpr_at_thresh = tp / (tp + fn)
        fpr_at_thresh = fp / (fp + tn)
        tpr_at_thresh_list.append(tpr_at_thresh)
        fpr_at_thresh_list.append(fpr_at_thresh)

    pr_d = {"fpr": [], "tpr": []}

    for value, key_l in zip(pr_d.values(), [fpr_list, tpr_list]):
        for thresholds_fold, l_fold in zip(thresh_list, key_l):
            # Reverse the thresholds to ensure they are increasing
            thresholds_fold_reversed = thresholds_fold[::-1]
            fold_reversed = l_fold[::-1]
            interp = np.interp(thresh_vals, thresholds_fold_reversed, fold_reversed)
            value.append(interp)

    n_folds = len(fpr_list)
    mean_d = {k: np.mean(v, axis=0) for k, v in pr_d.items()}
    se_d = {k: np.std(v, axis=0, ddof=1) / np.sqrt(n_folds) for k, v in pr_d.items()}

    # Compute standard errors and confidence intervals
    t_crit = stats.t.ppf(1 - alpha / 2, df=n_folds - 1)
    ci_upper = {k: mean_d[k] + t_crit * se_d[k] for k in mean_d}
    ci_lower = {k: mean_d[k] - t_crit * se_d[k] for k in mean_d}

    # Compute mean and confidence intervals for AUC
    avg_auc = np.mean(auc_list)
    se_auc = np.std(auc_list, ddof=1) / np.sqrt(n_folds)
    ci_upper_auc = avg_auc + t_crit * se_auc
    ci_lower_auc = avg_auc - t_crit * se_auc

    # Compute mean and confidence intervals for TPR and FPR at the specific threshold
    mean_tpr_at_thresh = np.mean(tpr_at_thresh_list)
    se_tpr_at_thresh = np.std(tpr_at_thresh_list, ddof=1) / np.sqrt(n_folds)
    ci_lower_tpr_at_thresh = mean_tpr_at_thresh - t_crit * se_tpr_at_thresh
    ci_upper_tpr_at_thresh = mean_tpr_at_thresh + t_crit * se_tpr_at_thresh

    mean_fpr_at_thresh = np.mean(fpr_at_thresh_list)
    se_fpr_at_thresh = np.std(fpr_at_thresh_list, ddof=1) / np.sqrt(n_folds)
    ci_lower_fpr_at_thresh = mean_fpr_at_thresh - t_crit * se_fpr_at_thresh
    ci_upper_fpr_at_thresh = mean_fpr_at_thresh + t_crit * se_fpr_at_thresh

    result = {
        "roc_curve": {
            "mean_fpr": mean_d["fpr"],
            "mean_tpr": mean_d["tpr"],
            "ci_fpr": {"lower": ci_lower["fpr"], "upper": ci_upper["fpr"]},
            "ci_tpr": {"lower": ci_lower["tpr"], "upper": ci_upper["tpr"]},
            "thresholds": thresh_vals,
        },
        "auc": {
            "mean": avg_auc,
            "ci_lower": ci_lower_auc,
            "ci_upper": ci_upper_auc,
        },
        "tpr_at_threshold": {
            "threshold": specific_threshold,
            "mean": mean_tpr_at_thresh,
            "ci_lower": ci_lower_tpr_at_thresh,
            "ci_upper": ci_upper_tpr_at_thresh,
        },
        "fpr_at_threshold": {
            "threshold": specific_threshold,
            "mean": mean_fpr_at_thresh,
            "ci_lower": ci_lower_fpr_at_thresh,
            "ci_upper": ci_upper_fpr_at_thresh,
        },
    }
    return result


overall = pd.Series(data=True, index=y_tvt.index)
hiv_status = {
    "any HIV": overall,
    "HIV–": ~clinical.loc[y_tvt.index, "HIV_status"],
    "HIV+": clinical.loc[y_tvt.index, "HIV_status"],
}
smear_status = {
    "overall": overall,
    "Smear+": ~clinical.loc[y_tvt.index, "p_cat"].isin(["S-C+", "Clinical_TB"]),
    "Smear–": clinical.loc[y_tvt.index, "p_cat"] != "S+C+",
}


In [None]:
# Preliminaries for New Figure 2

color_marker = {
    "TPP": "#491E5D",
    "AlereLAM": "#D98825",
    # "Truenat MTB": ("indianred", "o"),
    "Truenat Plus": "#fe0809",
    # "Xpert MTB/RIF": ("#3377BB", "o"),
    "Xpert Ultra": "#0077c8",
    # "EclLAM": ("#009E73", "o"),
    "Smear microscopy": "#CC79A7",
    # "Smear (ZN)": ("brown", "o"),
    # "Smear (FM)": ("brown", "o"),
    "Culture": "brown",
}

masks = {
    "All": pd.Series(data=True, index=y_tvt.index),
    "HIV-": ~clinical.loc[y_tvt.index, "HIV_status"],
    "HIV+": clinical.loc[y_tvt.index, "HIV_status"],
    "S+": ~clinical.loc[y_tvt.index, "p_cat"].isin(["S-C+", "Clinical_TB"]),
    "S-": clinical.loc[y_tvt.index, "p_cat"] != "S+C+",
}

tpp_sizes = {
    "Optimal": (0.95, 0.05, 1),
    "Low-complexity": (0.8, 0.15, 0.6),
    "Near point of care": (0.75, 0.05, 1 / 3),
    "Point of care": (0.65, 0.1, 0.2),
}

comparison_zorder = {
    "Xpert Ultra": -1,
    "AlereLAM": -2,
    "Truenat Plus": -3,
    "Smear microscopy": -4,
}


def add_errorbar(row):
    x = 1 - row[("Specificity", "Point")]
    y = row[("Sensitivity", "Point")]
    yerr = np.array(
        [row[("Sensitivity", "High")] - y, y - row[("Sensitivity", "Low")]]
    ).reshape((2, 1))
    xerr = np.array(
        [1 - row[("Specificity", "Low")] - x, x - 1 + row[("Specificity", "High")]]
    ).reshape((2, 1))
    label = row[("Condition", "Assay")]
    ms_spec_n = row[("Specificity", "N")]
    if ms_spec_n == 0:
        ms_spec_n = 2858
    markersize = (row[("Sensitivity", "N")] * ms_spec_n) ** 0.5 / 200
    if markersize > 16:
        markersize = 16
    return dict(
        x=x,
        y=y,
        yerr=yerr,
        xerr=xerr,
        marker="o",
        color=color_marker[label],
        label=label,
        markersize=markersize,
        zorder=comparison_zorder[label],
    )


def plot_thresh_ci(predicted_proba, y_tvt, alpha=0.05, specific_threshold=0.72):
    result = threshold_average_with_ci(
        predicted_proba, y_tvt, alpha=alpha, specific_threshold=specific_threshold
    )

    # Extract ROC curve data
    roc_curve = result["roc_curve"]
    fpr = roc_curve["mean_fpr"]
    tpr = roc_curve["mean_tpr"]
    fpr_ci_lower = roc_curve["ci_fpr"]["lower"]
    fpr_ci_upper = roc_curve["ci_fpr"]["upper"]
    tpr_ci_lower = roc_curve["ci_tpr"]["lower"]
    tpr_ci_upper = roc_curve["ci_tpr"]["upper"]

    # Extract AUC data
    auc_data = result["auc"]
    roc_auc = auc_data["mean"]
    auc_ci_lower = auc_data["ci_lower"]
    auc_ci_upper = auc_data["ci_upper"]

    # Plot the mean ROC curve
    roc_mean = (fpr, tpr)

    # Fill between the confidence intervals for TPR
    roc_ci = {"x": fpr, "y1": tpr_ci_lower, "y2": tpr_ci_upper}

    # Plot TPR and FPR at the specific threshold
    tpr_at_thresh = result["tpr_at_threshold"]["mean"]
    tpr_at_thresh_ci_lower = result["tpr_at_threshold"]["ci_lower"]
    tpr_at_thresh_ci_upper = result["tpr_at_threshold"]["ci_upper"]

    fpr_at_thresh = result["fpr_at_threshold"]["mean"]
    fpr_at_thresh_ci_lower = result["fpr_at_threshold"]["ci_lower"]
    fpr_at_thresh_ci_upper = result["fpr_at_threshold"]["ci_upper"]

    # Plot the point at the specific threshold with error bars
    errorbar = {
        "x": fpr_at_thresh,
        "y": tpr_at_thresh,
        "xerr": [
            [fpr_at_thresh - fpr_at_thresh_ci_lower],
            [fpr_at_thresh_ci_upper - fpr_at_thresh],
        ],
        "yerr": [
            [tpr_at_thresh - tpr_at_thresh_ci_lower],
            [tpr_at_thresh_ci_upper - tpr_at_thresh],
        ],
    }

    # Uncomment to fill between FPR confidence intervals
    # ax.fill_betweenx(tpr, fpr_ci_lower, fpr_ci_upper, alpha=0.2, color="C1")
    return roc_mean, roc_ci, errorbar, auc_data


title_weight = "semibold"
legend_kws = {
    "handlelength": 1.2,
    "alignment": "left",
    "fancybox": False,
    "title_fontproperties": {"weight": title_weight},
    "framealpha": 0.5,
}

panel_titles = {
    "All": ("All Samples (N=576)", "A"),
    "S+": ("Smear+", "B"),
    "S-": ("Smear–", "C"),
    "HIV-": ("HIV–", "D"),
    "HIV+": ("HIV+", "E"),
}

In [None]:
fig, axd = plt.subplot_mosaic(
    [["All", "All", "S+"], ["All", "All", "S-"], ["HIV-", "HIV+", "Legend"]],
    figsize=(10, 10),
    layout="constrained",
)

legend_handles = {key: [] for key in masks}
legend_legends = {
    "Nested cross-validation": [],
    "Comparisons from literature": [],
    "WHO target product profile": [],
}

for a, mask in masks.items():
    roc_mean, roc_ci, errorbar, auc_data = plot_thresh_ci(
        ncv_tvt.predicted_proba.loc[mask], y_tvt.loc[mask]
    )
    axd[a].plot(
        *roc_mean, color="k", solid_joinstyle="miter", solid_capstyle="projecting"
    )
    legend_handles[a].append(
        axd[a].plot([], [], "s-", color="k", label=f" AUC = {auc_data['mean']:.3f}")[0]
    )
    legend_handles[a].append(
        axd[a].fill_between(
            **roc_ci,
            color="k",
            alpha=0.2,
            label=f"({auc_data['ci_lower']:.3f}–{auc_data['ci_upper']:.3f})",
            lw=0,
        )
    )
    axd[a].errorbar(**errorbar, fmt="s", color="k")
    axd[a].set(
        xlim=(0, 1),
        ylim=(0, 1),
        xlabel="1–Specificity (False Positive Rate)",
        ylabel="Sensitivity (True Positive Rate)",
        aspect="equal",
    )
    axd[a].set_xticks(np.arange(0, 1.01, 0.02), minor=True)
    axd[a].set_yticks(np.arange(0, 1.01, 0.02), minor=True)
    axd[a].set_xticks(np.arange(0, 1.01, 0.1))
    axd[a].set_yticks(np.arange(0, 1.01, 0.1))
    if a == "All":
        labels = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
    else:
        labels = [0, "", 0.2, "", 0.4, "", 0.6, "", 0.8, "", 1]
    axd[a].set_xticklabels(labels)
    axd[a].set_yticklabels(labels)

# Test Set
data = clinical.loc[y_tvt.index].dropna(subset="Estimated TB Probability")
# Test set ROC curve
fpr, tpr, thresholds = metrics.roc_curve(data["y"], data["Estimated TB Probability"])
auc = metrics.auc(fpr, tpr)
axd["S-"].plot(
    fpr,
    tpr,
    color="#00664A",
    solid_joinstyle="miter",
    solid_capstyle="projecting",
)
fpr = data["Predicted Diagnosis"][data["y"] == 0].mean()
tpr = data["Predicted Diagnosis"][data["y"] == 1].mean()
legend_handles["S-"].append(
    axd["S-"].plot(
        fpr, tpr, "s-", label=f"Blinded test\nset, AUC={auc:.3f}", color="#00664A"
    )[0]
)

# AlereLAM on test set
tpr = data["AlereLAM Result"][data["y"] == 1].mean()
fpr = data["AlereLAM Result"][data["y"] == 0].mean()
legend_handles["S-"].append(
    axd["S-"].plot(
        fpr,
        tpr,
        "D",
        label="AlereLAM in\nblinded test set",
        color=color_marker["AlereLAM"],
        zorder=-1,
    )[0]
)

# TPP
for label, (loc, height, alpha) in tpp_sizes.items():
    legend_legends["WHO target product profile"].append(
        axd["All"].add_patch(
            matplotlib.patches.Rectangle(
                (0, loc),
                width=0.02,
                height=height,
                color=color_marker["TPP"],
                alpha=alpha,
                lw=0,
                label=label,
            )
        )
    )

# Comparison assays
for row in meta_analysis[meta_analysis[("Total", "axd")].notna()].iterrows():
    kwargs = add_errorbar(row[1])
    axs = row[1][("Total", "axd")].split(", ")
    if row[1][("Total", "legend")] == 1:
        axs.append("Legend")
    for a in axs:
        eb = axd[a].errorbar(**kwargs)
        if a == "Legend":
            legend_legends["Comparisons from literature"].append(eb)

for a in ["All", "HIV-", "HIV+"]:
    empirical_spos = (
        clinical.loc[y_tvt.index]["p_cat"][masks[a] & y_tvt] == "S+C+"
    ).sum() / (masks[a] & y_tvt).sum()
    (artist,) = axd[a].plot(
        0.02,
        empirical_spos,
        "D",
        color=color_marker["Smear microscopy"],
        zorder=-3,
        label="Smear positivity rate in cohort",
    )
legend_legends["Comparisons from literature"].append(artist)

legend_legends["Nested cross-validation"] = [
    axd["Legend"].plot([], [], "s-", color="k", label="Mean (area under curve, AUC)")[
        0
    ],
    axd["Legend"].fill_between(
        [], [], [], color="k", alpha=0.2, label="95% confidence interval", lw=0
    ),
]

locs = [(0, 0.74), (0, 0.29), (0, -0.09)]
for (title, handles), loc in zip(legend_legends.items(), locs):
    artist = axd["Legend"].legend(
        handles=handles, title=title, loc=loc, mode="expand", **legend_kws
    )
    axd["Legend"].add_artist(artist)


# Legends
axd["Legend"].axis("off")
axd["Legend"].set(xlim=(-2, -1), ylim=(-2, -1))
for a in masks:
    axd[a].legend(handles=legend_handles[a], loc="lower right", **legend_kws)

# Titles
for a, (title, l) in panel_titles.items():
    dy = 0.01
    dx = 0.09
    y = 1 - dy if title.startswith("All") else 1 - 2 * dy
    xl = dx if title.startswith("All") else 2 * dx
    axd[a].text(x=0.5, y=y, s=title, ha="center", va="top", weight=title_weight)
    axd[a].text(x=xl, y=y, s=f"({l})", ha="left", va="top", weight=title_weight)


plt.savefig(OUTPUT_DIR / "f2.pdf")

## Table of sensitivity and specificity

In [None]:
s_s_cols = pd.MultiIndex.from_product(
    [
        ["auc", "tpr_at_threshold", "fpr_at_threshold", "sensitivity", "specificity"],
        ["mean", "ci_lower", "ci_upper"],
    ]
)

sens_spec = pd.DataFrame(
    index=pd.MultiIndex.from_product([hiv_status, smear_status]),
    columns=s_s_cols,
)

for hiv_label, hiv_selector in hiv_status.items():
    for smear_label, smear_selector in smear_status.items():
        mask = hiv_selector & smear_selector
        y_true = y_tvt.loc[mask]
        d = threshold_average_with_ci(ncv_tvt.predicted_proba.loc[mask], y_true)
        for score, stat in s_s_cols:
            if score in d:
                sens_spec.loc[(hiv_label, smear_label), (score, stat)] = d[score][stat]

sens_spec["sensitivity"] = sens_spec["tpr_at_threshold"]
sens_spec["specificity"] = 1 - sens_spec["fpr_at_threshold"]
sens_spec = sens_spec.drop(["tpr_at_threshold", "fpr_at_threshold"], axis=1, level=0)
sens_spec.to_excel(OUTPUT_DIR / "sens_spec_table.xlsx")

## Table of demographics

In [None]:
# Table 1
countries = clinical["Country"].value_counts().index

tb_categories = {"S+C+": "S+C+", "S-C+": "S-C+", "Clinical_TB": "Clinically diagnosed"}
cont_cats = {
    ("HIV+", "CD4 count*\n(cells/μL)"): "HIV_CD4CNT",
    ("Age\n(years)",): "Age at Enrol",
}
table_index = pd.MultiIndex.from_tuples(
    [("N", "Initial"), ("N", "Excluded"), ("N", "Measured"), ("Male",), ("HIV+",)]
    + list(cont_cats)
    + [("Country", country) for country in countries]
    + [("TB",)]
    + [("TB", cat) for cat in tb_categories.values()]
    + [("Non-TB",), ("Non-TB", "Latent TB**")]
)

table_1_cols = list(
    clinical["Cohort"].unique().map(lambda s: s.title(), na_action=None)
) + ["All"]
table_1 = pd.DataFrame(columns=table_1_cols, index=table_index)

for cohort in table_1.columns:
    df = clinical if cohort == "All" else clinical[clinical["Cohort"] == cohort.lower()]
    n = df.shape[0]
    table_1.loc[("N", "Initial"), cohort] = n

    excluded = df["488 Ag85B 182"].isna().sum()
    table_1.loc[("N", "Excluded"), cohort] = f"{excluded}\n({excluded / n:.0%})"
    df = df[df["488 Ag85B 182"].notna()]

    measured = df.shape[0]
    table_1.loc[("N", "Measured"), cohort] = f"{measured}\n({measured / n:.0%})"

    male = (df["SEX"] == "Male").sum()
    table_1.loc["Male", cohort] = f"{male}\n({male / measured:.0%})"

    hiv = df["HIV_status"].sum()
    table_1.loc["HIV+", cohort] = f"{hiv}\n({hiv / measured:.0%})"

    for r, df_col in cont_cats.items():
        col = df[df_col][df[df_col] > 0]
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            table_1.loc[r, cohort] = f"{round(col.median())}\n({col.min()}–{col.max()})"

    for country in countries:
        c = (df["Country"] == country).sum()
        table_1.loc[("Country", country), cohort] = f"{c}\n({c / measured:.0%})"

    tb = df["p_cat"].isin(tb_categories).sum()
    table_1.loc["TB", cohort] = f"{tb}\n({tb / measured:.0%})"
    for p_cat, row in tb_categories.items():
        n_p_cat = (df["p_cat"] == p_cat).sum()
        table_1.loc[("TB", row), cohort] = f"{n_p_cat}\n({n_p_cat / tb:.0%})"

    ntb = df["p_cat"].isin(["NonTB_LTBI", "NonTB_NonLTBI", "Likely_subcl_TB"]).sum()
    table_1.loc["Non-TB", cohort] = f"{ntb}\n({ntb / measured:.0%})"

    latent_tb = (
        (df["p_cat"] == "NonTB_LTBI")
        | ((df["p_cat"] == "Likely_subcl_TB") & df["QFT_RES"])
    ).sum()
    table_1.loc[("Non-TB", "Latent TB**"), cohort] = (
        f"{latent_tb}\n({latent_tb / ntb:.0%})"
    )

writer = pd.ExcelWriter(OUTPUT_DIR / "table_1.xlsx", engine="xlsxwriter")
table_1.to_excel(writer, sheet_name="Sheet1")
cell_format = writer.book.add_format(
    {"font_name": "Times New Roman", "font_size": 12, "align": "center", "border": 1}
)
writer.sheets["Sheet1"].set_column(0, len(table_1.columns) + 1, None, cell_format)
writer.close()

## Figure 3

In [None]:
def transform_ticks(divisor, ticks, labels=None):
    trans_ticks = np.arcsinh(np.asarray(ticks) / divisor)
    if labels is None:
        labels = ticks
    return {"ticks": trans_ticks, "labels": labels}


fig, axs = plt.subplots(ncols=2, figsize=(9, 4.5), width_ratios=[1, 3])

ag85b_div = 0.04
lam_div = 4
X_trans = np.arcsinh(
    clinical.loc[y_tvt.index, X_med.columns]
    / pd.Series(index=X_med.columns, data=[ag85b_div] + [lam_div] * 3)
)

tb_colors = {False: "#192666", True: "#993300"}
tb_labels = ["Non-TB", "TB"]

swarm_df = pd.concat(
    [clinical.loc[X_trans.index, ["p_cat", "HIV_status", "y"]], X_trans], axis=1
).melt(id_vars=["HIV_status", "y", "p_cat"], var_name="Plex")
swarm_df["x"] = swarm_df.apply(lambda row: f"{row.Plex}\n{row.y}", axis=1)

kwargs = {"size": 2, "x": "x", "y": "value"}
ylim = (swarm_df["value"].min() * 1.01, swarm_df["value"].max() * 1.01)
sns.swarmplot(swarm_df[swarm_df["Plex"] == "488 Ag85B 182"], ax=axs[0], **kwargs)
axs[0].set_xticks([0, 1], labels=["Non-TB\nAg85B", "TB\nAg85B"])
axs[0].set_yticks(**transform_ticks(ag85b_div, [0, 0.1, 1, 10, 100]))
axs[0].set(ylabel="Ag85B (pg/mL)", ylim=ylim, xlabel="")

sns.swarmplot(swarm_df[swarm_df["Plex"] != "488 Ag85B 182"], ax=axs[1], **kwargs)
axs[1].set_xticks(
    np.arange(6),
    labels=[
        f"{p_cat}\n{plex.split()[-1]}"
        for plex in X_trans.columns[1:]
        for p_cat in ["Non-TB", "TB"]
    ],
)
axs[1].set_yticks(**transform_ticks(lam_div, [0, 10, 100, 1000, 10000]))
axs[1].set(ylabel="LAM (pg/mL)", ylim=ylim, xlabel="")

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "f3.pdf")

## Exploratory data analysis

In [None]:
def get_var_type(s):
    if s.nunique() == 2:
        return "binary"
    elif pd.api.types.is_numeric_dtype(s):
        return "numeric"
    elif isinstance(s.dtype, pd.CategoricalDtype):
        return "categorical"
    else:
        raise ValueError(f"Unsupported data type {s.dtype} in column {s.name}")


def group(num_data, cat_data):
    return [num_data[cat_data == val] for val in cat_data.unique()]


def _choose_mwu_method(n1, n2, has_ties):
    """Choose appropriate Mann-Whitney U method to avoid overflow."""
    if (n1 > 8 and n2 > 8) or has_ties or (n1 + n2) > 20:
        return "asymptotic"
    return "exact"


def mannwhitneyu_ab(num_data, bin_data):
    groups = group(num_data, bin_data)
    n1, n2 = len(groups[0]), len(groups[1])

    # Check for ties
    combined = np.concatenate([groups[0], groups[1]])
    has_ties = len(combined) != len(np.unique(combined))

    # Choose method dynamically
    method = _choose_mwu_method(n1, n2, has_ties)

    u, p_value = stats.mannwhitneyu(
        groups[0], groups[1], alternative="two-sided", method=method
    )
    rrb = 1 - (2 * u) / (n1 * n2)
    return rrb, p_value


def mannwhitneyu_ba(bin_data, num_data):
    return mannwhitneyu_ab(num_data, bin_data)


def kruskal_ab(num_data, cat_data):
    groups = group(num_data, cat_data)

    # Check for valid groups (each group must have at least 1 observation)
    groups = [g for g in groups if len(g) > 0]
    if len(groups) < 2:
        return np.nan, np.nan

    h, p_value = stats.kruskal(*groups)

    # Fixed eta-squared calculation
    n_total = sum(len(g) for g in groups)
    k = len(groups)

    if n_total == k:  # Each group has exactly 1 observation
        return np.nan, p_value

    # Corrected eta-squared formula for stats.kruskal-Wallis
    eta_sq = (h - k + 1) / (n_total - k)
    eta_sq = max(0, eta_sq)  # Ensure non-negative

    return np.sqrt(eta_sq), p_value


def kruskal_ba(cat_data, num_data):
    return kruskal_ab(num_data, cat_data)


def chi2(data_a, data_b):
    contingency_table = pd.crosstab(data_a, data_b)

    # Check for valid contingency table
    if contingency_table.size == 0 or contingency_table.sum().sum() == 0:
        return np.nan, np.nan

    # Check minimum expected frequencies (common rule: all >= 5)
    expected = (
        contingency_table.sum(axis=0).values[:, None]
        @ contingency_table.sum(axis=1).values[None, :]
    )
    expected = expected / contingency_table.sum().sum()

    if (expected < 5).any():
        # Consider using Fisher's exact test for 2x2 tables or warning for larger tables
        pass  # Continue with chi-square but be aware of potential issues

    x2, p_value, __, __ = stats.chi2_contingency(contingency_table)

    # Use Cramér's V (more appropriate than phi for non-2x2 tables)
    n = len(data_a)
    min_dim = min(contingency_table.shape) - 1
    cramers_v = np.sqrt(x2 / (n * min_dim)) if min_dim > 0 else 0

    return cramers_v, p_value


stat_tests = {
    ("numeric", "numeric"): stats.spearmanr,
    ("numeric", "binary"): mannwhitneyu_ab,
    ("numeric", "categorical"): kruskal_ab,
    ("binary", "numeric"): mannwhitneyu_ba,
    ("categorical", "numeric"): kruskal_ba,
    ("categorical", "categorical"): chi2,
    ("binary", "binary"): chi2,
    ("binary", "categorical"): chi2,
    ("categorical", "binary"): chi2,
}

# Combine all columns to consider
clinical_hyp = clinical.select_dtypes(exclude=["object", "string"]).copy()
clinical_hyp = clinical_hyp[clinical_hyp.columns[clinical_hyp.nunique() > 1]]
dt_cols = clinical_hyp.select_dtypes("datetime64").columns
clinical_hyp[dt_cols] = (
    clinical_hyp[dt_cols]
    .map(lambda x: x.toordinal() if pd.notna(x) else pd.NA)
    .convert_dtypes()
)

# Step 2: Iterate Over All Unique Pairs of Columns
results = []

for col_a, col_b in tqdm(list(combinations(clinical_hyp.columns, 2))):
    # Step 3: Select and Perform Appropriate Test
    # Prepare data by dropping NA
    data = clinical_hyp[[col_a, col_b]].dropna()

    # Check both columns have variation after dropping NAs
    if all(data.nunique() > 1) and len(data) > 1:  # Need at least 2 observations
        var_type_a = get_var_type(data[col_a])
        var_type_b = get_var_type(data[col_b])

        appropriate_test = stat_tests[(var_type_a, var_type_b)]

        try:
            stat, p_value = appropriate_test(data[col_a], data[col_b])

            # Handle invalid results
            if pd.isna(stat) or pd.isna(p_value):
                continue

            # Step 4: Store the Results
            log_p_value = -np.log10(p_value) if p_value > 0 else np.nan
            results.append(
                {
                    "col_a": col_a,
                    "col_b": col_b,
                    "test": appropriate_test.__name__,
                    "r2": stat**2,
                    "p_value": p_value,
                    "log_p_value": log_p_value,
                    "n": len(data),
                }
            )
        except Exception as e:
            print(f"Error testing {col_a} vs {col_b}: {e}")
            continue


# Step 5: Compile Results into a DataFrame
results_df = pd.DataFrame(results)

# Step 6: Adjust p-values for Multiple Comparisons
p_values = results_df["p_value"].values
results_df["adjusted_p_value"] = stats.false_discovery_control(p_values, method="bh")
results_df = results_df.convert_dtypes().sort_values(
    ["p_value", "log_p_value", "adjusted_p_value", "r2", "n"],
    ascending=[True, False, True, False, False],
)

results_df.to_excel(OUTPUT_DIR / "correlations.xlsx", index=False)