In [None]:
import os
import pandas as pd
import numpy as np
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

from typing import List, Dict, Tuple

METRICS = ["roc_auc", "f1", "precision", "recall", "specificity", "jw5", "jw6"]

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

experiment_name = "test_cardiomegaly_cls"
label = "Cardiomegaly"
results_path = os.path.join(project_path, "runs", experiment_name, "results", label)

In [None]:
def get_logits(result_path: str, fold_idx: int, epoch: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Folds and epochs are 0-indexed.
    """
    epoch_predictions_path = os.path.join(
        result_path, f"fold_{fold_idx}", "predictions", f"epoch_{epoch:02d}.csv"
    )
    epoch_predictions = pd.read_csv(epoch_predictions_path)
    logits = epoch_predictions["logits"].values
    labels = epoch_predictions["labels"].values

    return logits, labels

In [None]:
def get_metrics(logits: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
    """
    Calculate performance metrics based on logits and true labels.
    """
    roc_auc = roc_auc_score(labels, logits)

    probabilities = 1 / (1 + np.exp(-logits))
    thresholds = np.linspace(0, 1, 100)
    best_threshold = 0.5
    best_youden = 0

    for threshold in thresholds:
        predictions = (probabilities >= threshold).astype(int)

        sensitivity = np.sum((predictions == 1) & (labels == 1)) / np.sum(labels == 1)
        specificity = np.sum((predictions == 0) & (labels == 0)) / np.sum(labels == 0)
        youden = sensitivity + specificity - 1

        if youden > best_youden:
            best_youden = youden
            best_threshold = threshold

    final_predictions = (probabilities >= best_threshold).astype(int)

    final_precision = np.sum((final_predictions == 1) & (labels == 1)) / np.sum(final_predictions == 1)
    final_recall = np.sum((final_predictions == 1) & (labels == 1)) / np.sum(labels == 1)
    specificity = np.sum((final_predictions == 0) & (labels == 0)) / np.sum(labels == 0)
    jw5 = (final_precision + final_recall) / 2
    jw6 = final_precision*0.4 + final_recall*0.6
    best_f1 = 2 * final_precision * final_recall / (final_precision + final_recall)

    return {
        "roc_auc": roc_auc, 
        "f1": best_f1,
        "precision": final_precision,
        "recall": final_recall,
        "specificity": specificity,
        "jw5": jw5,
        "jw6": jw6,
    }

In [None]:
def get_cv_results(
        results_path: str,
        n_folds: int = 5,
        n_epochs: int = 10
    ) -> Tuple[Dict[str, List[float]], Dict[str, List[Tuple[float, float]]]]:
    """
    Parse cross-validation results. Returns metrics for each epoch and fold, along with confidence intervals.

    Args:
        results_path: Path to results directory.
        n_folds: Number of folds.
        n_epochs: Number of epochs.

    Returns:
        metrics_epoch_fold: Dictionary of metrics for each epoch and fold. 
        confidence_intervals: Dictionary of mean and standard deviation for each metric and each epoch.
    """
    epochs_performance = {m: {ep: [] for ep in range(n_epochs)} for m in METRICS}

    for i in range(n_folds):
        for ep in range(n_epochs):
            logits, labels = get_logits(results_path, i, ep)

            performance_metrics = get_metrics(logits, labels)
            for metric_name, metric_value in performance_metrics.items():
                epochs_performance[metric_name][ep].append(metric_value)

    confidence_intervals = {
        m: [(np.mean(perf), np.std(perf)) for perf in epochs_performance[m].values()]
        for m in METRICS
    }

    return epochs_performance, confidence_intervals

In [None]:
def best_epoch_from_metric(
        epochs_performance: Dict[str, List[float]],
        selection_metric: str = "roc_auc"
    ) -> int:
    """
    Select the best epoch based on the specified metric.
    """
    epoch_mean_selection = [
        np.mean(perf) for perf in epochs_performance[selection_metric].values()
    ]

    best_epoch = np.argmax(epoch_mean_selection)

    return best_epoch

In [None]:
def generate_simple_folds_plot(
        epochs_performance: Dict[str, List[float]],
        label:str,
        metric:str,
        num_folds=5,
        num_epochs=10
    ) -> None:
    plt.figure()

    for fold_idx in range(num_folds):
        epoch_metrics = [epochs_performance[metric][x][fold_idx] for x in range(num_epochs)]
        plt.plot(epoch_metrics, label=f"Fold {fold_idx + 1}")

    plt.xlabel("Epoch")
    plt.ylabel(metric)
    plt.legend()
    plt.title(f"Training Curve {metric} for {label}")

    plt.show()

In [None]:
def generate_ci_plot(
        confidence_intervals:Dict[str, List[Tuple[float, float]]],
        label:str,
        metric:str,
        num_epochs=10,
        z_score=1.96
    ) -> None:
    plt.figure()

    means = [ci[0] for ci in confidence_intervals[metric]]
    stds = [ci[1] for ci in confidence_intervals[metric]]
    ci = [std * z_score for std in stds]
    
    epochs = range(num_epochs)
    plt.plot(epochs, means, label=f"Mean {metric}")
    plt.ylim([0, 1])
    plt.xticks(epochs)
    plt.fill_between(epochs, np.array(means) - np.array(ci), np.array(means) + np.array(ci), alpha=0.2)
    plt.xlabel("Epoch")
    plt.ylabel(metric)
    plt.title(f"Confidence Interval for {metric} - {label}")
    plt.legend()
    plt.show()
    

In [None]:
epochs_performance, confidence_intervals = get_cv_results(results_path)
best_epoch = best_epoch_from_metric(epochs_performance)

In [None]:
generate_simple_folds_plot(epochs_performance, label, "roc_auc")