In [1]:
import numpy as np
import pandas as pd

from sklearn.metrics import (
    confusion_matrix,
    recall_score,
    precision_score,
    f1_score,
    balanced_accuracy_score,
    roc_auc_score,
    average_precision_score
)
from typing import Dict, List, Tuple
from scipy import stats
from utisl import *

In [2]:
model_name = "LGBM"
all_predictions = pd.read_csv(
    f"results/next_year_prediction_results_{model_name}_followup_calibrated.csv",
    index_col=0,
)

(
    data,
    baseline_data,
    baseline_data_cont,
    follow_up_predictors_total,
    follow_up_predictors_vars_by_year,
    follow_up_predictors_cont,
) = data_extraction()

In [3]:
def calculate_metrics(y_true, y_pred, y_prob):
    """
    Calculate comprehensive binary classification metrics.

    Args:
        y_true: True labels
        y_pred: Predicted labels (after threshold)
        y_prob: Prediction probabilities

    Returns:
        dict: Dictionary containing all metrics
    """
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    metrics = {
        "Threshold Metrics": {
            "True Positives": int(tp),
            "True Negatives": int(tn),
            "False Positives": int(fp),
            "False Negatives": int(fn),
            "Sensitivity/Recall": recall_score(y_true, y_pred),
            "Specificity": tn / (tn + fp) if (tn + fp) > 0 else 0,
            "Precision/PPV": precision_score(y_true, y_pred) if (tp + fp) > 0 else 0,
            "NPV": tn / (tn + fn) if (tn + fn) > 0 else 0,
            "F1 Score": f1_score(y_true, y_pred),
            "BAC": balanced_accuracy_score(y_true, y_pred),
        },
        "Probability Metrics": {
            "ROC AUC": roc_auc_score(y_true, y_prob)
            if len(np.unique(y_true)) > 1
            else np.nan,
            "PR AUC": average_precision_score(y_true, y_prob),
        },
    }

    # Additional metrics
    metrics["Additional Metrics"] = {
        "Positive Rate": (tp + fp) / (tp + fp + tn + fn),
        "Negative Rate": (tn + fn) / (tp + fp + tn + fn),
        "False Discovery Rate": fp / (fp + tp) if (fp + tp) > 0 else 0,
        "False Omission Rate": fn / (fn + tn) if (fn + tn) > 0 else 0,
    }

    return metrics


def find_threshold_for_recall(probs, labels, target_recall=0.80):
    """
    Find the probability threshold that achieves the target recall.
    """
    sorted_indices = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_indices]
    sorted_labels = labels[sorted_indices]

    unique_probs = np.unique(sorted_probs)[::-1]
    for threshold in unique_probs:
        pred = (sorted_probs >= threshold).astype(int)
        recall = recall_score(sorted_labels, pred)
        if recall >= target_recall:
            return threshold

    return min(unique_probs)


def analyze_multiple_recalls(labels, probs, recall_targets):
    """
    For each recall target in recall_targets, find the threshold, make predictions,
    and compute metrics. Returns a DataFrame with results.
    """
    results = []
    for r in recall_targets:
        threshold = find_threshold_for_recall(probs, labels, target_recall=r)
        pred = (probs >= threshold).astype(int)
        metric_dict = calculate_metrics(labels, pred, probs)
        row = {
            "Recall_Target": r,
            "Threshold_Found": threshold,
            "Actual_Recall": metric_dict["Threshold Metrics"]["Sensitivity/Recall"],
            "Precision": metric_dict["Threshold Metrics"]["Precision/PPV"],
            "Specificity": metric_dict["Threshold Metrics"]["Specificity"],
            "F1": metric_dict["Threshold Metrics"]["F1 Score"],
            "BAC": metric_dict["Threshold Metrics"]["BAC"],
            "ROC_AUC": metric_dict["Probability Metrics"]["ROC AUC"],
            "PR_AUC": metric_dict["Probability Metrics"]["PR AUC"],
        }
        results.append(row)

    df_multi_recalls = pd.DataFrame(results)
    return df_multi_recalls


def decision_curve_analysis(labels, probs, thresholds):
    """
    Calculate net benefit across a list of threshold probabilities
    for decision curve analysis.

    Net Benefit = (TP/N) - (FP/N)*(p_t/(1-p_t))
    Treat All => NB = prevalence - (1-prevalence)*(p_t/(1-p_t))
    Treat None => NB = 0

    Returns a DataFrame with columns: threshold, net_benefit_model,
    net_benefit_treat_all, net_benefit_treat_none.
    """
    if thresholds is None:
        thresholds = np.linspace(0.0, 1.0, 21)  # e.g. 0.0, 0.05, ..., 1.0
    n = len(labels)
    prevalence = np.mean(labels)

    results = []
    for pt in thresholds:
        if pt == 0 or pt == 1:
            # Avoid dividing by zero in the formula (or we define net benefit in a special way)
            continue

        # Predictions at this threshold
        pred = (probs >= pt).astype(int)
        tp = np.sum((pred == 1) & (labels == 1))
        fp = np.sum((pred == 1) & (labels == 0))

        nb_model = (tp / n) - (fp / n) * (pt / (1 - pt))
        # Treat All => net benefit
        nb_treat_all = (prevalence) - ((1 - prevalence) * (pt / (1 - pt)))
        # Treat None => 0
        nb_treat_none = 0

        results.append(
            {
                "threshold": pt,
                "tp": tp,
                "fp": fp,
                "net_benefit_model": nb_model,
                "net_benefit_treat_all": nb_treat_all,
                "net_benefit_treat_none": nb_treat_none,
            }
        )

    return pd.DataFrame(results)


def evaluate_multiple_thresholds(labels, probs, thresholds=None):
    """
    Evaluate metrics at multiple operating thresholds (via fixed probability threshold).
    """
    if thresholds is None:
        thresholds = [0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.5]

    results = []
    for thr in thresholds:
        preds = (probs >= thr).astype(int)
        metric_dict = calculate_metrics(labels, preds, probs)
        metric_dict["threshold"] = thr
        results.append(metric_dict)

    records = []
    for r in results:
        row = {
            "Threshold": r["threshold"],
            "TP": r["Threshold Metrics"]["True Positives"],
            "FP": r["Threshold Metrics"]["False Positives"],
            "TN": r["Threshold Metrics"]["True Negatives"],
            "FN": r["Threshold Metrics"]["False Negatives"],
            "Recall": r["Threshold Metrics"]["Sensitivity/Recall"],
            "Specificity": r["Threshold Metrics"]["Specificity"],
            "Precision": r["Threshold Metrics"]["Precision/PPV"],
            "NPV": r["Threshold Metrics"]["NPV"],
            "F1": r["Threshold Metrics"]["F1 Score"],
            "BAC": r["Threshold Metrics"]["BAC"],
            "ROC_AUC": r["Probability Metrics"]["ROC AUC"],
            "PR_AUC": r["Probability Metrics"]["PR AUC"],
            "PosRate": r["Additional Metrics"]["Positive Rate"],
        }
        records.append(row)

    return pd.DataFrame(records)


def evaluate_top_k(labels, probs, k=None, fraction=None):
    """
    Evaluate precision and recall in the top-k or top fraction of predicted risks.
    """
    df_temp = pd.DataFrame({"label": labels, "prob": probs})
    df_temp.sort_values("prob", ascending=False, inplace=True)

    n_total = len(df_temp)
    n_positives = df_temp["label"].sum()

    if fraction is not None:
        top_n = int(np.ceil(fraction * n_total))
        subset = df_temp.head(top_n)
        descriptor = f"Top {fraction*100:.1f}% (N={top_n})"
    elif k is not None:
        top_n = min(k, n_total)
        subset = df_temp.head(top_n)
        descriptor = f"Top K={k} (N={top_n})"
    else:
        raise ValueError("Either 'k' or 'fraction' must be specified.")

    n_true_positives = subset["label"].sum()
    precision_top = n_true_positives / len(subset) if len(subset) > 0 else 0
    recall_top = n_true_positives / n_positives if n_positives > 0 else 0

    return {
        "Subset": descriptor,
        "Precision_top": precision_top,
        "Recall_top": recall_top,
        "Subset_size": len(subset),
        "Total_positives": n_positives,
    }


#######################################
# MAIN ANALYSIS LOGIC
#######################################


def analyze_predictions(df, outcome_type, year_type, recall_targets):
    """
    Analyze predictions for a specific outcome type with:
      1) Single-threshold metrics at each recall target
      2) Multiple threshold metrics (fixed probability cutoffs)
      3) Decision Curve Analysis
      4) Top-k fraction approach
    """
    # Filter data for the specific outcome type
    outcome_df = df[df["outcome_type"] == outcome_type].copy()
    labels = outcome_df["label"].values.astype(float)
    probs = outcome_df["prob"].values

    # (1) Multiple Recalls
    multi_recalls_df = analyze_multiple_recalls(
        labels, probs, recall_targets=recall_targets
    )

    # We also show single-threshold metrics for the first target_recall in that list
    target_recall = recall_targets[1]
    threshold = find_threshold_for_recall(probs, labels, target_recall)
    preds = (probs >= threshold).astype(int)
    single_thr_metrics = calculate_metrics(labels, preds, probs)
    single_thr_metrics["Threshold"] = threshold

    # (2) Evaluate multiple probability thresholds
    multi_thr_df = evaluate_multiple_thresholds(labels, probs)

    # (3) Decision Curve Analysis
    dca_df = decision_curve_analysis(labels, probs, thresholds=np.linspace(0, 1, 21))

    #     # (4) Top-k fraction approach (example: top 5%)
    #     top_fraction_results = evaluate_top_k(labels, probs, fraction=0.05)

    # Add predictions & misclassification for final DataFrame
    outcome_df["prediction"] = preds
    outcome_df["threshold_used"] = threshold
    outcome_df["misclassified"] = outcome_df["prediction"] != outcome_df["label"]
    outcome_df["year_type"] = year_type

    ###########################################
    # Print out the key results
    ###########################################
    print(f"\n{'='*50}")
    print(f"Results for {outcome_type} - {year_type}")
    print(f"{'='*50}")

    # Print multi-recall results
    print("\n--- Multiple Recall Targets ---")
    print(multi_recalls_df.round(4).to_string(index=False))

    # Print single threshold metrics at the first recall target
    print(f"\nSingle Threshold for {target_recall*100:.1f}% recall:")
    print(f"Found threshold = {threshold:.4f}")
    for cat_name, cat_vals in single_thr_metrics.items():
        if isinstance(cat_vals, dict):
            print(f"\n{cat_name}:")
            print("-" * 30)
            for metric_name, value in cat_vals.items():
                print(f"{metric_name:<25} {value:.4f}")
        elif isinstance(cat_vals, float):
            pass

    print("\n--- Multiple Operating Thresholds (Fixed Probability) ---")
    print(multi_thr_df.round(4).to_string(index=False))

    print("\n--- Decision Curve Analysis ---")
    print(
        dca_df[
            [
                "threshold",
                "net_benefit_model",
                "net_benefit_treat_all",
                "net_benefit_treat_none",
            ]
        ]
        .round(4)
        .to_string(index=False)
    )

    print("\n--- Top-Fraction Approach (e.g., Top 5%) ---")
    for k, v in evaluate_top_k(labels, probs, fraction=0.05).items():
        if isinstance(v, float):
            print(f"{k:<20}: {v:.4f}")
        else:
            print(f"{k:<20}: {v}")

    print("\n--- Top-Fraction Approach (e.g., Top 10%) ---")
    for k, v in evaluate_top_k(labels, probs, fraction=0.10).items():
        if isinstance(v, float):
            print(f"{k:<20}: {v:.4f}")
        else:
            print(f"{k:<20}: {v}")

    print("\n--- Top-Fraction Approach (e.g., Top 20%) ---")
    for k, v in evaluate_top_k(labels, probs, fraction=0.20).items():
        if isinstance(v, float):
            print(f"{k:<20}: {v:.4f}")
        else:
            print(f"{k:<20}: {v}")

    return outcome_df, single_thr_metrics, multi_recalls_df, multi_thr_df, dca_df, _


def analyze_all_outcomes(df, year_type, recall_targets=[0.5, 0.6, 0.7, 0.8, 0.9]):
    """
    Analyze predictions for all outcome types with comprehensive metrics.
    """
    all_results = []
    all_metrics = {}
    for outcome in df["outcome_type"].unique():
        (
            outcome_results,
            single_thr_metrics,
            multi_recalls_df,
            multi_thr_df,
            dca_df,
            top_k_dict,
        ) = analyze_predictions(df, outcome, year_type, recall_targets=recall_targets)

        all_results.append(outcome_results)
        # Store as you wish
        all_metrics[outcome] = {
            "SingleThrMetrics": single_thr_metrics,
            "MultiRecallsDF": multi_recalls_df,
            "MultiThreshDF": multi_thr_df,
            "DecisionCurveDF": dca_df,
            "TopFractionResults": top_k_dict,
        }

    return pd.concat(all_results), all_metrics


#######################################
# EXAMPLE USAGE
#######################################

model_name = "LGBM"
dataset = "real777"

# Load your predictions
all_predictions = pd.read_csv(
    f"{dataset}/results/next_year_prediction_results_{model_name}_multi_year_training_best_model_calibrated.csv",
    index_col=0,
)

# Decide whether to use calibrated or uncalibrated probabilities
all_predictions["prob"] = all_predictions["uncali_prob"]
# all_predictions["prob"] = all_predictions["cali_prob_full"]

all_results = pd.DataFrame()

# We analyze "death" and "graft_loss" separately, for "Y1" vs "Y2+"
for outcome_type in ["death", "graft_loss"]:
    for y_type in ["Y1", "Y2+"]:
        if y_type == "Y1":
            sub_predictions = all_predictions.loc[
                (all_predictions["outcome_type"] == outcome_type)
                & (all_predictions["year"] == 0)
            ]
        else:
            sub_predictions = all_predictions.loc[
                (all_predictions["outcome_type"] == outcome_type)
                & (all_predictions["year"] > 0)
            ]
        if len(sub_predictions) == 0:
            continue

        # Analyze across multiple recall targets: 50%, 60%, 70%, 80%, 90%
        results_df, metrics_dict = analyze_all_outcomes(
            sub_predictions, year_type=y_type, recall_targets=[0.5, 0.6, 0.7, 0.8, 0.9]
        )
        all_results = pd.concat([all_results, results_df])


Results for death - Y1

--- Multiple Recall Targets ---
 Recall_Target  Threshold_Found  Actual_Recall  Precision  Specificity     F1    BAC  ROC_AUC  PR_AUC
           0.5           0.0250         0.5048     0.0449       0.7125 0.0825 0.6086   0.6558  0.0447
           0.6           0.0207         0.6000     0.0427       0.6401 0.0798 0.6200   0.6558  0.0447
           0.7           0.0165         0.7048     0.0395       0.5406 0.0747 0.6227   0.6558  0.0447
           0.8           0.0121         0.8000     0.0345       0.4003 0.0661 0.6001   0.6558  0.0447
           0.9           0.0092         0.9048     0.0319       0.2645 0.0616 0.5847   0.6558  0.0447

Single Threshold for 60.0% recall:
Found threshold = 0.0207

Threshold Metrics:
------------------------------
True Positives            63.0000
True Negatives            2509.0000
False Positives           1411.0000
False Negatives           42.0000
Sensitivity/Recall        0.6000
Specificity               0.6401
Precision/PPV


Results for graft_loss - Y1

--- Multiple Recall Targets ---
 Recall_Target  Threshold_Found  Actual_Recall  Precision  Specificity     F1    BAC  ROC_AUC  PR_AUC
           0.5           0.0448         0.5031     0.0813       0.7588 0.1399 0.6309   0.6855  0.0912
           0.6           0.0429         0.6012     0.0763       0.6914 0.1355 0.6463   0.6855  0.0912
           0.7           0.0403         0.7055     0.0682       0.5915 0.1245 0.6485   0.6855  0.0912
           0.8           0.0366         0.8037     0.0570       0.4356 0.1064 0.6196   0.6855  0.0912
           0.9           0.0318         0.9018     0.0490       0.2571 0.0929 0.5795   0.6855  0.0912

Single Threshold for 60.0% recall:
Found threshold = 0.0429

Threshold Metrics:
------------------------------
True Positives            98.0000
True Negatives            2657.0000
False Positives           1186.0000
False Negatives           65.0000
Sensitivity/Recall        0.6012
Specificity               0.6914
Precisio