In [None]:
import os
import pandas as pd
import numpy as np
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from prettytable import PrettyTable
import seaborn as sns

from typing import List, Dict, Tuple

METRICS = ["roc_auc", "f1", "precision", "recall", "specificity", "jw5", "jw6", "loss"]

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

experiment_name = "test_cardiomegaly_cls"
label = "Cardiomegaly"
results_path = os.path.join(project_path, "runs", experiment_name, "results", label)

In [None]:
def get_logits(
    result_path: str, valid_name: int, epoch: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Folds and epochs are 0-indexed.
    """
    epoch_predictions_path = os.path.join(
        result_path, valid_name, "predictions", f"epoch_{epoch:02d}.csv"
    )
    epoch_predictions = pd.read_csv(epoch_predictions_path)
    logits = epoch_predictions["logits"].values
    labels = epoch_predictions["labels"].values

    return logits, labels

In [None]:
def get_threshold(
    logits: np.ndarray | List[np.ndarray], labels: np.ndarray | List[np.ndarray]
):
    if not isinstance(logits, np.ndarray):
        logits = np.concatenate(logits)
        labels = np.concatenate(labels)

    probabilities = 1 / (1 + np.exp(-logits))

    thresholds = np.linspace(0, 1, 100)
    best_threshold = 0.5
    best_youden = 0

    for threshold in thresholds:
        predictions = (probabilities >= threshold).astype(int)

        sensitivity = np.sum((predictions == 1) & (labels == 1)) / np.sum(labels == 1)
        specificity = np.sum((predictions == 0) & (labels == 0)) / np.sum(labels == 0)
        youden = sensitivity + specificity - 1

        if youden > best_youden:
            best_youden = youden
            best_threshold = threshold

    return best_threshold

In [None]:
def get_metrics(
    logits: np.ndarray, labels: np.ndarray, threshold=None
) -> Dict[str, float]:
    """
    Calculate performance metrics based on logits and true labels.
    Also returns the threshold determined using the youden index.
    """
    roc_auc = roc_auc_score(labels, logits)

    if threshold is None:
        threshold = get_threshold(logits, labels)

    probabilities = 1 / (1 + np.exp(-logits))

    final_predictions = (probabilities >= threshold).astype(int)

    final_precision = np.sum((final_predictions == 1) & (labels == 1)) / np.sum(
        final_predictions == 1
    )
    final_recall = np.sum((final_predictions == 1) & (labels == 1)) / np.sum(
        labels == 1
    )
    specificity = np.sum((final_predictions == 0) & (labels == 0)) / np.sum(labels == 0)
    jw5 = (final_precision + final_recall) / 2
    jw6 = final_precision * 0.4 + final_recall * 0.6
    best_f1 = 2 * final_precision * final_recall / (final_precision + final_recall)

    loss = -np.mean(labels * logits + (1 - labels) * (-logits))

    return {
        "roc_auc": roc_auc,
        "f1": best_f1,
        "precision": final_precision,
        "recall": final_recall,
        "specificity": specificity,
        "jw5": jw5,
        "jw6": jw6,
        "loss": loss,
        "threshold": threshold,
    }

In [None]:
def get_cv_results(
    results_path: str, n_folds: int = 5, n_epochs: int = 10, default_threshold=None
) -> Tuple[Dict[str, List[float]], Dict[str, List[Tuple[float, float]]]]:
    """
    Parse cross-validation results. Returns metrics for each epoch and fold, along with confidence intervals.

    Args:
        results_path: Path to results directory.
        n_folds: Number of folds.
        n_epochs: Number of epochs.

    Returns:
        metrics_epoch_fold: Dictionary of metrics for each epoch and fold.
        confidence_intervals: Dictionary of mean and standard deviation for each metric and each epoch.
    """
    epochs_performance = {
        m: {ep: [] for ep in range(n_epochs)} for m in METRICS + ["threshold"]
    }

    for ep in range(n_epochs):
        logits, labels = zip(
            *[get_logits(results_path, f"fold_{i}", ep) for i in range(n_folds)]
        )

        if default_threshold is None:
            threshold = get_threshold(logits, labels)
        else:
            threshold = default_threshold

        for i in range(n_folds):

            performance_metrics = get_metrics(logits[i], labels[i], threshold)
            for metric_name, metric_value in performance_metrics.items():
                epochs_performance[metric_name][ep].append(metric_value)

    confidence_intervals = {
        m: [(np.mean(perf), np.std(perf)) for perf in epochs_performance[m].values()]
        for m in METRICS
    }

    return epochs_performance, confidence_intervals

In [None]:
def get_test_results(
    results_path: str, n_epochs: int = 10, default_threshold=None
) -> Tuple[Dict[str, List[float]], Dict[str, List[Tuple[float, float]]]]:

    epochs_performance = {
        m: {ep: None for ep in range(n_epochs)} for m in METRICS + ["threshold"]
    }

    for ep in range(n_epochs):
        logits, labels = get_logits(results_path, "test", ep)

        if default_threshold is None:
            threshold = get_threshold(logits, labels)
        else:
            threshold = default_threshold

        performance_metrics = get_metrics(logits, labels, threshold)
        for metric_name, metric_value in performance_metrics.items():
            epochs_performance[metric_name][ep] = metric_value

    return epochs_performance

In [None]:
def best_epoch_from_metric(
    epochs_performance: Dict[str, List[float]], selection_metric: str = "roc_auc"
) -> int:
    """
    Select the best epoch based on the specified metric.
    """
    epoch_mean_selection = [
        np.mean(perf) for perf in epochs_performance[selection_metric].values()
    ]

    if "loss" in selection_metric:
        best_epoch = np.argmin(epoch_mean_selection)
    else:
        best_epoch = np.argmax(epoch_mean_selection)

    return int(best_epoch)

In [None]:
def generate_simple_folds_plot(
    cv_performance: Dict[str, List[float]],
    label: str,
    metric: str,
    num_folds=5,
    num_epochs=10,
    cm=plt.cm.tab10,
) -> None:
    plt.figure(figsize=(8, 6))

    epochs = range(1, num_epochs + 1)

    for fold_idx in range(num_folds):
        epoch_metrics = [cv_performance[metric][x][fold_idx] for x in range(num_epochs)]
        plt.plot(
            epochs,
            epoch_metrics,
            color=cm(fold_idx / num_folds),
            label=f"Fold {fold_idx + 1}",
        )

    plt.xlabel("Epoch")
    plt.ylabel(metric)
    plt.legend(
        loc="lower center",
        ncol=num_folds,
        bbox_to_anchor=(0.5, -0.25),
        frameon=False,
    )
    plt.xticks(epochs)
    plt.title(f"Training Curve {metric} for {label}")
    plt.grid(True, linestyle="--", alpha=0.5)

    plt.show()

In [None]:
def generate_folds_plot_all_metrics(
    cv_performance: Dict[str, List[float]],
    best_epoch_by_metric: Dict[str, int],
    label: str,
    num_folds=5,
    num_epochs=10,
    cm=plt.cm.tab10,
) -> None:
    n_metrics = len(METRICS)
    fig_height = 3 * n_metrics
    fig, axs = plt.subplots(n_metrics, figsize=(12, fig_height), sharex=True)

    if n_metrics == 1:
        axs = [axs]

    epochs = range(1, num_epochs + 1)

    for metric_idx, metric in enumerate(METRICS):
        best_epoch = best_epoch_by_metric[metric] + 1

        ax = axs[metric_idx]
        for fold_idx in range(num_folds):
            epoch_metrics = [
                cv_performance[metric][x][fold_idx] for x in range(num_epochs)
            ]
            ax.plot(
                epochs,
                epoch_metrics,
                color=cm(fold_idx / num_folds),
                label=f"Fold {fold_idx + 1}",
                linewidth=2,
                marker="o",
                markersize=4,
            )

        ax.axvline(
            x=best_epoch,
            color="black",
            linestyle="--",
            alpha=0.5,
            linewidth=2,
            label="Best Epoch",
        )

        ax.set_xticks(epochs)
        ax.set_ylabel(metric.upper())
        ax.set_title(f"{metric.upper()} Training Curve", fontsize=12)
        ax.grid(True, linestyle="--", alpha=0.5)

    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        ncol=len(labels),
        bbox_to_anchor=(0.5, 0.0),
        frameon=False,
    )

    axs[-1].set_xlabel("Epoch", fontsize=12)
    fig.suptitle(
        f"{num_folds}-Fold Validation Performance for {label}",
        fontsize=14,
        fontweight="bold",
    )
    plt.tight_layout(rect=[0, 0.05, 1, 0.97])
    plt.show()

In [None]:
def generate_ci_plot_all_metrics(
    confidence_intervals: Dict[str, List[Tuple[float, float]]],
    best_epoch_by_metric: Dict[str, int],
    label: str,
    num_epochs=10,
    z_score=1.96,
    ci_label: str = "95% CI",
    full_y: bool = False,
) -> None:
    n_metrics = len(METRICS)
    fig_height = 3 * n_metrics
    fig, axs = plt.subplots(n_metrics, figsize=(12, fig_height), sharex=True)

    if n_metrics == 1:
        axs = [axs]

    epochs = range(1, num_epochs + 1)

    for metric_idx, metric in enumerate(METRICS):
        best_epoch = best_epoch_by_metric[metric] + 1

        ax = axs[metric_idx]
        means = [ci[0] for ci in confidence_intervals[metric]]
        stds = [ci[1] for ci in confidence_intervals[metric]]
        ci = [std * z_score for std in stds]

        ax.fill_between(
            epochs,
            np.array(means) - np.array(ci),
            np.array(means) + np.array(ci),
            alpha=0.2,
            color="navy",
            label=ci_label,
        )

        ax.plot(
            epochs,
            means,
            color="navy",
            linewidth=2,
            marker="o",
            markersize=4,
            label="Mean",
            alpha=0.8,
        )

        ax.axvline(
            x=best_epoch,
            color="black",
            linestyle="--",
            alpha=0.5,
            linewidth=2,
            label="Best Epoch",
        )

        ax.set_xticks(epochs)
        ax.set_ylabel(metric.upper())
        ax.set_title(f"{metric.upper()} Confidence Intervals", fontsize=12)
        ax.grid(True, linestyle="--", alpha=0.5)

    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        ncol=len(labels),
        bbox_to_anchor=(0.5, 0.0),
        frameon=False,
    )

    axs[-1].set_xlabel("Epoch", fontsize=12)
    fig.suptitle(
        f"Cross Validation Confidence Intervals for {label}",
        fontsize=14,
        fontweight="bold",
    )
    plt.tight_layout(rect=[0, 0.05, 1, 0.97])
    plt.show()

In [None]:
def generate_test_plot_all_metrics(
    test_performance: Dict[str, List[float]],
    label: str,
    num_folds=5,
    num_epochs=10,
) -> None:
    n_metrics = len(METRICS)
    fig_height = 3 * n_metrics
    fig, axs = plt.subplots(n_metrics, figsize=(12, fig_height), sharex=True)

    if n_metrics == 1:
        axs = [axs]

    epochs = range(1, num_epochs + 1)

    for metric_idx, metric in enumerate(METRICS):
        ax = axs[metric_idx]
        epoch_metrics = [test_performance[metric][x] for x in range(num_epochs)]
        ax.plot(
            epochs,
            epoch_metrics,
            color="navy",
            label=f"Test",
            linewidth=2,
            marker="o",
            markersize=4,
        )

        ax.set_xticks(epochs)
        ax.set_ylabel(metric.upper())
        ax.set_title(f"{metric.upper()} Training Curve", fontsize=12)
        ax.grid(True, linestyle="--", alpha=0.5)

    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        ncol=len(labels),
        bbox_to_anchor=(0.5, 0.0),
        frameon=False,
    )

    axs[-1].set_xlabel("Epoch", fontsize=12)
    fig.suptitle(f"Test Performance for {label}", fontsize=14, fontweight="bold")
    plt.tight_layout(rect=[0, 0.05, 1, 0.97])
    plt.show()

In [None]:
def generate_density_plot(results_path, valid_name, epoch, label):

    logits, labels = get_logits(results_path, valid_name, epoch)

    prob_pos = logits[labels == 1]
    prob_neg = logits[labels == 0]

    plt.figure(figsize=(8, 6))

    sns.kdeplot(prob_neg, label="Negative", fill=True, alpha=0.5, color="tab:blue")
    sns.kdeplot(prob_pos, label="Positive", fill=True, alpha=0.5, color="tab:orange")

    plt.xlabel("Predicted Probability")
    plt.ylabel("Density")
    plt.title(f"Density of Predicted Probabilities for {label}")
    plt.legend(
        loc="lower center",
        ncol=2,
        bbox_to_anchor=(0.5, -0.25),
        frameon=False,
    )
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout(rect=[0, 0.07, 1, 0.97])

    plt.show()

In [None]:
generate_density_plot(results_path, "test", 0, label)

In [None]:
default_threshold = 0.5

cv_performance, confidence_intervals = get_cv_results(
    results_path, default_threshold=default_threshold
)
test_performance = get_test_results(results_path, default_threshold=default_threshold)

best_epoch_by_metric = {
    m: best_epoch_from_metric(cv_performance, selection_metric=m) for m in METRICS
}

table = PrettyTable(["Metric", "Best Epoch"])

for key, value in best_epoch_by_metric.items():
    table.add_row([key, value])
print(table)

In [None]:
table = PrettyTable(["Epoch", "Best Threshold"])

for key, value in cv_performance["threshold"].items():
    table.add_row([key, f"{value[0]:.2f}"])
print(table)

In [None]:
generate_folds_plot_all_metrics(cv_performance, best_epoch_by_metric, label)

In [None]:
generate_ci_plot_all_metrics(confidence_intervals, best_epoch_by_metric, label)

In [None]:
generate_test_plot_all_metrics(test_performance, label)