In [None]:
# -*- authors : Vincent Roduit -*-
# -*- date : 2025-10-14 -*-
# -*- Last revision: 2025-10-14 by Vincent Roduit -*-
# -*- python version : 3.13.7. -*-
# -*- Description: Notebook to test Random Forest model -*-

# <center> inAGE - imaging neuroscience of AGEing </center>
## <center> White Matter Hyperintensity detection  </center>
---

In [None]:
import logging
import sys
import warnings
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from data_func.utils import filter_csv_description
from misc.constants import DATA_CSV, FAZEKAS, RESULTS_DIR, GENERAL_RESULTS_DIR
from models.model_utils import split_test_set
from viz.data_viz import (
    plot_distribution_data,
    plot_lesion_size_distribution,
    plot_map_example,
    plot_mask_analysis,
    plot_processing_hist,
    plot_processing_img,
    plot_representative_slices,
    plot_volume_fazekas,
)
from viz.feature_plots import plot_features

warnings.filterwarnings("ignore", category=PendingDeprecationWarning, module="seaborn")

# auto reload modules when they have changed
%load_ext autoreload
%autoreload 2

In [None]:
# ==================== LOGGING CONFIGURATION ====================
# Modify the LOG_LEVEL variable below to control logging output:
# - logging.DEBUG: Shows all messages including detailed processing steps
# - logging.INFO: Shows general information about pipeline execution (default)
# - logging.WARNING: Shows only warnings and errors
# - logging.ERROR: Shows only errors
# - logging.CRITICAL: Shows only critical errors

LOG_LEVEL = logging.INFO  # <-- Change this to adjust logging level

# Configure logging for Jupyter notebooks
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=LOG_LEVEL,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
)

# Set specific logger levels if needed
notebook_logger = logging.getLogger("notebook")
notebook_logger.setLevel(LOG_LEVEL)
logging.getLogger("matplotlib.category").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

In [None]:
PRINT_OPTION = "display"


def print_table(df: pd.DataFrame, print_option: str = PRINT_OPTION) -> None:
    """Print table either for display or latex.

    Args:
        df (pd.DataFrame): DataFrame to plot
        print_option (str): The print option. Default to "display.

    """
    match print_option:
        case "Latex":
            print(df.to_latex())  # noqa:T201
        case _:
            display(df)

In [None]:
df_dataset = pd.read_csv(Path(DATA_CSV))
notebook_logger.info("Orignal Dataset contains %s subjects", str(len(df_dataset)))

df_dataset = filter_csv_description(df_dataset, split=None)
notebook_logger.info(
    "After filtering missing data and QC, %s  Dataset contains subjects",
    str(len(df_dataset)),
)

In [None]:
df_dataset_stats = (
    df_dataset.groupby(["Fazekas"])
    .agg(
        Age=("Age", "mean"),
        N=("Gender", "size"),
        Male=("Gender", lambda s: (s == "Male").mean() * 100),
        N_train=("split", lambda s: (s == "train").mean()),
    )
    .astype(float)
    .round(2)
)
df_dataset_stats = df_dataset_stats.T
print_table(df_dataset_stats)

# Fazekas description

## Distribution of data

In [None]:
plot_distribution_data(df_dataset)
plt.savefig(Path(RESULTS_DIR / "general" / "fazekas_distribution.png"))

## Cross validation test set

In [None]:
df_test = split_test_set(
    deepcopy(df_dataset.query("split == 'test' and max_grade < 3"))
)
plt.figure(figsize=(16, 9))
sns.countplot(df_test, x="group", hue="Fazekas")
plt.title("Splits for 5-fold cross-validation")
plt.show()
plt.savefig(Path(RESULTS_DIR / "general" / "cv_split_test.png"), bbox_inches="tight")
plt.close()

# Maps examples

In [None]:
plot_map_example()
plt.savefig(Path(RESULTS_DIR / "general" / "map_examples.png"))

# For different Fazekas

In [None]:
subjects = ["PR05769", "PR05786", "PR05741", "PR06040", "PR05739", "PR05887", "PR05868"]
slice_idx = [175, 155, 185, 180, 175, 185, 185]

In [None]:
plot_representative_slices(df_dataset, subjects, slice_idx)
plt.savefig(
    Path(RESULTS_DIR / "general" / "representative_slices_fazekas.png"),
    bbox_inches="tight",
)

# Analyze the tissue segmentation for WMH

In [None]:
df_dataset_lesions, df_wmh_grouped = plot_mask_analysis(df_dataset)
plt.savefig(Path(GENERAL_RESULTS_DIR / "wm_mask_analysis.png"))

In [None]:
df_wmh_grouped.sort_values(by=["mask_type", "percentage"], ascending=False)

# Lesion Sizes for different Fazekas

In [None]:
plot_lesion_size_distribution(df_dataset_lesions.query("lesion_size > 0"))
plt.savefig(Path(RESULTS_DIR / "general" / "lesion_size_distribution.png"))

In [None]:
df_dataset_lesions.groupby("id").agg(
    {"lesion_size": "sum", "Fazekas": "first"}
).groupby("Fazekas").describe()

# Subject-wise lesion size per Fazekas

In [None]:
df_subjects_lesion_sum = (
    df_dataset_lesions.groupby("id")
    .agg({"lesion_size": "sum", FAZEKAS: "first"})
    .reset_index()
)

In [None]:
plot_volume_fazekas(df_dataset_lesions, palette="Blues")
plt.savefig(Path(RESULTS_DIR / "general" / "wmh_volume_fazekas.png"))

# Features Summary

In [None]:
plot_features()
plt.savefig(Path(RESULTS_DIR / "general" / "example_feature_maps.png"))

# Processing effects

In [None]:
plot_processing_img()
plt.savefig(Path(RESULTS_DIR / "general" / "processing_effect_on_img.png"))

In [None]:
plot_processing_hist()
plt.savefig(Path(RESULTS_DIR / "general" / "processing_effect_on_hist.png"))