In [None]:
import os
import random
import pandas as pd
import numpy as np
from dotenv import load_dotenv
from tqdm import tqdm
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
from sklearn.metrics import precision_score, recall_score, roc_curve, auc

METRICS = ["roc_auc", "pr_auc", "f1", "precision", "recall"]

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

experiment_name = "test_cardiomegaly_cls"
label = "Cardiomegaly"
results_path = os.path.join(project_path, "runs", experiment_name, "results", label)

In [None]:
run_name = "vitb_CT-RATE"
checkpoint_name = "training_99999"

In [None]:
def get_logits(result_path, fold_idx, epoch):
    """
    Folds and Epochs are indexed from 1.
    """
    epoch_predictions_path = os.path.join(
        result_path, f"fold_{fold_idx}", "predictions", f"epoch_{epoch:02d}.csv"
    )
    epoch_predictions = pd.read_csv(epoch_predictions_path)
    logits = epoch_predictions["logits"].values
    labels = epoch_predictions["labels"].values

    return logits, labels

In [None]:
def get_pr_f1(logits, labels):
    probabilities = 1 / (1 + np.exp(-logits))
    thresholds = np.linspace(0, 1, 100)
    best_threshold = 0.5
    best_youden = 0

    for threshold in thresholds:
        predictions = (probabilities >= threshold).astype(int)

        sensitivity = np.sum((predictions == 1) & (labels == 1)) / np.sum(labels == 1)
        specificity = np.sum((predictions == 0) & (labels == 0)) / np.sum(labels == 0)
        youden = sensitivity + specificity - 1

        if youden > best_youden:
            best_youden = youden
            best_threshold = threshold

    final_predictions = (probabilities >= best_threshold).astype(int)
    final_precision = precision_score(labels, final_predictions)
    final_recall = recall_score(labels, final_predictions)
    best_f1 = 2 * final_precision * final_recall / (final_precision + final_recall)

    return {"precision": final_precision, "recall": final_recall, "f1": best_f1}

In [None]:
def get_cv_results(results_path: str, n_folds: int = 5, n_epochs: int = 10):
    """
    Parse cross-validation results.

    Args:
        results_path: Path to results directory.

    Returns:
        confidence_intervals: Dictionary of confidence intervals for each metric.
        metrics_epoch_fold: Dictionary of metrics for each epoch and fold for each metric, epoch, and cv fold.
        best_epoch:
    """

    confidence_intervals = {}
    metrics_epoch_fold = {}

    epochs_performance = {m: {ep: [] for ep in range(n_epochs)} for m in METRICS}

    for i in range(n_folds):
        cv_fold_path = os.path.join(results_path, f"fold_{i}", "summary.csv")
        summary_df = pd.read_csv(cv_fold_path)
        assert len(summary_df) == n_epochs

        for ep in range(n_epochs):
            logits, labels = get_logits(results_path, i, ep)

            pr_f1 = get_pr_f1(logits, labels)
            for m in ["precision", "recall", "f1"]:
                epochs_performance[m][ep].append(pr_f1[m])

        for ep, row in summary_df.iterrows():
            for m in ["roc_auc", "pr_auc"]:
                epochs_performance[m][ep].append(row[m])

    confidence_intervals = {
        m: [(np.mean(perf), np.std(perf)) for perf in epochs_performance[m].values()]
        for m in METRICS
    }

    return epochs_performance, confidence_intervals

In [None]:
def best_epoch_from_metric(epochs_performance, selection_metric: str = "roc_auc"):
    epoch_mean_selection = [
        np.mean(perf) for perf in epochs_performance[selection_metric].values()
    ]

    best_epoch = np.argmax(epoch_mean_selection)

    return best_epoch

In [None]:
epochs_performance, confidence_intervals = get_cv_results(results_path)
best_epoch = best_epoch_from_metric(epochs_performance)

In [None]:
def generate_val_epoch_plot(train_results, label, metric, num_folds=5, num_epochs=10):
    plt.figure()

    for fold_idx in range(num_folds):
        epoch_metrics = [train_results[metric][x][fold_idx] for x in range(num_epochs)]
        plt.plot(epoch_metrics, label=f"Fold {fold_idx + 1}")

    plt.xlabel("Epoch")
    plt.ylabel(metric)
    plt.legend()
    plt.title(f"Training Curve {metric} for {label}")

    plt.show()

In [None]:
generate_val_epoch_plot(epochs_performance, label, "roc_auc")