diff --git a/tutorials/abide.ipynb b/tutorials/abide.ipynb
index a648f1b..388798a 100644
--- a/tutorials/abide.ipynb
+++ b/tutorials/abide.ipynb
@@ -11,19 +11,23 @@
"tags": []
},
"source": [
- "# Reducing Phenotypical Effect to Improve Multi-site Autism Classification Performance\n",
+ "# Reducing Phenotypic Effects to Improve Multi-site Autism Classification Performance\n",
"\n",
- "In this tutorial, we will show how to leverage patient's phenotypic information to reduce the site-dependencies of functional connectivity data using domain adaptation for improving multi-site autism classification performance.\n",
+ "In this tutorial, we demonstrate how to leverage **patient phenotypic information** to reduce **site-specific biases** in functional connectivity data using **domain adaptation**, with the goal of improving **multi-site autism classification**.\n",
"\n",
- "The basis of this notebook is to extend the original work by Kunda et al. from IEEE TMI 2022 that proposes a second-order functional connectivity named Tangent-Pearson, the Tangent correlation of the correlation and the application of domain adaptation for neuroimaging to reduce site-dependencies.\n",
+ "This notebook builds on the work of **Kunda et al. (IEEE TMI, 2022)**, which introduced a second-order functional connectivity representation called **Tangent-Pearson**, the tangent embedding of the Pearson correlation matrix. The original work also applied domain adaptation to reduce site dependencies in fMRI-derived features, using **site labels** as the domain information.\n",
"\n",
- "Kunda et al. previously applied domain adaptation by only leveraging the site labels. We will extend this work by applying domain adaptation using various phenotypic information like sex, handedness, age, and eye status.\n",
+ "We extend this approach by incorporating a **richer set** of phenotypic variables, such as sex, handedness, age, and eye status into the domain adaptation framework. This enables more effective harmonization across data collected from different imaging sites.\n",
"\n",
- "Our **objectives** are to:\n",
- "1. Load the ABIDE dataset with its available preprocessing pipelines and atlasses.\n",
- "2. Extracting functional connectivity from the time series extracted from the preprocessed scans.\n",
- "3. Preprocess the phenotypic information to be used for domain adaptation and obtaining the classification and site labels.\n",
- "4. Creating a pipeline to train and evaluate the performance."
+ "---\n",
+ "\n",
+ "**Objectives**\n",
+ "\n",
+ "1.\t**Load** the ABIDE dataset using different preprocessing pipelines and brain atlases.\n",
+ "2.\t**Extract** functional connectivity features from ROI-based time series.\n",
+ "3.\t**Preprocess** phenotypic variables for use in domain adaptation, and obtain class labels (ASD vs CONTROL) and site labels.\n",
+ "4.\t**Build** a training and evaluation pipeline to assess classification performance under various domain adaptation strategies.\n",
+ "5.\t**Interpret** the learned model by extracting weights for pairwise ROI feature importance and visualizing them using a connectome plot."
]
},
{
@@ -39,7 +43,29 @@
"source": [
"# Setup\n",
"\n",
- "As a starting point, we need to install some packages and have provided helper functions to assist in this tutorial."
+ "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "375bb16c-a668-4582-93c2-3642ef82baf2",
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": [
+ "hide-input"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import warnings\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "os.environ[\"PYTHONWARNINGS\"] = \"ignore\""
]
},
{
@@ -55,12 +81,16 @@
"source": [
"## Packages\n",
"\n",
- "The packages that we will require for this tutorial includes PyKale and Nilearn. PyKale is an interdisciplinary machine learning library internally developed at the University of Sheffield and Nilearn is a neuroimaging library mainly intended for fMRI data analysis."
+ "The main packages required for this tutorial are PyKale and Nilearn.\n",
+ "\n",
+ "**PyKale** is an open-source interdisciplinary machine learning library developed at the University of Sheffield, with a focus on applications in biomedical and scientific domains.\n",
+ "\n",
+ "**Nilearn** is a Python library for neuroimaging analysis, widely used for processing and visualizing functional MRI (fMRI) data."
]
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 2,
"id": "07105910",
"metadata": {
"editable": true,
@@ -101,7 +131,7 @@
"source": [
"## Helper Functions\n",
"\n",
- "The helper functions is adopted from the [source code](https://github.com/zaRizk7/abide-demo) originally intended for demonstrating the use cases of containers to improve reproducibility and reusability in ML experiments."
+ "The helper functions used in this tutorial are adapted from the [source code](https://github.com/zaRizk7/abide-demo), which was originally developed to demonstrate the use of containers for improving reproducibility and reusability in machine learning experiments."
]
},
{
@@ -109,20 +139,25 @@
"id": "a6657912-3cd9-43e7-a73c-783ccbfc7442",
"metadata": {
"editable": true,
+ "jp-MarkdownHeadingCollapsed": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
- "### Feature Extractions\n",
+ "### Feature Extraction\n",
"\n",
- "Contains the imputation, categorical mapping, continuous standardization for the selected phenotypes and chaining functional connectivity extraction for the time series. "
+ "Includes functionality for:\n",
+ "- **Imputation** of missing phenotypic values,\n",
+ "- **Categorical encoding** of variables such as sex, handedness, and eye status,\n",
+ "- **Standardization** of continuous features like age and FIQ,\n",
+ "- **Chaining functional connectivity transformations** (e.g., Pearson, Tangent) to extract features from ROI-based time series."
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"id": "4c15c779-51c4-4942-90f4-63d0ca889a5f",
"metadata": {
"editable": true,
@@ -339,12 +374,14 @@
"source": [
"### Trainer\n",
"\n",
- "Contains the hyperparameter grid and a wrapper function to train the pipeline with or without domain adaptation."
+ "Defines the **hyperparameter search spac**e and provides a **wrapper function** to train the classification pipeline, with or without domain adaptation.\n",
+ "\n",
+ "Providing various search strategies (e.g., grid or randomized search) and integrates seamlessly with cross-validation and MIDA-based feature transformation when enabled."
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"id": "c7b8c1e8-263a-4692-9cec-32cfc3078a80",
"metadata": {
"editable": true,
@@ -395,11 +432,9 @@
"\n",
"MIDA_GRID = {\n",
" \"num_components\": [32, 64, 128, 256, None],\n",
- " \"kernel\": [\"linear\", \"rbf\"],\n",
" \"mu\": 1 / (2 * C),\n",
" \"eta\": 1 / (2 * C),\n",
" \"ignore_y\": [True, False],\n",
- " \"augment\": [True, False],\n",
"}\n",
"MIDA_GRID = {f\"domain_adapter__{key}\": value for key, value in MIDA_GRID.items()}\n",
"\n",
@@ -533,42 +568,107 @@
},
{
"cell_type": "markdown",
- "id": "c6c98c7e-9e0b-477e-b470-d89780867bd2",
+ "id": "36d5a108-f0a5-45d4-b517-2b48cea84e8c",
"metadata": {
"editable": true,
+ "jp-MarkdownHeadingCollapsed": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
- "# Pipeline"
+ "### Evaluation\n",
+ "\n",
+ "Contains a function to **aggregate the top-1 cross-validation scores** for each defined model, selecting the best-performing configuration per model based on a specified evaluation metric."
]
},
{
- "cell_type": "markdown",
- "id": "b43b0320-dffe-4233-bd34-9785770bbee6",
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "5af2b149-a2c3-4dfc-bba9-038259e4fe8a",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
- "tags": []
+ "tags": [
+ "hide-input"
+ ]
},
+ "outputs": [],
"source": [
- "## Resting-state fMRI Preprocessing\n",
+ "import pandas as pd\n",
+ "from collections import defaultdict\n",
+ "from sklearn.utils._param_validation import validate_params, StrOptions\n",
+ "\n",
+ "# Mapping for model and score display names\n",
+ "MODEL = [\"baseline\", \"site_only\", \"all_phenotypes\"]\n",
+ "MODEL = {model: \" \".join(model.split(\"_\")).title() for model in MODEL}\n",
+ "\n",
+ "SCORE = [\"accuracy\", \"precision\", \"recall\", \"f1\"]\n",
+ "SCORE = {score: score.title() for score in SCORE}\n",
+ "SCORE[\"roc_auc\"] = \"AUROC\"\n",
+ "SCORE[\"matthews_corrcoef\"] = \"MCC\"\n",
+ "\n",
+ "\n",
+ "@validate_params(\n",
+ " {\"cv_results\": [dict], \"sort_by\": [StrOptions(set(SCORE))]},\n",
+ " prefer_skip_nested_validation=True,\n",
+ ")\n",
+ "def compile_results(cv_results, sort_by):\n",
+ " \"\"\"\n",
+ " Compile and summarize cross-validation results into a formatted DataFrame.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " cv_results : dict of str -> pd.DataFrame or dict of str -> dict of str -> list\n",
+ " Dictionary mapping model names to cross-validation results.\n",
+ " Each entry should either be a DataFrame or a dictionary of dictionary of list.\n",
+ " sort_by : str\n",
+ " Metric to use for selecting the best-performing model variant.\n",
+ " Available ones include: \"accuracy\", \"precision\", \"recall\", \"f1\", \"roc_auc\",\n",
+ " and \"matthews_corrcoef\".\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " compiled_results : pd.DataFrame\n",
+ " Summary table with models as rows and formatted score strings (mean ± std) as columns.\n",
+ " \"\"\"\n",
+ " compiled_results = defaultdict(list)\n",
"\n",
- "Usually, we need to preprocess the fMRI scans first before running the pipeline. However, ABIDE dataset provides several preprocessed subsets that can be downloaded directly. The ones we are going to focus on includes:\n",
- "- `atlas`: Brain atlas used for extracting the time series. Available ones are: `\"aal\"`, `\"cc200\"`, `\"cc400\"`, `\"dosenbach160\"`, `\"ez\"`, `\"ho\"`, and `\"tt\"`. Default: `\"cc200\"`.\n",
- "- `bp`: Band-pass filter signals between 0.01Hz and 0.1Hz. Default: `False`.\n",
- "- `gsr`: Applies global signal regression on the signals. Default: `False`.\n",
- "- `qc`: Only use scans that passes all quality checks. Default: `True`."
+ " for model in cv_results:\n",
+ " # Ensure results are in DataFrame format\n",
+ " if not isinstance(cv_results[model], pd.DataFrame):\n",
+ " cv_results[model] = pd.DataFrame(cv_results[model])\n",
+ "\n",
+ " # Extract all available test scores\n",
+ " scores = [\n",
+ " score.replace(\"rank_test_\", \"\")\n",
+ " for score in cv_results[model].columns\n",
+ " if \"rank_test\" in score\n",
+ " ]\n",
+ "\n",
+ " # Select the best row (lowest rank) based on the given metric\n",
+ " cv_result = cv_results[model].sort_values(f\"rank_test_{sort_by}\").iloc[0]\n",
+ "\n",
+ " compiled_results[\"Model\"].append(MODEL[model])\n",
+ "\n",
+ " for score in scores:\n",
+ " mean_score = cv_result[f\"mean_test_{score}\"]\n",
+ " std_score = cv_result[f\"std_test_{score}\"]\n",
+ " compiled_results[SCORE[score]].append(f\"{mean_score:.4f} ± {std_score:.4f}\")\n",
+ "\n",
+ " # Convert to DataFrame and index by model name\n",
+ " compiled_results = pd.DataFrame(compiled_results)\n",
+ " compiled_results = compiled_results.set_index(\"Model\")\n",
+ "\n",
+ " return compiled_results\n"
]
},
{
- "cell_type": "code",
- "execution_count": 4,
- "id": "af7f3a6f-5762-4191-9e0d-f3e3eaa33003",
+ "cell_type": "markdown",
+ "id": "e35c2ebb-0811-4746-8595-6610b9f55c02",
"metadata": {
"editable": true,
"slideshow": {
@@ -576,48 +676,198 @@
},
"tags": []
},
- "outputs": [],
"source": [
- "atlas = \"cc200\"\n",
- "bp = False\n",
- "gsr = False\n",
- "qc = True"
+ "### Visualization\n",
+ "\n",
+ "Provides utility functions to:\n",
+ "- Generate **pairwise ROI labels**,\n",
+ "- Extract the **top-p most important weights** and convert them into a **symmetric matrix**,\n",
+ "- Wrap the process of **plotting connectomes** for visual interpretation of model-derived feature importances."
]
},
{
"cell_type": "code",
- "execution_count": 5,
- "id": "bff0d762-27f6-4080-8918-38e4fb844a83",
+ "execution_count": 6,
+ "id": "21510b19-195d-4e90-9f79-7861a247e0d0",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
- "tags": []
+ "tags": [
+ "hide-input"
+ ]
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[get_dataset_dir] Dataset found in /home/zarizky/nilearn_data/ABIDE_pcp\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "from nilearn.datasets import fetch_abide_pcp\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from itertools import combinations\n",
+ "from matplotlib.cm import get_cmap\n",
+ "from nilearn.plotting import plot_connectome\n",
+ "from sklearn.utils._param_validation import Interval, validate_params, Real\n",
"\n",
- "dataset = fetch_abide_pcp(\n",
- " derivatives=[f\"rois_{atlas}\"],\n",
- " band_pass_filtering=bp,\n",
- " global_signal_regression=gsr,\n",
- " quality_checked=qc,\n",
- ")"
+ "\n",
+ "@validate_params(\n",
+ " {\"rois\": [\"array-like\"], \"sep\": [str]}, prefer_skip_nested_validation=True\n",
+ ")\n",
+ "def get_pairwise_rois(rois, sep=\"---\"):\n",
+ " \"\"\"\n",
+ " Generate all unique ROI pair labels (upper triangle only).\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " rois : array-like of str or int\n",
+ " List of region-of-interest (ROI) names.\n",
+ " sep : str, optional\n",
+ " Separator string used to join ROI pairs, by default '---'.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " pairs : np.ndarray\n",
+ " Array of ROI pair strings in the format \"ROI1 --- ROI2\".\n",
+ " \"\"\"\n",
+ " pairs = [f\" {sep} \".join((a, b)) for a, b in combinations(rois, 2)]\n",
+ " return np.array(pairs)\n",
+ "\n",
+ "\n",
+ "@validate_params(\n",
+ " {\n",
+ " \"weights\": [\"array-like\"],\n",
+ " \"labels\": [\"array-like\"],\n",
+ " \"coords\": [\"array-like\"],\n",
+ " \"p\": [Interval(Real, 0, 1, closed=\"neither\")],\n",
+ " \"sep\": [str],\n",
+ " },\n",
+ " prefer_skip_nested_validation=True,\n",
+ ")\n",
+ "def get_top_symmetric_weight(weights, labels, coords, p=0.001, sep=\"---\"):\n",
+ " \"\"\"\n",
+ " Construct a symmetric weight matrix for top-p ROI pairs.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " weights : array-like\n",
+ " 1D array of weights corresponding to pairwise ROI combinations.\n",
+ " labels : array-like of str or int\n",
+ " List of all ROI names or indices in the original data.\n",
+ " coords : array-like\n",
+ " Coordinates of each ROI, shape (n_rois, 3).\n",
+ " p : float, optional\n",
+ " Proportion of top weights to retain (default is 0.001).\n",
+ " sep : str, optional\n",
+ " Separator used in ROI pair labels, by default '---'.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " sym_weights : array-like\n",
+ " Symmetric matrix of selected top weights (n_top_rois, n_top_rois).\n",
+ " top_roi_labels : array-like\n",
+ " Labels of ROIs included in the top-p weight pairs.\n",
+ " top_roi_coords : array-like\n",
+ " Coordinates corresponding to `top_roi_labels`.\n",
+ " \"\"\"\n",
+ " weights = pd.Series(np.copy(weights), index=get_pairwise_rois(labels))\n",
+ " rank = weights.abs().nlargest(int(len(weights.index) * p))\n",
+ " weights = weights[rank.index]\n",
+ "\n",
+ " pairs = np.array([[roi.strip() for roi in col.split(sep)] for col in weights.index])\n",
+ " unique = np.unique(pairs)\n",
+ "\n",
+ " label_to_index = {label: idx for idx, label in enumerate(labels)}\n",
+ " indices = [label_to_index[roi] for roi in unique]\n",
+ "\n",
+ " top_labels = np.array(labels)[indices]\n",
+ " top_coords = np.array(coords)[indices]\n",
+ " mappings = {roi: idx for idx, roi in enumerate(top_labels)}\n",
+ "\n",
+ " sym_weights = np.zeros([len(indices)] * 2)\n",
+ " for (roi1, roi2), weight in zip(pairs, weights.values):\n",
+ " i, j = mappings[roi1], mappings[roi2]\n",
+ " sym_weights[i, j] = weight\n",
+ " sym_weights[j, i] = weight\n",
+ "\n",
+ " return sym_weights, top_labels, top_coords\n",
+ "\n",
+ "\n",
+ "@validate_params(\n",
+ " {\n",
+ " \"weights\": [\"array-like\"],\n",
+ " \"labels\": [\"array-like\"],\n",
+ " \"coords\": [\"array-like\"],\n",
+ " \"p\": [Interval(Real, 0, 1, closed=\"neither\")],\n",
+ " \"cmap\": [str],\n",
+ " \"marker_size\": [Real],\n",
+ " \"legend_params\": [dict],\n",
+ " },\n",
+ " prefer_skip_nested_validation=True,\n",
+ ")\n",
+ "def visualize_connectome(\n",
+ " weights, labels, coords, p=1e-3, cmap=\"tab20\", marker_size=100, legend_params={}\n",
+ "):\n",
+ " \"\"\"\n",
+ " Visualize the top-p weighted ROI connections as a symmetric connectome plot.\n",
+ "\n",
+ " This function selects the top proportion `p` of the largest (by absolute value)\n",
+ " weights between region pairs, constructs a symmetric connectivity matrix,\n",
+ " and plots the corresponding connectome using `nilearn.plotting.plot_connectome`.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " weights : array-like of shape (n_edges,)\n",
+ " Weights assigned to each unique pair of ROIs, typically from a model or analysis.\n",
+ " These are expected to align with the order of ROI pairs generated from `labels`.\n",
+ "\n",
+ " labels : array-like of shape (n_rois,)\n",
+ " List of ROI (region of interest) names corresponding to the original brain atlas.\n",
+ "\n",
+ " coords : array-like of shape (n_rois, 3)\n",
+ " 3D coordinates for each ROI, used to place nodes in the plot.\n",
+ "\n",
+ " p : float, optional\n",
+ " Proportion of the top-weighted connections (by absolute value) to include.\n",
+ " Must be in the open interval (0, 1). Default is 0.001 (0.1%).\n",
+ "\n",
+ " cmap : str, optional\n",
+ " Matplotlib colormap name used to assign colors to ROI nodes. Default is 'tab20'.\n",
+ "\n",
+ " marker_size : float, optional\n",
+ " Size of ROI node markers in the connectome plot. Default is 100.\n",
+ "\n",
+ " legend_params : dict, optional\n",
+ " Additional keyword arguments to pass to the plot's legend (e.g., location, fontsize).\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " proj : nilearn.plotting.displays._projectors.ConnectivityProjection\n",
+ " A `nilearn` projection object with the plotted connectome. This object supports\n",
+ " further customization (e.g., adding markers, saving the figure).\n",
+ "\n",
+ " Notes\n",
+ " -----\n",
+ " - This function assumes the weights are symmetric or that symmetry should be imposed.\n",
+ " - Useful for visualizing model interpretability or structural/functional connectivity.\n",
+ " \"\"\"\n",
+ " marker_colors = get_cmap(cmap)(np.arange(len(labels)))\n",
+ " sym_weights, labels, coords = get_top_symmetric_weight(weights, labels, coords, p)\n",
+ " proj = plot_connectome(sym_weights, coords, colorbar=True)\n",
+ "\n",
+ " for i in range(len(labels)):\n",
+ " proj.add_markers(\n",
+ " [coords[i]],\n",
+ " marker_color=marker_colors[i],\n",
+ " marker_size=marker_size,\n",
+ " label=labels[i],\n",
+ " )\n",
+ "\n",
+ " proj.axes[next(iter(proj.axes))].ax.legend(**legend_params)\n",
+ "\n",
+ " return proj\n"
]
},
{
"cell_type": "markdown",
- "id": "5498cc9d-2157-4372-8a55-53c39c492ed5",
+ "id": "c6c98c7e-9e0b-477e-b470-d89780867bd2",
"metadata": {
"editable": true,
"slideshow": {
@@ -626,20 +876,12 @@
"tags": []
},
"source": [
- "## Phenotype Preprocessing \n",
- "\n",
- "The phenotypic information comes with several missing data. To utilize it for modeling, we need to impute and encode the missing values. Categorical phenotypes that we will use are `SITE_ID`, `SEX`, `HANDEDNESS_CATEGORY`, `EYE_STATUS_AT_SCAN` which will be one-hot encoded while the continuous ones including `AGE_AT_SCAN` and `FIQ` will be optionally standardized by defining the argument options:\n",
- "- `standardize`: Standardization strategy for subject's age and FIQ. If set to `True` or `\"all\"`, standardizes the values over all subjects while `\"site\"` standardizes according to the sites. Default: `False`.\n",
- "\n",
- "There are several missing `HANDEDNESS_CATEGORY` values, we consider that the missing ones by default are right-handed subjects while for `FIQ`, we impute the missing values by setting them to `100`.\n",
- "\n",
- "The labels that assigns control and ASD group is the `DX_GROUP` phenotype that is binary encoded such that `CONTROL` and `ASD` is assigned to `0` and `1` respectively."
+ "# Pipeline"
]
},
{
- "cell_type": "code",
- "execution_count": 6,
- "id": "0996603e-3a81-4277-beab-20f189af09ed",
+ "cell_type": "markdown",
+ "id": "b43b0320-dffe-4233-bd34-9785770bbee6",
"metadata": {
"editable": true,
"slideshow": {
@@ -647,30 +889,61 @@
},
"tags": []
},
- "outputs": [],
"source": [
- "standardize = \"site\""
+ "## Resting-state fMRI Preprocessing\n",
+ "\n",
+ "Typically, raw fMRI scans require extensive preprocessing before they can be used in a machine learning pipeline. However, the **ABIDE** dataset provides several preprocessed derivatives, which can be downloaded directly from the [Preprocessed Connectomes Project (PCP)](https://preprocessed-connectomes-project.org/abide/), eliminating the need for manual preprocessing.\n",
+ "\n",
+ "In this tutorial, we focus on the following preprocessing options:\n",
+ "- `atlas`: The **brain atlas** used to **extract ROI time series**. Available options include: `\"aal\"`, `\"cc200\"`, `\"cc400\"`, `\"dosenbach160\"`, `\"ez\"`, `\"ho\"`, and `\"tt\"`. Default: `\"aal\"`\n",
+ "- `bp`: Whether to apply **band-pass filter** to the time series between [0.01, 0.1] Hz. Default: `False`\n",
+ "- `gsr`: Whether to apply **global signal regression** to remove shared global noise from the signals. Default: `False`\n",
+ "- `qc`: Whether to include **only scans that passed all quality checks** provided by the dataset curators. Default: `True`"
]
},
{
"cell_type": "code",
"execution_count": 7,
- "id": "2d4f160c-ed31-4f7e-9215-ba8f043576a3",
+ "id": "bff0d762-27f6-4080-8918-38e4fb844a83",
"metadata": {
"editable": true,
+ "scrolled": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[get_dataset_dir] Dataset found in /home/zarizky/nilearn_data/ABIDE_pcp\n"
+ ]
+ }
+ ],
"source": [
- "labels, sites, phenotypes = process_phenotypic_data(dataset[\"phenotypic\"], standardize)"
+ "from nilearn.datasets import fetch_abide_pcp\n",
+ "\n",
+ "# Define preprocessing options\n",
+ "atlas = \"aal\" # Brain atlas used to extract ROI time series (e.g., AAL atlas)\n",
+ "bp = False # Apply band-pass filtering (0.01–0.1 Hz) to remove low- and high-frequency noise\n",
+ "gsr = False # Do not apply global signal regression (GSR)\n",
+ "qc = True # Include all subjects regardless of quality control status\n",
+ "\n",
+ "# Fetch the preprocessed ABIDE dataset using the specified preprocessing options\n",
+ "# This returns a dictionary containing region-wise time series and associated metadata\n",
+ "dataset = fetch_abide_pcp(\n",
+ " derivatives=[f\"rois_{atlas}\"], # Select the atlas-specific ROI time series (e.g., 'rois_aal')\n",
+ " band_pass_filtering=bp, # Whether to apply band-pass filtering\n",
+ " global_signal_regression=gsr, # Whether to apply global signal regression\n",
+ " quality_checked=qc, # Whether to include only subjects that passed QC\n",
+ ")"
]
},
{
"cell_type": "markdown",
- "id": "c50bbd33-332b-4288-8def-beec31b83adf",
+ "id": "5498cc9d-2157-4372-8a55-53c39c492ed5",
"metadata": {
"editable": true,
"slideshow": {
@@ -679,15 +952,44 @@
"tags": []
},
"source": [
- "## Feature Extraction\n",
+ "## Phenotype Preprocessing \n",
+ "\n",
+ "The phenotypic information in the dataset contains several missing values. We impute and encode it to make it suitable for modeling.\n",
+ "\n",
+ "**Categorical Variables**\n",
+ "\n",
+ "The following categorical phenotypes are used and will be **one-hot encoded**:\n",
+ "- `SITE_ID`\n",
+ "- `SEX`\n",
+ "- `HANDEDNESS_CATEGORY`\n",
+ "- `EYE_STATUS_AT_SCAN`\n",
+ "\n",
+ "**Continuous Variables**\n",
+ "\n",
+ "The following continuous phenotypes will optionally be **standardized**:\n",
+ "- `AGE_AT_SCAN`\n",
+ "- `FIQ`\n",
+ "\n",
+ "Possible options to `standardize` the continuous phenotypes includes:\n",
+ "- `\"all\"` or `True`: Standardize across all subjects\n",
+ "- `\"site\"`: Standardize within each site\n",
+ "- `False`: No standardization\n",
"\n",
- "- `measures`: Sequences of connectivity measure transformation to extract features from the time series. Available ones are `\"pearson\"`, `\"partial\"`, `\"tangent\"`, `\"covariance\"`, and `\"precision\"`. Default: `[\"pearson\"]`."
+ "**Handling Missing Values**\n",
+ "- `HANDEDNESS_CATEGORY`: Missing values are assumed to correspond to `right-handed` subjects.\n",
+ "- `FIQ`: Missing values are imputed with a default score of `100`.\n",
+ "\n",
+ "**Label Encoding**\n",
+ "\n",
+ "The diagnostic label `DX_GROUP` is used to assign the target class:\n",
+ "- `CONTROL` → `0`\n",
+ "- `ASD` → `1`"
]
},
{
"cell_type": "code",
"execution_count": 8,
- "id": "35b98cf5-7e46-4ab7-a746-e76c9794a97f",
+ "id": "2d4f160c-ed31-4f7e-9215-ba8f043576a3",
"metadata": {
"editable": true,
"slideshow": {
@@ -697,13 +999,24 @@
},
"outputs": [],
"source": [
- "measures = [\"pearson\"]"
+ "standardize = \"site\" # Standardize continuous phenotypes (e.g., age, FIQ) within each site\n",
+ "\n",
+ "# Process the phenotypic metadata from the ABIDE dataset\n",
+ "# This function handles:\n",
+ "# - Imputation of missing values (e.g., assuming right-handed for missing handedness)\n",
+ "# - One-hot encoding of categorical variables (e.g., sex, site, eye status)\n",
+ "# - Standardization of continuous variables based on the chosen strategy ('site' or 'all')\n",
+ "\n",
+ "# Returns:\n",
+ "# - `labels`: Binary class labels (0 = control, 1 = ASD)\n",
+ "# - `sites`: Site identifiers for domain adaptation\n",
+ "# - `phenotypes`: Feature matrix containing encoded and standardized phenotypic variables\n",
+ "labels, sites, phenotypes = process_phenotypic_data(dataset[\"phenotypic\"], standardize)"
]
},
{
- "cell_type": "code",
- "execution_count": 9,
- "id": "62495790-b83e-439f-9860-31af63da8c4f",
+ "cell_type": "markdown",
+ "id": "c50bbd33-332b-4288-8def-beec31b83adf",
"metadata": {
"editable": true,
"slideshow": {
@@ -711,14 +1024,27 @@
},
"tags": []
},
- "outputs": [],
"source": [
- "features = extract_functional_connectivity(dataset[f\"rois_{atlas}\"], measures)"
+ "## Feature Extraction\n",
+ "\n",
+ "Functional MRI (fMRI) time series data often vary in temporal length. However, many machine learning models, including those used in this study require fixed-size input. To address this, a common approach in fMRI analysis is to compute the functional connectivity (e.g., correlation) between regions of interest (ROIs), resulting in a fixed-size feature representation.\n",
+ "\n",
+ "Specifically, we compute a connectivity matrix for each subject, and extract the upper or lower triangular part (excluding the diagonal) to obtain a feature vector suitable for model training.\n",
+ "\n",
+ "The available arguments for feature extraction are:\n",
+ "- `measures`: A sequence of connectivity transformations applied to the ROI time series. Supported options include: `\"pearson\"`, `\"partial\"`, `\"tangent\"`, `\"covariance\"`, and `\"precision\"`. Default: `[\"pearson\"]`.\n",
+ "\n",
+ "Multiple transformations can be chained to compute composite connectivity representations. For example, the **Tangent-Pearson** method proposed by *Kunda et al.* can be specified via `measures = [\"tangent\", \"pearson\"]`. This design also allows for future extensions to support higher-order connectivity features.\n",
+ "\n",
+ "```{warning}\n",
+ "Given the long runtime needed for Tangent-Pearson, we opt to use `\"pearson\"` by default.\n",
+ "```"
]
},
{
- "cell_type": "markdown",
- "id": "a4083ba3-67bf-4dc4-88f2-22b798cb8156",
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "62495790-b83e-439f-9860-31af63da8c4f",
"metadata": {
"editable": true,
"slideshow": {
@@ -726,13 +1052,21 @@
},
"tags": []
},
+ "outputs": [],
"source": [
- "## Modeling"
+ "measures = [\"pearson\"] # Use Pearson correlation to compute functional connectivity\n",
+ "\n",
+ "# Extract functional connectivity features from ROI time series using the specified measure\n",
+ "# - `dataset[f\"rois_{atlas}\"]`: Time series data extracted from the selected brain atlas (e.g., AAL)\n",
+ "# - `measures`: List of connectivity measures to apply; in this case, only Pearson correlation is used\n",
+ "# to compute pairwise correlations between ROI time series, resulting in a symmetric connectivity matrix\n",
+ "\n",
+ "features = extract_functional_connectivity(dataset[f\"rois_{atlas}\"], measures)"
]
},
{
"cell_type": "markdown",
- "id": "b1e3cbda-d9e9-4a81-81fc-49fa61c5b49e",
+ "id": "a4083ba3-67bf-4dc4-88f2-22b798cb8156",
"metadata": {
"editable": true,
"slideshow": {
@@ -741,13 +1075,18 @@
"tags": []
},
"source": [
- "### Random Seed"
+ "## Modeling\n",
+ "\n",
+ "We define and train machine learning models for classifying autism spectrum disorder (ASD) using functional connectivity features.\n",
+ "\n",
+ "We explore different configurations including a baseline model, domain adaptation using site information, and an extended approach that incorporates additional phenotypic variables.\n",
+ "\n",
+ "Each model is evaluated using cross-validation, and we analyze the impact of domain adaptation on classification performance."
]
},
{
- "cell_type": "code",
- "execution_count": 10,
- "id": "b0081878-08f4-4614-a71f-486b919ea1fa",
+ "cell_type": "markdown",
+ "id": "b1e3cbda-d9e9-4a81-81fc-49fa61c5b49e",
"metadata": {
"editable": true,
"slideshow": {
@@ -755,14 +1094,15 @@
},
"tags": []
},
- "outputs": [],
"source": [
- "seed = 0"
+ "### Random Seed\n",
+ "\n",
+ "To ensure reproducibility across runs, we define a fixed random seed. This guarantees that all operations involving randomness, such as cross-validation splits, model initialization, and hyperparameter search to produce consistent results."
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"id": "e6012573-fde9-44cd-b204-f67dbd61a9ae",
"metadata": {
"editable": true,
@@ -775,6 +1115,10 @@
"source": [
"from sklearn.utils.validation import check_random_state\n",
"\n",
+ "seed = 0 # Set a fixed seed for reproducibility\n",
+ "\n",
+ "# Convert the seed into a numpy-compatible RandomState instance\n",
+ "# This ensures consistent behavior across scikit-learn functions that rely on randomness\n",
"random_state = check_random_state(seed)"
]
},
@@ -789,30 +1133,21 @@
"tags": []
},
"source": [
- "### Cross-Validation Split"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "b2c71cfa-0222-4187-ae19-062c008ca841",
- "metadata": {
- "editable": true,
- "slideshow": {
- "slide_type": ""
- },
- "tags": []
- },
- "outputs": [],
- "source": [
- "split = \"skf\"\n",
- "num_folds = 10\n",
- "num_cv_repeats = 5"
+ "### Cross-Validation Split\n",
+ "\n",
+ "To evaluate model performance reliably, we define a cross-validation (CV) strategy. By default, we use **Repeated Stratified K-Fold**, which preserves class distribution across folds and supports repeated trials for more stable estimates.\n",
+ "\n",
+ "Alternatively, we can also use **Leave-P-Groups-Out (LPGO)** cross-validation. This strategy is particularly useful in multi-site studies, as it ensures that data from the same group (e.g., imaging site) are not shared between training and test sets, enabling more realistic generalization assessment under domain shift.\n",
+ "\n",
+ "For this tutorial we will specify several arguments:\n",
+ "- `split`: Defines the cross-validation strategy. `\"skf\"` for stratified k-fold to maintain label balance in each fold or use `\"lpgo\"` to evaluate generalization across sites by holding out entire groups (e.g., imaging sites). Default: `\"lpgo\"`\n",
+ "- `num_folds`: Sets how many folds to use for stratified k-fold or how many groups to leave out in LPGO. Default: `1`\n",
+ "- `num_cv_repeats`: Determines how many times the k-fold procedure is repeated to obtain more stable estimates (ignored when using LPGO). Default: `1`"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 11,
"id": "1e1e926a-e7b0-41a5-ab1f-9d35c9cad8bd",
"metadata": {
"editable": true,
@@ -825,14 +1160,22 @@
"source": [
"from sklearn.model_selection import LeavePGroupsOut, RepeatedStratifiedKFold\n",
"\n",
+ "split = \"lpgo\" # Cross-validation split strategy\n",
+ "num_folds = 1 # Number of folds (or groups to leave out)\n",
+ "num_cv_repeats = 1 # Number of repetitions (used only for skf)\n",
+ "\n",
+ "# Define the default cross-validation strategy:\n",
+ "# Repeated stratified k-fold maintains class distribution across folds and supports multiple repetitions\n",
"cv = RepeatedStratifiedKFold(\n",
- " n_splits=num_folds,\n",
- " n_repeats=num_cv_repeats,\n",
- " random_state=random_state,\n",
+ " n_splits=num_folds, # Number of stratified folds\n",
+ " n_repeats=num_cv_repeats, # Number of repeat rounds\n",
+ " random_state=random_state, # Ensures reproducibility\n",
")\n",
"\n",
+ "# Override with leave-p-proups-out if specified\n",
+ "# This strategy holds out `p` unique groups (e.g., sites) per fold, enabling group-level generalization\n",
"if split == \"lpgo\":\n",
- " cv = LeavePGroupsOut(num_folds)"
+ " cv = LeavePGroupsOut(num_folds) # Use group-based CV for domain adaptation or site bias evaluation"
]
},
{
@@ -846,34 +1189,24 @@
"tags": []
},
"source": [
- "### Model Definition"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "id": "9a84d614-69b4-4bf7-beef-a2e1f9d1cf15",
- "metadata": {
- "editable": true,
- "slideshow": {
- "slide_type": ""
- },
- "tags": []
- },
- "outputs": [],
- "source": [
- "classifier = \"logistic\"\n",
- "mida = False\n",
- "search_strategy = \"random\"\n",
- "scoring = [\"accuracy\", \"roc_auc\"]\n",
- "num_solver_iterations = 100\n",
- "num_search_iterations = 10\n",
- "num_jobs = None"
+ "### Model Definition\n",
+ "We define different model configurations used for classification. Each model shares the same base classifier (e.g., logistic regression), but differs in how domain adaptation is applied:\n",
+ "- **Baseline**: A standard model trained directly on functional connectivity features without domain adaptation.\n",
+ "- **Site Only**: A domain-adapted model that uses site labels as the adaptation factor, reducing site-specific bias.\n",
+ "- **All Phenotypes**: An extended domain-adapted model that incorporates multiple phenotypic variables (e.g., age, sex, handedness) to further reduce inter-site variability.\n",
+ "\n",
+ "We also specify the hyperparameter search strategy and other training parameters for each configuration, including:\n",
+ "- `classifier`: The base model to use for classification. Available options include `\"logistic\"` for logistic regression, `\"ridge\"` for ridge classifier, and `\"svm\"` for support vector machines. Default: `\"logistic\"`\n",
+ "- `scoring`: A list of performance metrics (e.g., accuracy, F1, AUROC) used during cross-validation.\n",
+ "- `num_solver_iterations`: Maximum number of iterations allowed for the solver to converge during model fitting.\n",
+ "- `num_search_iterations`: Number of hyperparameter combinations to evaluate in a randomized search.\n",
+ "- `num_jobs`: Number of CPU cores used in parallel for hyperparameter tuning and model training. Set to `-1` to use all of the available CPU cores or `-k` to use all but `k` CPU cores.\n",
+ "- `verbose`: Controls the verbosity of the training output. Higher values provide more detailed logs."
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 12,
"id": "654fef9c-a4cc-48f6-b626-4acfa2329291",
"metadata": {
"editable": true,
@@ -884,17 +1217,50 @@
},
"outputs": [],
"source": [
- "trainer = create_trainer(\n",
+ "from sklearn.base import clone\n",
+ "\n",
+ "# Define core training configuration\n",
+ "classifier = \"logistic\" # Classifier type (e.g., 'logistic', 'svm', etc.)\n",
+ "scoring = list(SCORE) # List of scoring metrics to evaluate during CV\n",
+ "num_solver_iterations = int(1e6) # Max iterations for the solver to converge\n",
+ "num_search_iterations = 100 # Number of parameter settings sampled in randomized search\n",
+ "num_jobs = -4 # Number of parallel jobs for training and CV\n",
+ "verbose = 1 # Verbosity level for output\n",
+ "\n",
+ "# Initialize dictionary to store trainer instances for each model configuration\n",
+ "trainers = {}\n",
+ "\n",
+ "# Create a baseline trainer without domain adaptation (MIDA disabled)\n",
+ "trainers[\"baseline\"] = create_trainer(\n",
+ " classifier, # Classifier to use\n",
+ " False, # Do not apply MIDA (no domain adaptation)\n",
+ " \"grid\", # Use grid search for hyperparameter tuning\n",
+ " cv, # Cross-validation strategy\n",
+ " scoring, # Evaluation metrics\n",
+ " num_solver_iterations, # Max solver iterations\n",
+ " num_search_iterations, # Max hyperparameter trials (ignored for grid search)\n",
+ " num_jobs, # Number of parallel jobs\n",
+ " random_state, # Random seed for reproducibility\n",
+ " verbose, # Verbosity level\n",
+ ")\n",
+ "\n",
+ "# Create a trainer with MIDA enabled, using site labels as domain adaptation factors\n",
+ "trainers[\"site_only\"] = create_trainer(\n",
" classifier,\n",
- " mida,\n",
- " search_strategy,\n",
+ " True, # Apply MIDA for domain adaptation\n",
+ " \"random\", # Use randomized search for hyperparameter tuning\n",
" cv,\n",
" scoring,\n",
" num_solver_iterations,\n",
" num_search_iterations,\n",
" num_jobs,\n",
" random_state,\n",
- ")"
+ " verbose,\n",
+ ")\n",
+ "\n",
+ "# Clone the 'site_only' trainer to create 'all_phenotypes' trainer\n",
+ "# This enables reusing the same training configuration, while modifying only the input domain factors\n",
+ "trainers[\"all_phenotypes\"] = clone(trainers[\"site_only\"])"
]
},
{
@@ -908,12 +1274,16 @@
"tags": []
},
"source": [
- "### Training"
+ "## Training and Cross-Validation\n",
+ "\n",
+ "We train each model configuration using the previously defined cross-validation strategy. The training process involves fitting the model on functional connectivity features and evaluating its performance using multiple scoring metrics (e.g., accuracy, F1-score, AUROC).\n",
+ "\n",
+ "For models with domain adaptation, we pass additional domain factors (such as site or phenotypic variables) to guide the alignment of feature representations. Cross-validation is performed to ensure robust performance estimates and to select the best hyperparameter configuration for each model."
]
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 13,
"id": "c53cab98-3265-4587-9d47-abd9462398b7",
"metadata": {
"editable": true,
@@ -924,502 +1294,41 @@
},
"outputs": [
{
- "name": "stderr",
+ "name": "stdout",
"output_type": "stream",
"text": [
- "/home/zarizky/miniforge3/envs/workshop-notebooks/lib/python3.10/site-packages/sklearn/model_selection/_split.py:86: UserWarning: The groups parameter is ignored by RepeatedStratifiedKFold\n",
- " warnings.warn(\n",
- "/home/zarizky/miniforge3/envs/workshop-notebooks/lib/python3.10/site-packages/sklearn/model_selection/_split.py:877: UserWarning: The groups parameter is ignored by StratifiedKFold\n",
- " warnings.warn(\n",
- "/home/zarizky/miniforge3/envs/workshop-notebooks/lib/python3.10/site-packages/sklearn/model_selection/_split.py:877: UserWarning: The groups parameter is ignored by StratifiedKFold\n",
- " warnings.warn(\n",
- "/home/zarizky/miniforge3/envs/workshop-notebooks/lib/python3.10/site-packages/sklearn/model_selection/_split.py:877: UserWarning: The groups parameter is ignored by StratifiedKFold\n",
- " warnings.warn(\n",
- "/home/zarizky/miniforge3/envs/workshop-notebooks/lib/python3.10/site-packages/sklearn/model_selection/_split.py:877: UserWarning: The groups parameter is ignored by StratifiedKFold\n",
- " warnings.warn(\n",
- "/home/zarizky/miniforge3/envs/workshop-notebooks/lib/python3.10/site-packages/sklearn/model_selection/_split.py:877: UserWarning: The groups parameter is ignored by StratifiedKFold\n",
- " warnings.warn(\n"
+ "Fitting 20 folds for each of 31 candidates, totalling 620 fits\n"
]
- },
- {
- "data": {
- "text/html": [
- "
RandomizedSearchCV(cv=RepeatedStratifiedKFold(n_repeats=5, n_splits=10,\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40),\n",
- " error_score='raise',\n",
- " estimator=LogisticRegression(max_iter=1000000,\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40),\n",
- " param_distributions={'C': array([3.05175781e-05, 6.10351562e-05, 1.22070312e-04, 2.44140625e-04,\n",
- " 4.88281250e-04, 9.765...\n",
- " 1.25000000e-01, 2.50000000e-01, 5.00000000e-01, 1.00000000e+00,\n",
- " 2.00000000e+00, 4.00000000e+00, 8.00000000e+00, 1.60000000e+01,\n",
- " 3.20000000e+01, 6.40000000e+01, 1.28000000e+02, 2.56000000e+02,\n",
- " 5.12000000e+02, 1.02400000e+03, 2.04800000e+03, 4.09600000e+03,\n",
- " 8.19200000e+03, 1.63840000e+04, 3.27680000e+04])},\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40,\n",
- " refit='accuracy', scoring=['accuracy', 'roc_auc']) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. RandomizedSearchCV(cv=RepeatedStratifiedKFold(n_repeats=5, n_splits=10,\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40),\n",
- " error_score='raise',\n",
- " estimator=LogisticRegression(max_iter=1000000,\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40),\n",
- " param_distributions={'C': array([3.05175781e-05, 6.10351562e-05, 1.22070312e-04, 2.44140625e-04,\n",
- " 4.88281250e-04, 9.765...\n",
- " 1.25000000e-01, 2.50000000e-01, 5.00000000e-01, 1.00000000e+00,\n",
- " 2.00000000e+00, 4.00000000e+00, 8.00000000e+00, 1.60000000e+01,\n",
- " 3.20000000e+01, 6.40000000e+01, 1.28000000e+02, 2.56000000e+02,\n",
- " 5.12000000e+02, 1.02400000e+03, 2.04800000e+03, 4.09600000e+03,\n",
- " 8.19200000e+03, 1.63840000e+04, 3.27680000e+04])},\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40,\n",
- " refit='accuracy', scoring=['accuracy', 'roc_auc']) "
- ],
- "text/plain": [
- "RandomizedSearchCV(cv=RepeatedStratifiedKFold(n_repeats=5, n_splits=10,\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40),\n",
- " error_score='raise',\n",
- " estimator=LogisticRegression(max_iter=1000000,\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40),\n",
- " param_distributions={'C': array([3.05175781e-05, 6.10351562e-05, 1.22070312e-04, 2.44140625e-04,\n",
- " 4.88281250e-04, 9.765...\n",
- " 1.25000000e-01, 2.50000000e-01, 5.00000000e-01, 1.00000000e+00,\n",
- " 2.00000000e+00, 4.00000000e+00, 8.00000000e+00, 1.60000000e+01,\n",
- " 3.20000000e+01, 6.40000000e+01, 1.28000000e+02, 2.56000000e+02,\n",
- " 5.12000000e+02, 1.02400000e+03, 2.04800000e+03, 4.09600000e+03,\n",
- " 8.19200000e+03, 1.63840000e+04, 3.27680000e+04])},\n",
- " random_state=RandomState(MT19937) at 0x7FA866A4DE40,\n",
- " refit='accuracy', scoring=['accuracy', 'roc_auc'])"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
}
],
"source": [
- "fit_args = {\"x\" if mida else \"X\": features, \"y\": labels, \"groups\": sites}\n",
+ "# Define common training arguments for all models: features (X), labels (y), and group info (sites)\n",
+ "fit_args = {\"X\": features, \"y\": labels, \"groups\": sites}\n",
+ "\n",
+ "# Fit the baseline model using raw features (no domain adaptation)\n",
+ "trainers[\"baseline\"].fit(**fit_args)\n",
+ "\n",
+ "# Prepare arguments for MIDA-based models by renaming X -> x (MIDA uses 'x')\n",
+ "fit_args_mida = fit_args.copy()\n",
+ "fit_args_mida[\"x\"] = fit_args_mida.pop(\"X\")\n",
"\n",
- "if mida and site_only:\n",
- " fit_args[\"factors\"] = pd.get_dummies(groups)\n",
- "elif mida:\n",
- " fit_args[\"factors\"] = phenotypes\n",
+ "# Add one-hot encoded site information as domain factors for MIDA\n",
+ "fit_args_mida[\"factors\"] = pd.get_dummies(sites)\n",
"\n",
- "trainer.fit(**fit_args)"
+ "# Fit the 'site_only' model using domain adaptation with site labels only\n",
+ "trainers[\"site_only\"].fit(**fit_args_mida)\n",
+ "\n",
+ "# Update the MIDA input to include full phenotype metadata (e.g., age, gender, site)\n",
+ "fit_args_mida[\"factors\"] = phenotypes\n",
+ "\n",
+ "# Fit the 'all_phenotypes' model using full metadata as domain-relevant factors\n",
+ "trainers[\"all_phenotypes\"].fit(**fit_args_mida)\n",
+ "\n",
+ "# Collect cross-validation results from each trained model for later comparison\n",
+ "cv_results = {}\n",
+ "for model in trainers:\n",
+ " # Store each model's cv_results_ (e.g., scores, ranks) in a DataFrame\n",
+ " cv_results[model] = pd.DataFrame(trainers[model].cv_results_)"
]
},
{
@@ -1433,36 +1342,17 @@
"tags": []
},
"source": [
- "### Evaluation"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "id": "f6a1e8f1-048e-4b8f-8ad2-ab243ed872f3",
- "metadata": {
- "editable": true,
- "slideshow": {
- "slide_type": ""
- },
- "tags": []
- },
- "outputs": [],
- "source": [
- "cv_results = pd.DataFrame(trainer.cv_results_)\n",
- "cv_results = cv_results[\n",
- " [f\"{aggregate}_test_{score}\" for score in scoring for aggregate in [\"mean\", \"std\"]]\n",
- "]\n",
+ "## Evaluation\n",
+ "\n",
+ "We evaluate and compare the performance of different model configurations using cross-validation results. We aggregate the top-performing scores for each model based on a specified evaluation metric (e.g., accuracy), allowing us to assess the effectiveness of domain adaptation strategies.\n",
"\n",
- "cv_results = cv_results.sort_values(\"mean_test_accuracy\", ascending=False)\n",
- "cv_results = cv_results.round(4).reset_index(drop=True)\n",
- "cv_results.index.name = \"Rank\""
+ "By comparing models with and without domain adaptation, we can determine the impact of incorporating site and phenotypic information on multi-site autism classification performance. This analysis helps identify which configurations generalize best across heterogeneous imaging sites."
]
},
{
"cell_type": "code",
- "execution_count": 18,
- "id": "a623e47d-0846-4286-8d7d-05a8303d96d1",
+ "execution_count": 14,
+ "id": "ccffed05-868f-40ba-8ae6-8a2661875ad7",
"metadata": {
"editable": true,
"slideshow": {
@@ -1492,13 +1382,17 @@
" \n",
" \n",
" \n",
- " mean_test_accuracy \n",
- " std_test_accuracy \n",
- " mean_test_roc_auc \n",
- " std_test_roc_auc \n",
+ " Accuracy \n",
+ " Precision \n",
+ " Recall \n",
+ " F1 \n",
+ " AUROC \n",
+ " MCC \n",
" \n",
" \n",
- " Rank \n",
+ " Model \n",
+ " \n",
+ " \n",
" \n",
" \n",
" \n",
@@ -1507,114 +1401,201 @@
" \n",
" \n",
" \n",
- " 0 \n",
- " 0.6966 \n",
- " 0.0434 \n",
- " 0.7568 \n",
- " 0.0405 \n",
- " \n",
- " \n",
- " 1 \n",
- " 0.6950 \n",
- " 0.0421 \n",
- " 0.7548 \n",
- " 0.0399 \n",
- " \n",
- " \n",
- " 2 \n",
- " 0.6948 \n",
- " 0.0317 \n",
- " 0.7476 \n",
- " 0.0384 \n",
+ " Baseline \n",
+ " 0.6446 ± 0.0963 \n",
+ " 0.6352 ± 0.1622 \n",
+ " 0.5771 ± 0.1745 \n",
+ " 0.5851 ± 0.1298 \n",
+ " 0.6921 ± 0.1016 \n",
+ " 0.2889 ± 0.1881 \n",
" \n",
" \n",
- " 3 \n",
- " 0.6939 \n",
- " 0.0314 \n",
- " 0.7477 \n",
- " 0.0384 \n",
+ " Site Only \n",
+ " 0.6566 ± 0.0748 \n",
+ " 0.6355 ± 0.1207 \n",
+ " 0.5999 ± 0.0996 \n",
+ " 0.6112 ± 0.0920 \n",
+ " 0.6893 ± 0.1115 \n",
+ " 0.3090 ± 0.1516 \n",
" \n",
" \n",
- " 4 \n",
- " 0.6936 \n",
- " 0.0318 \n",
- " 0.7470 \n",
- " 0.0388 \n",
- " \n",
- " \n",
- " 5 \n",
- " 0.6934 \n",
- " 0.0325 \n",
- " 0.7482 \n",
- " 0.0385 \n",
- " \n",
- " \n",
- " 6 \n",
- " 0.6934 \n",
- " 0.0318 \n",
- " 0.7479 \n",
- " 0.0384 \n",
- " \n",
- " \n",
- " 7 \n",
- " 0.6886 \n",
- " 0.0373 \n",
- " 0.7515 \n",
- " 0.0388 \n",
- " \n",
- " \n",
- " 8 \n",
- " 0.6870 \n",
- " 0.0360 \n",
- " 0.7466 \n",
- " 0.0373 \n",
- " \n",
- " \n",
- " 9 \n",
- " 0.5612 \n",
- " 0.0225 \n",
- " 0.6736 \n",
- " 0.0655 \n",
+ " All Phenotypes \n",
+ " 0.6567 ± 0.0927 \n",
+ " 0.6216 ± 0.1328 \n",
+ " 0.5834 ± 0.1743 \n",
+ " 0.5909 ± 0.1595 \n",
+ " 0.6898 ± 0.0978 \n",
+ " 0.2909 ± 0.1869 \n",
" \n",
" \n",
"\n",
""
],
"text/plain": [
- " mean_test_accuracy std_test_accuracy mean_test_roc_auc \\\n",
- "Rank \n",
- "0 0.6966 0.0434 0.7568 \n",
- "1 0.6950 0.0421 0.7548 \n",
- "2 0.6948 0.0317 0.7476 \n",
- "3 0.6939 0.0314 0.7477 \n",
- "4 0.6936 0.0318 0.7470 \n",
- "5 0.6934 0.0325 0.7482 \n",
- "6 0.6934 0.0318 0.7479 \n",
- "7 0.6886 0.0373 0.7515 \n",
- "8 0.6870 0.0360 0.7466 \n",
- "9 0.5612 0.0225 0.6736 \n",
- "\n",
- " std_test_roc_auc \n",
- "Rank \n",
- "0 0.0405 \n",
- "1 0.0399 \n",
- "2 0.0384 \n",
- "3 0.0384 \n",
- "4 0.0388 \n",
- "5 0.0385 \n",
- "6 0.0384 \n",
- "7 0.0388 \n",
- "8 0.0373 \n",
- "9 0.0655 "
+ " Accuracy Precision Recall \\\n",
+ "Model \n",
+ "Baseline 0.6446 ± 0.0963 0.6352 ± 0.1622 0.5771 ± 0.1745 \n",
+ "Site Only 0.6566 ± 0.0748 0.6355 ± 0.1207 0.5999 ± 0.0996 \n",
+ "All Phenotypes 0.6567 ± 0.0927 0.6216 ± 0.1328 0.5834 ± 0.1743 \n",
+ "\n",
+ " F1 AUROC MCC \n",
+ "Model \n",
+ "Baseline 0.5851 ± 0.1298 0.6921 ± 0.1016 0.2889 ± 0.1881 \n",
+ "Site Only 0.6112 ± 0.0920 0.6893 ± 0.1115 0.3090 ± 0.1516 \n",
+ "All Phenotypes 0.5909 ± 0.1595 0.6898 ± 0.0978 0.2909 ± 0.1869 "
]
},
- "execution_count": 18,
"metadata": {},
- "output_type": "execute_result"
+ "output_type": "display_data"
}
],
"source": [
- "cv_results"
+ "# Compile the cross-validation results into a summary table,\n",
+ "# sorting by the model with the highest test accuracy across CV folds\n",
+ "compiled_results = compile_results(cv_results, \"accuracy\")\n",
+ "\n",
+ "# Display the compiled results DataFrame (models as rows, metrics as formatted strings)\n",
+ "display(compiled_results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "02820ae0-31ab-42b7-8343-cd330f1331fa",
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "source": [
+ "## Interpretation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0523ad8a-cf8f-4984-982a-a7d23dddeeb8",
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "source": [
+ "We interpret the trained models by analyzing the learned weights associated with functional connectivity features. Specifically, we extract the top-weighted ROI pairs that contributed most to the classification decision.\n",
+ "\n",
+ "These weights are visualized as a **connectome plot**, allowing us to examine which brain region interactions are most informative for distinguishing individuals with autism from controls. This not only enhances the interpretability of the model but also provides potential insights into neurobiological patterns relevant to autism."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "7529a360-33d4-46fe-ba45-f414e7e8d6c4",
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[get_dataset_dir] Dataset found in /home/zarizky/nilearn_data/aal_SPM12\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from nilearn.datasets import fetch_atlas_aal\n",
+ "from nilearn.plotting import find_parcellation_cut_coords\n",
+ "\n",
+ "# Retrieve the trained model configured with all phenotypic metadata\n",
+ "model = trainers[\"all_phenotypes\"]\n",
+ "\n",
+ "# Fetch the AAL (Automated Anatomical Labeling) atlas\n",
+ "# Provides ROI labels and brain region maps\n",
+ "atlas = fetch_atlas_aal()\n",
+ "\n",
+ "# Extract 3D MNI coordinates for each ROI in the AAL atlas\n",
+ "coords = find_parcellation_cut_coords(atlas.maps)\n",
+ "\n",
+ "# Extract weights from the MIDA-transformed feature space (domain-adapted representation)\n",
+ "mida_weights = model.best_mida_.orig_coef_\n",
+ "\n",
+ "# Extract weights from the final classifier (e.g., logistic regression coefficients)\n",
+ "classifier_weights = np.squeeze(model.best_estimator_.coef_, axis=0)\n",
+ "\n",
+ "# Compute final feature-level weights by linearly combining MIDA and classifier weights\n",
+ "# Resulting in one scalar per connectivity feature\n",
+ "weights = (mida_weights @ classifier_weights).T\n",
+ "\n",
+ "# Visualize the top 0.2% strongest ROI-to-ROI connections using a connectome plot\n",
+ "proj = visualize_connectome(\n",
+ " weights=weights,\n",
+ " labels=atlas.labels, # ROI names from the AAL atlas\n",
+ " coords=coords, # ROI spatial coordinates (MNI space)\n",
+ " p=0.002, # Visualize top 0.2% of weights by magnitude\n",
+ " legend_params=dict(\n",
+ " title=\"Region of Interest\", # Title shown above the legend\n",
+ " ncols=3, # Organize legend entries into 3 columns\n",
+ " loc=\"lower center\", # Legend anchor point\n",
+ " bbox_to_anchor=(1.5, -1.0), # Adjust legend position\n",
+ " ),\n",
+ ")\n",
+ "\n",
+ "# Display the resulting connectome plot\n",
+ "display(proj)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ea959ab0-f8f9-4bd6-866f-658963977abf",
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "source": [
+ "This plot shows the **most discriminative ROI connections** for classifying ASD vs Control subjects.\n",
+ "- **Red edges** indicate connections **stronger in ASD**.\n",
+ "- **Blue edges** indicate connections **stronger in Control**.\n",
+ "- Color intensity reflects the **magnitude of contribution** to the model’s decision.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "**Key Patterns**:\n",
+ "- **Increased connectivity in ASD (Red**):\n",
+ " - Between **frontal regions** (*Frontal_Inf_Tri_L*, *Frontal_Sup_R*) and **limbic structures** (*Hippocampus_R*, *Amygdala_L*), suggesting atypical emotional or executive function coupling.\n",
+ " - **Temporal and subcortical regions** (*Temporal_Pole_Mid_R*, *Putamen_R*) are also prominent—linked to altered language, reward, or sensorimotor integration in ASD.\n",
+ "- **Increased connectivity in Control (Blue)**:\n",
+ " - Involving **parietal and sensorimotor regions** (*Postcentral_R*, *Precuneus_R*, *Supramarginal_R*), indicating more typical integration of sensory and motor pathways.\n",
+ " - Also includes **default mode and association regions** (*Angular_L*, *Parietal_Inf_R*), which are often underconnected in ASD.\n",
+ "\n",
+ "The model distinguishes ASD from Control subjects by identifying **abnormal functional connections**, especially across **frontal–limbic**, **temporal–subcortical**, and **parietal–sensorimotor networks**, aligning with known neurodevelopmental differences in autism."
]
}
],