In [None]:
# -*- authors : Vincent Roduit -*-
# -*- date : 2025-12-11 -*-
# -*- Last revision: 2025-12-11 by Vincent Roduit -*-
# -*- python version : 3.13.7. -*-
# -*- Description: Notebook to see results -*-

# <center> inAGE - imaging neuroscience of AGEing </center>
## <center> White Matter Hyperintensity detection  </center>
---

In [None]:
import logging
import sys
import warnings
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from analysis.model_comparisons.model_comparison import (
    clean_summary_table,
    fetch_models_results,
    format_table_results,
)
from analysis.segmentation_difficulties.lesion_analysis import (
    compare_qmaps_distributions,
    compute_pca,
)
from analysis.segmentation_difficulties.seg_diff import compute_lesion_stats
from analysis.uncertainty_quantification.uq_correlation import (
    calculate_correlation,
    correct_correlation,
)
from data_func.utils import filter_csv_description
from misc.constants import (
    DATA_CSV,
    DEFAULT_MAP_ORDER,
    FAZEKAS,
    ID,
    MODEL,
    RESULTS_DIR,
    SPLIT,
    TEST,
    UQ_RESULTS_DIR,
)
from viz.models_comparison import compare_models_metrics
from viz.rc_viz import (
    plot_maps,
    plot_partial_corr_with_ci,
    plot_pearson_corr,
    plot_rc_curves,
)
from viz.seg_grading import raters_boxplot, raters_diff
from viz.stats import plot_embedding, plot_pca_loadings, plot_qmaps_detection

from analysis.seg_grading.seg_grading import compare_seg_grading

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=PendingDeprecationWarning, module="seaborn")

# auto reload modules when they have changed
%load_ext autoreload
%autoreload 2

In [None]:
# ==================== LOGGING CONFIGURATION ====================
# Modify the LOG_LEVEL variable below to control logging output:
# - logging.DEBUG: Shows all messages including detailed processing steps
# - logging.INFO: Shows general information about pipeline execution (default)
# - logging.WARNING: Shows only warnings and errors
# - logging.ERROR: Shows only errors
# - logging.CRITICAL: Shows only critical errors

LOG_LEVEL = logging.INFO  # <-- Change this to adjust logging level

# Configure logging for Jupyter notebooks
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=LOG_LEVEL,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
)

# Set specific logger levels if needed
notebook_logger = logging.getLogger("notebook")
notebook_logger.setLevel(LOG_LEVEL)
logging.getLogger("matplotlib.category").setLevel(logging.WARNING)

In [None]:
df_dataset = pd.read_csv(Path(DATA_CSV))
df_test = filter_csv_description(df_dataset, split=TEST, max_qc_grade=3)

In [None]:
PRINT_OPTION = "display"


def print_table(df: pd.DataFrame, print_option: str = PRINT_OPTION) -> None:
    """Print table either for display or latex.

    Args:
        df (pd.DataFrame): DataFrame to plot
        print_option (str): The print option. Default to "display.

    """
    match print_option:
        case "Latex":
            print(df.to_latex())  # noqa:T201
        case _:
            display(df)

# Results

In [None]:
cols_rename = {
    "Dice original": "DSC",
    "WDC": "WDC",
    "Avg. Symmetric Surface Distance": "ASSD",
    "Volume Relative Difference": "VRD",
    "Volumetric Similarity": "VS",
    "Hausdorff Distance 95": "HD95",
    "Recall": "SEN",
}

cols_to_keep = [
    "DSC",
    "Precision",
    "SEN",
    "VS",
    "ASSD",
    "VRD",
    "HD95",
    "WDC",
]

## Random Forest & LGBM: Comparision of trials

In [None]:
df_metrics_rf = pd.read_csv(RESULTS_DIR / "RandomForest" / "pr_metrics.csv")
df_metrics_rf = clean_summary_table(df_metrics_rf)

df_metrics_lgbm = pd.read_csv(RESULTS_DIR / "LGBM" / "pr_metrics.csv")
df_metrics_lgbm = clean_summary_table(df_metrics_lgbm)

In [None]:
df_lgbm = df_metrics_lgbm.set_index("Model")
df_rf = df_metrics_rf.set_index("Model")

df_combined = pd.concat(
    {
        "LGBM": df_lgbm,
        "RF": df_rf,
    },
    axis=1,
)
print_table(df_combined)

## Best model - check overfitting

In [None]:
df_lgbm_metrics = pd.read_csv(RESULTS_DIR / "LGBM" / "lgbm_metrics.csv")
df_lgbm_metrics = df_lgbm_metrics.merge(df_dataset[[ID, SPLIT]])
df_lgbm_metrics = df_lgbm_metrics.rename(columns=cols_rename)
df_lgbm_metrics_formated = format_table_results(
    df_lgbm_metrics, cols_to_keep, cols_rename, [SPLIT]
)
print_table(df_lgbm_metrics_formated)

## Model Comparison with other DL methods

In [None]:
pred_patterns = {
    "MDGRU": "*_MDGRU.nii.gz",
    "NNUNET": "*_NNUNET.nii.gz",
    "PGS": "*_PGS.nii.gz",
    "samseg": "*_pred.nii.gz",
    "segcsvd": "thr_*.nii.gz",
    "shiva": "*_pred.nii.gz",
    "whitenet": "*.nii.gz",
    "lgbm_post": "*_pred.nii.gz",
}
model_renames = {
    "MDGRU": "MD-GRU",
    "NNUNET": "nnUnet",
    "PGS": "PGS",
    "samseg": "Samseg",
    "segcsvd": "segcsvd",
    "shiva": "Shiva",
    "whitenet": "WHITE-Net",
    "lgbm_post": "LGBM",
}
df_metrics = fetch_models_results(
    df_test,
    csv_results_name="final_results",
    segmentations_folder="segmentation",
    pred_patterns=pred_patterns,
    models=["lgbm_post"],
)
df_metrics = df_metrics.rename(columns=cols_rename)
df_metrics_formated = format_table_results(
    df_metrics, cols_to_keep, cols_rename, [MODEL], model_renames
)
print_table(df_metrics_formated)

In [None]:
df_metrics_formated = format_table_results(
    df_metrics, cols_to_keep, cols_rename, [MODEL, FAZEKAS], model_renames
)
print_table(df_metrics_formated)

In [None]:
df_metrics_formated = format_table_results(
    df_metrics.query("Fazekas>0"), cols_to_keep, cols_rename, [MODEL], model_renames
)
print_table(df_metrics_formated)

In [None]:
df_metrics[MODEL] = df_metrics[MODEL].apply(lambda x: model_renames[x])
compare_models_metrics(
    df_metrics=df_metrics, save_dir=RESULTS_DIR / "models_comp", palette_name="Lupi"
)

## Lesion analysis

In [None]:
df_patients_stats, df_lesions_stats = compute_lesion_stats(df_dataset)

In [None]:
feature_cols = [col for col in df_patients_stats.columns if col in DEFAULT_MAP_ORDER]
df_stats = compare_qmaps_distributions(
    df_patients_stats, feature_cols, target_col="detected"
)
print_table(df_stats)

In [None]:
fig = plot_qmaps_detection(
    df_patients_stats,
    feature_cols,
    detected_col="detected",
    file_name="lesion_diff_qmaps",
    palette="Kippenberger",
)

In [None]:
df_clean_pca, embeddings_pca, pca = compute_pca(
    df_lesions_stats, feature_cols, "dominant_region"
)

In [None]:
loadings = pd.DataFrame(
    pca.components_.T,
    index=feature_cols,
    columns=[f"PC{i + 1}" for i in range(pca.n_components_)],
)

plot_pca_loadings(loadings)

In [None]:
def get_detection_class(
    percentage_detec: float, thresholds: tuple[int, int] = (0.4, 0.8)
) -> str:
    """Get the detection categories based on percentages.

    Args:
        percentage_detec (float): The percentage detected
        thresholds (tuple[int,int]): Tuple containing the 2 thresholds.
          Default to (0.4,0.8)

    Returns:
        str: Corresponding class

    """
    if percentage_detec == 0:
        return "0%"
    if percentage_detec <= thresholds[0]:
        return "0-40%"
    if percentage_detec <= thresholds[1]:
        return "40-80%"
    return ">80%"


df_clean_pca["detection_class"] = df_clean_pca["percentage_detected"].map(
    get_detection_class
)
df_clean_pca = df_clean_pca.sort_values(by="detection_class", ascending=False)

In [None]:
plot_embedding(df_clean_pca, embeddings_pca, "dominant_region", palette="Kippenberger")

## Retention Curve Analysis

In [None]:
h5_paths = [UQ_RESULTS_DIR / "train_rc.h5", UQ_RESULTS_DIR / "test_rc.h5"]
df_auc_stats = plot_rc_curves(
    h5_paths, segmentation_metric="DSC", palette="Green_Orange_Teal"
)
plt.savefig(UQ_RESULTS_DIR / "rc.png")

In [None]:
latex_tables = {}
df_copy = deepcopy(df_auc_stats)
df_copy["Measure"] = df_copy["Measure"].apply(lambda x: f"${x}$")
df_copy["mean"] = df_copy["Measure"].str.extract(r"^([\d\.]+)").astype(float)
for scale, group in df_copy.groupby("Scale"):
    notebook_logger.info("Scale: %s", scale)
    group_sorted = group.sort_values(by="mean")
    pivot = group_sorted.pivot_table(
        index="Measure", columns="Split", values="Value", aggfunc="first"
    )
    latex_tables[scale] = pivot.to_latex(float_format="%.2f")
    print_table(pivot[["Train", "Test"]].sort_values(by="Train", ascending=True))

In [None]:
entropy = "avg. mean entropy_of_expected"

df_patient_metrics = pd.read_csv(UQ_RESULTS_DIR / "patient_metrics.csv")
df_patient_metrics = df_patient_metrics.merge(df_dataset)

df_corr = calculate_correlation(df_patient_metrics, segmentation_metric="DSC")
print_table(df_corr)
df_correct_corr = correct_correlation(
    df_patient_metrics, entropy_metric=entropy, segmentation_metric="DSC"
)
print_table(df_correct_corr)

In [None]:
order = df_corr.query("split == 'train'").sort_values("rho_median", ascending=False)[
    "entropy_measure_latex"
]

df_corr_format = df_corr.pivot_table(
    index="entropy_measure_latex",
    columns="split",
    values=["value", "p_value"],
    aggfunc="first",
).loc[order]

# reorder columns: train first
df_corr_format = df_corr_format.swaplevel(0, 1, axis=1).sort_index(axis=1)

# flatten column names
df_corr_format.columns = [
    f"{split}_{metric}" for split, metric in df_corr_format.columns
]
df_corr_format = df_corr_format[
    ["train_value", "train_p_value", "test_value", "test_p_value"]
]
print_table(df_corr_format)

In [None]:
plot_pearson_corr(
    entropy, df_patient_metrics, segmentation_metric="DSC", palette="Kippenberger"
)

In [None]:
plot_partial_corr_with_ci(entropy, df_patient_metrics, dice_col="DSC")

In [None]:
df_plot_data = df_patient_metrics[
    df_patient_metrics["id"].isin(["PR05868", "PR06000"])
][["id", "DSC", "avg. mean entropy_of_expected", "Fazekas"]]

df_plot_data = df_plot_data.sort_values(by="DSC", ascending=False)
df_plot_data.index = np.arange(0, len(df_plot_data))
df_plot_data["slice_idx"] = [183, 182]

In [None]:
plot_maps(df_plot_data, model="lgbm_post", segmentation_metric="DSC")

# Segmentation comparision between GT and Prediction

In [None]:
df_mixed_model, df_results = compare_seg_grading(df_dataset)

In [None]:
raters_diff(df_long=df_results)

In [None]:
raters_boxplot(df_results)

In [None]:
print(df_mixed_model.summary())