In [None]:
import json
import pickle
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import RocCurveDisplay, get_scorer, make_scorer, roc_curve
from sklearn.model_selection import RepeatedStratifiedKFold

from config import CLINICAL_TEST_FILE, LOD_COL_FMT, OUTPUT_DIR, PROCESSED_DIR
from utils import (
    Barcode2,
    LogisticGAM,  # noqa
    NestedCV,
    StratifiedGroupKFoldFirst,  # noqa
    make_gam,
)

In [None]:
X = pd.read_csv(PROCESSED_DIR / "X.csv", index_col=0)
X_all = pd.read_csv(PROCESSED_DIR / "X_all.csv", index_col=[0, 1])
X_a: pd.DataFrame = np.arcsinh(X / LOD_COL_FMT["LOD_samples"] * 10)  # type: ignore
y = pd.read_csv(PROCESSED_DIR / "y.csv", index_col=0)
X_all_a = np.arcsinh(X_all / LOD_COL_FMT["LOD_samples"] * 10).loc[y.index]
mbv = pd.read_csv(PROCESSED_DIR / "mbv.csv", index_col=0).loc[X.index]
y_all = y.loc[X_all_a.index.get_level_values(0)]
y_all.index = X_all_a.index

In [None]:
n_splits = 5
random_state = 0
inner_repeats = 4
outer_repeats = 10
n_trials = 100
negligible = 0.01


def sens_at_spec(y_true, y_score, spec=1):
    fpr, tpr, __ = roc_curve(y_true, y_score)
    valid_tpr = tpr[fpr <= 1 - spec]
    return valid_tpr.max() if valid_tpr.size else 0.0


scorers = {
    "roc_auc": get_scorer("roc_auc"),
    "sens_at_perf_spec": make_scorer(
        sens_at_spec, response_method=("decision_function", "predict_proba")
    ),
    "balanced_accuracy": get_scorer("balanced_accuracy"),
}

## Nested CV (warning: long)

In [None]:
# Using median only
outer_cv = RepeatedStratifiedKFold(
    n_splits=n_splits, n_repeats=outer_repeats, random_state=random_state
)
ncv_med = NestedCV(outer_cv, n_trials=100, scorers=scorers, inner_repeats=4)
ncv_med.fit(X_a, y.values.ravel())

# with open(PROCESSED_DIR / "ncv.pkl", "wb") as f:
#   pickle.dump(ncv_med, f)

## Alternative: load existing results

In [None]:
with open(PROCESSED_DIR / "ncv.pkl", "rb") as f:
    ncv_med = pickle.load(f)

## Plot partial dependence

In [None]:
train_on_mbv_params = ncv_med.chosen_params[[f"lam_{i}" for i in range(4)]].quantile(
    0.1
).to_dict() | {
    "n_splines_0": 9,  # [7, 9],
    "n_splines_1": 15,  # [10, 13, 14, 15],
    "n_splines_2": 13,
    "n_splines_3": 14,  # [10, 12, 14],
    "mono_0": 1,
    "mono_1": 1,
    "mono_2": 0,
    "mono_3": 0,
}
train_on_mbv = make_gam(train_on_mbv_params, cols=range(X_a.shape[1]))
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    train_on_mbv.fit(X_a, y.values.ravel())
fig, axs = plt.subplots(ncols=4, figsize=(16, 4), sharey=True)

pgml_ticks = [
    [0, 0.1, 1, 10, 100],
    [0, 1, 10, 100, 1000, 10000],
    [0, 10, 100, 1000, 10000],
    [0, 1, 10, 100, 1000],
]
xlims = [(-0.15, 10000), (-4, 10000), (-20, 10000), (0, 10000)]
for i in range(4):
    plex = X_a.columns[i]
    color = LOD_COL_FMT.loc[plex, "color"]
    XX = train_on_mbv.generate_X_grid(term=i, n=1000)
    pdep = train_on_mbv.partial_dependence(term=i, X=XX)
    if i == 0:
        pdep = pdep - train_on_mbv.coef_[0]
    axs[i].plot(XX[:, i], pdep, color="k", lw=1)

    lod = LOD_COL_FMT.loc[plex, "LOD_samples"]
    if int(lod) == lod:
        lod = int(lod)
    axs[i].axvline(
        np.arcsinh(10),
        ls=(1, (3, 3)),
        color=color,
        # alpha=0.5,
        lw=1,
        label=f"LOD: {lod} pg/mL",
        zorder=-1,
    )
    # axs[i].axhline(0, color="k", lw=1, alpha=0.2, zorder=-2)
    axs[i].axhline(np.log(0.72 / (1 - 0.72)) / 4, color="k", lw=1, alpha=0.2, zorder=-2)
    trans_ticks = np.arcsinh(np.array(pgml_ticks[i]) / lod * 10)
    xlim_trans = np.arcsinh(np.array(xlims[i]) / lod * 10)
    xlim = (max(XX[:, i].min(), xlim_trans[0]), min(XX[:, i].max(), xlim_trans[1]))
    axs[i].set(
        xlim=xlim,
        ylim=(-5.3, 7.3),
        title=plex,
        xticks=trans_ticks,
        xticklabels=pgml_ticks[i],
        xlabel=f"{plex.split()[1]} concentration (pg/mL)",
        # yticks=np.arange(-6, 9),
        # yticklabels=[""] * 15,
    )
    # axs[i].legend(loc="lower right", framealpha=1)
axs[0].set(ylabel="partial log-odds")

plt.savefig(OUTPUT_DIR / "si_partial_dependence.pdf")

## Save model trained on model-building and validation sets

In [None]:
# with open(PROCESSED_DIR / "train_on_mbv.pkl", "rb") as f:
#     train_on_mbv = pickle.load(f)

with open(PROCESSED_DIR / "train_on_mbv.pkl", "wb") as f:
    pickle.dump(train_on_mbv, f)

train_on_mbv_d = train_on_mbv.__dict__.copy()
for key in ["distribution", "link", "terms"]:
    train_on_mbv_d[key] = str(train_on_mbv_d[key])
train_on_mbv_d["callbacks"] = [str(cb) for cb in train_on_mbv_d["callbacks"]]
train_on_mbv_d["logs_"] = dict(train_on_mbv_d["logs_"])
for key, value in train_on_mbv_d.items():
    if isinstance(value, np.ndarray):
        train_on_mbv_d[key] = value.tolist()
for key, value in train_on_mbv_d["statistics_"].items():
    if isinstance(value, np.ndarray):
        train_on_mbv_d["statistics_"][key] = value.tolist()

with open(PROCESSED_DIR / "train_on_mbv_params.json", "w") as f:
    json.dump(train_on_mbv_d, f)

## Make and save predictions for test set

In [None]:
X_med = pd.read_csv(PROCESSED_DIR / "X_med.csv", index_col=0)
X_med_a = np.arcsinh(X_med / LOD_COL_FMT["LOD_samples"] * 10)

X_test_a = X_med_a.loc[X_med_a.index.difference(X_a.index)]
X_test = X_med.loc[X_med.index.difference(X_a.index)].copy()
X_test.columns = X_test.columns.map(lambda col: f"{col} (pg/mL)")

y_test_pred = train_on_mbv.predict_proba(X_test_a)[:, 1]

X_test["Estimated TB Probability"] = y_test_pred

X_test["Predicted Diagnosis"] = X_test["Estimated TB Probability"].map(
    lambda x: "TB" if x > 0.72 else "Non TB"
)

barcode_map = pd.read_csv(PROCESSED_DIR / "barcodes.csv", index_col=0)
X_test.index = pd.Index(barcode_map.loc[X_test.index].values[:, 0], name="barcode")

writer = pd.ExcelWriter(OUTPUT_DIR / "test_set_predictions.xlsx")
X_test.reset_index().to_excel(writer, index=False, sheet_name="Predictions")

for col_idx, col_width in enumerate([20] * 7):
    writer.sheets["Predictions"].set_column(col_idx, col_idx, col_width)

writer.close()

X_test.index.name = "Barcode"
X_test.index = X_test.index.map(lambda x: Barcode2(x).any_aliquot())

## (After unblinding) Read in clinical true values for test set

In [None]:
clinical_test = pd.read_excel(CLINICAL_TEST_FILE)
clinical_test.index = clinical_test["barcode"].map(lambda x: Barcode2(x).any_aliquot())

### Plot ROC curve

In [None]:
X_test_roc = X_test.loc[X_test.index.intersection(clinical_test.index)]
y_test_pred = X_test_roc["Estimated TB Probability"]
y_test_roc = clinical_test.loc[X_test_roc.index, "p_cat"].map(
    {
        "Clinical_TB": 1,
        "Likely_subcl_TB": 0,
        "NonTB_LTBI": 0,
        "NonTB_NonLTBI": 0,
        "S-C+": 1,
    }
)
RocCurveDisplay.from_predictions(
    y_test_roc, y_test_pred, name="ROC Curve", color="k", lw=1
)
sens = (
    (X_test_roc["Predicted Diagnosis"] == "TB") & (y_test_roc == 1)
).sum() / y_test_roc.sum()
spec = ((X_test_roc["Predicted Diagnosis"] == "Non TB") & (y_test_roc == 0)).sum() / (
    y_test_roc == 0
).sum()

ax = plt.gca()
ax.plot(
    [1 - spec],
    [sens],
    "o",
    color="tab:red",
    label=f"Predetermined threshold\n{sens:.0%} sensitivity\n{spec:.0%} specificity",
)
ax.plot([0, 1], [0, 1], "k--", alpha=0.2)
ax.set(
    xlim=(0, 1),
    ylim=(0, 1),
    aspect="equal",
    xticks=np.arange(0, 1.01, 0.1),
    yticks=np.arange(0, 1.01, 0.1),
    xlabel="1 - Specificity (False Positive Rate)",
    ylabel="Sensitivity (True Positive Rate)",
    title="Test Set",
)
ax.legend()
plt.tight_layout()
plt.show()
X_test["Actual p_cat"] = X_test.index.map(
    lambda x: clinical_test.loc[x, "p_cat"] if x in clinical_test.index else pd.NA
)