diff --git a/README.md b/README.md
index a6ae46b..d73f9ba 100644
--- a/README.md
+++ b/README.md
@@ -11,22 +11,30 @@ This repository contains materials and resources for workshops conducted in 2025
## How to Contribute
1. Fork the repository.
-2. Create a new branch for your feature or bug fix:
+2. Create a new branch for your tutorial or bug fix:
```bash
- git checkout -b my-feature-branch
+ git checkout -b my-branch
```
3. Make your changes and commit them with clear messages:
```bash
- git commit -m "Add feature X"
+ git commit -m "Add function ... to simplify tutorial ... content"
```
4. Push your branch to your forked repository:
```bash
- git push origin my-feature-branch
+ git push origin my-tutorial-branch
```
5. Open a pull request to the main repository.
Please ensure your contributions adhere to the repository's coding standards and include appropriate documentation.
+## Building the Book
+
+To build the book in development, assuming that the working directory is the project's folder, please call:
+
+```bash
+jupyter-book build .
+```
+
## Pre-commit Hooks
This repository uses pre-commit hooks to ensure code quality and consistency. To set up pre-commit hooks locally, follow these steps:
diff --git a/tutorials/brain-disorder-diagnosis/config.py b/tutorials/brain-disorder-diagnosis/config.py
index 8392c67..36bde8b 100644
--- a/tutorials/brain-disorder-diagnosis/config.py
+++ b/tutorials/brain-disorder-diagnosis/config.py
@@ -1,14 +1,12 @@
-import os
from yacs.config import CfgNode
-DEFAULT_DIR = os.path.join(os.getcwd(), "data")
_C = CfgNode()
# Dataset configuration
_C.DATASET = CfgNode()
# Path to the dataset directory
-_C.DATASET.PATH = DEFAULT_DIR
+_C.DATASET.PATH = "data"
# Name of the brain atlas to use
# Available options:
# - "aal" (AAL)
diff --git a/tutorials/brain-disorder-diagnosis/data.py b/tutorials/brain-disorder-diagnosis/data.py
index e563eac..b396973 100644
--- a/tutorials/brain-disorder-diagnosis/data.py
+++ b/tutorials/brain-disorder-diagnosis/data.py
@@ -27,8 +27,8 @@
}
)
],
- "vectorize": [bool],
- "verbose": [bool],
+ "vectorize": ["boolean"],
+ "verbose": ["boolean"],
},
prefer_skip_nested_validation=False,
)
diff --git a/tutorials/brain-disorder-diagnosis/notebook.ipynb b/tutorials/brain-disorder-diagnosis/notebook.ipynb
index 2ea8c10..d17502c 100644
--- a/tutorials/brain-disorder-diagnosis/notebook.ipynb
+++ b/tutorials/brain-disorder-diagnosis/notebook.ipynb
@@ -1,742 +1,742 @@
{
- "nbformat": 4,
- "nbformat_minor": 5,
- "metadata": {
- "kernelspec": {
- "display_name": "embc25",
- "language": "python",
- "name": "python3"
- }
- },
- "cells": [
- {
- "metadata": {},
- "source": [
- "# Brain Disorder Diagnosis\n",
- "\n",
- "In this tutorial, we demonstrate how to leverage **patient phenotypic information** to reduce **site-specific biases** in functional connectivity data using **domain adaptation**, with the goal of improving **multi-site autism classification**.\n",
- "\n",
- "This notebook builds on the work of **Kunda et al. (IEEE TMI, 2022)**, which introduced a second-order functional connectivity representation called **Tangent-Pearson**, the tangent embedding of the Pearson correlation matrix. The original work also applied domain adaptation to reduce site dependencies in fMRI-derived features, using **site labels** as the domain information.\n",
- "\n",
- "We extend this approach by incorporating a **richer set** of phenotypic variables, such as sex, handedness, age, and eye status into the domain adaptation framework. This enables more effective harmonization across data collected from different imaging sites.\n",
- "\n",
- "---\n",
- "\n",
- "**Objectives**\n",
- "\n",
- "1.\t**Load** the ABIDE dataset using different preprocessing pipelines and brain atlases.\n",
- "2.\t**Preprocess** phenotypic variables for use in domain adaptation, and obtain class labels (ASD vs CONTROL) and site labels.\n",
- "3.\t**Extract** functional connectivity **embedding** from ROI-based time series.\n",
- "4.\t**Build** a **training** and **evaluation** pipeline to assess classification performance under various domain adaptation strategies.\n",
- "5.\t**Interpret** the learned model by extracting weights for pairwise ROI feature importance and visualizing them using a connectome plot."
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "## Setup\n",
- "\n",
- "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.\n",
- "\n",
- "In addition, several helper scripts are provided to modularize the code and simplify the workflow. These can be inspected directly as `.py` files in the notebook’s current directory. The helper scripts include:\n",
- "\n",
- "- **`config.py`**: Defines the base configuration settings, which can be customized and overridden using external `.yml` files.\n",
- "- **`data.py`**: Provides data loading functions and utilities to automatically download any required datasets.\n",
- "- **`parsing.py`**: Contains utilities to compile and summarize evaluation results from the training process.\n",
- "- **`preprocess.py`**: Handles phenotype preprocessing, including missing value imputation, categorical variable encoding, and feature extraction from fMRI time series data."
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": [
- "hide-input"
- ]
- },
- "source": [
- "import os\n",
- "import site\n",
- "import sys\n",
- "import warnings\n",
- "\n",
- "warnings.filterwarnings(\"ignore\")\n",
- "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"\n",
- "\n",
- "if \"google.colab\" in str(get_ipython()):\n",
- " sys.path.insert(0, site.getusersitepackages())\n",
- " !git clone --single-branch https://github.com/pykale/embc-mmai25.git\n",
- " %cp -r /content/embc-mmai25/tutorials/brain-disorder-diagnosis/* /content/\n",
- " %rm -r /content/embc-mmai25"
- ],
- "cell_type": "code",
- "outputs": [],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "## Packages\n",
- "\n",
- "The main packages required for this tutorial are:\n",
- "\n",
- "- **pykale**: An open-source interdisciplinary machine learning library developed at the University of Sheffield. It focuses on applications in biomedical and scientific domains, providing tools for multimodal learning, domain adaptation, and model interpretability.\n",
- "\n",
- "- **gdown**: A utility package that simplifies downloading files and folders directly from Google Drive.\n",
- "\n",
- "- **nilearn**: A Python library for neuroimaging analysis. It offers convenient tools for processing, analyzing, and visualizing functional MRI (fMRI) data.\n",
- "\n",
- "- **yacs**: A lightweight configuration management library used to store and organize experiment settings in a hierarchical and human-readable format."
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": [
- "hide-input"
- ]
- },
- "source": [
- "!pip install --quiet --user \\\n",
- " git+https://github.com/pykale/pykale@main \\\n",
- " gdown==5.2.0 nilearn==0.10.4 yacs==0.1.8 \\\n",
- " && echo \"pykale, gdown, nilearn, and yacs installed successfully ✅\" \\\n",
- " || echo \"Failed to install pykale, gdown, nilearn, and yacs ❌\""
- ],
- "cell_type": "code",
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "pykale, gdown, nilearn, and yacs installed successfully ✅\n"
- ]
+ "nbformat": 4,
+ "nbformat_minor": 5,
+ "metadata": {
+ "kernelspec": {
+ "display_name": "embc",
+ "language": "python",
+ "name": "python3"
}
- ],
- "execution_count": null
- },
- {
- "metadata": {},
- "source": [
- "## Configuration\n",
- "\n",
- "To minimize the footprint of the notebook when specifying configurations, we provide a `config.py` file that defines default parameters. These can be customized by supplying a `.yml` configuration file, such as `experiments/base.yml` as an example.\n",
- "\n",
- "Please refer to these files for detailed instructions on how to customize the experiment settings. \n",
- "We provide detailed descriptions of each configurable option in the following sections."
- ],
- "cell_type": "markdown"
},
- {
- "metadata": {
- "tags": [
- "hide-input"
- ]
- },
- "source": [
- "from config import get_cfg_defaults\n",
- "\n",
- "cfg = get_cfg_defaults()\n",
- "cfg.merge_from_file(\"experiments/base.yml\")\n",
- "cfg.freeze()\n",
- "print(cfg)"
- ],
- "cell_type": "code",
- "outputs": [
+ "cells": [
{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "CROSS_VALIDATION:\n",
- " NUM_FOLDS: 10\n",
- " NUM_REPEATS: 1\n",
- " SPLIT: skf\n",
- "DATASET:\n",
- " ATLAS: hcp-ica\n",
- " FC: tangent-pearson\n",
- " PATH: /home/zarizky/projects/embc-mmai25/tutorials/brain-disorder-diagnosis/data\n",
- "PHENOTYPE:\n",
- " STANDARDIZE: site\n",
- "RANDOM_STATE: 0\n",
- "TRAINER:\n",
- " CLASSIFIER: lr\n",
- " NONLINEAR: False\n",
- " NUM_SEARCH_ITER: 20\n",
- " NUM_SOLVER_ITER: 100\n",
- " N_JOBS: -1\n",
- " PRE_DISPATCH: 2*n_jobs\n",
- " REFIT: accuracy\n",
- " SCORING: ['accuracy', 'roc_auc']\n",
- " SEARCH_STRATEGY: random\n",
- " VERBOSE: 0\n"
- ]
- }
- ],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "# Data Loading and Preprocessing\n",
- "\n",
- "Typically, raw fMRI scans require extensive preprocessing before they can be used in a machine learning pipeline. However, the **ABIDE** dataset provides several preprocessed derivatives, which can be downloaded directly from the [Preprocessed Connectomes Project (PCP)](https://preprocessed-connectomes-project.org/abide/), eliminating the need for manual preprocessing.\n",
- "\n",
- "Given the long runtime required to extract the functional connectivity embedding, we will omit this step from this notebook and provide pre-computed embeddings through the provided `load_data` function with the associated atlas.\n",
- "\n",
- "For users interested in computing the time series and functional connectivity embeddings from scratch, assuming preprocessed images are available, please refer to:\n",
- "\n",
- "- [`NiftiLabelsMasker` (Deterministic / 3D Atlas)](https://nilearn.github.io/stable/modules/generated/nilearn.maskers.NiftiLabelsMasker.html)\n",
- "- [`NiftiMapsMasker` (Probabilistic / 4D Atlas)](https://nilearn.github.io/stable/modules/generated/nilearn.maskers.NiftiMapsMasker.html)\n",
- "- `extract_functional_connectivity` function implemented in `preprocess.py`.\n",
- "\n",
- "In this tutorial, we focus on the following preprocessing options:\n",
- "\n",
- "- **`path`** (or `data_dir`): Directory where the preprocessed dataset is located.\n",
- " - *Default:* Current working directory + `/data`\n",
- "\n",
- "- **`atlas`**: The brain atlas used to extract ROI time series.\n",
- " - Available options:\n",
- " - `\"aal\"`: AAL Atlas\n",
- " - `\"cc200\"`: Cameron Craddock 200\n",
- " - `\"cc400\"`: Cameron Craddock 400\n",
- " - `\"difumo64\"`: DiFuMo 64\n",
- " - `\"dos160\"`: Dosenbach 160\n",
- " - `\"hcp-ica\"`: HCP-ICA\n",
- " - `\"ho\"`: Harvard-Oxford\n",
- " - `\"tt\"`: Talairach-Tournoux \n",
- " - *Default:* `\"cc200\"`\n",
- "\n",
- "- **`fc`**: The functional connectivity measure used to compute pairwise associations between ROIs.\n",
- " - Available options:\n",
- " - `\"pearson\"`: Pearson correlation\n",
- " - `\"partial\"`: Partial correlation\n",
- " - `\"tangent\"`: Tangent embedding\n",
- " - `\"precision\"`: Precision (inverse covariance)\n",
- " - `\"covariance\"`: Covariance\n",
- " - `\"tangent-pearson\"`: Tangent-Pearson hybrid connectivity \n",
- " - *Default:* `\"tangent-pearson\"`"
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "from data import load_data\n",
- "\n",
- "fc, phenotypes, rois, coords = load_data(\n",
- " cfg.DATASET.PATH, cfg.DATASET.ATLAS, cfg.DATASET.FC\n",
- ")"
- ],
- "cell_type": "code",
- "outputs": [
+ "metadata": {},
+ "source": [
+ "# Brain Disorder Diagnosis\n",
+ "\n",
+ "In this tutorial, we demonstrate how to leverage **patient phenotypic information** to reduce **site-specific biases** in functional connectivity data using **domain adaptation**, with the goal of improving **multi-site autism classification**.\n",
+ "\n",
+ "This notebook builds on the work of **Kunda et al. (IEEE TMI, 2022)**, which introduced a second-order functional connectivity representation called **Tangent-Pearson**, the tangent embedding of the Pearson correlation matrix. The original work also applied domain adaptation to reduce site dependencies in fMRI-derived features, using **site labels** as the domain information.\n",
+ "\n",
+ "We extend this approach by incorporating a **richer set** of phenotypic variables, such as sex, handedness, age, and eye status into the domain adaptation framework. This enables more effective harmonization across data collected from different imaging sites.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "**Objectives**\n",
+ "\n",
+ "1.\t**Load** the ABIDE dataset using different preprocessing pipelines and brain atlases.\n",
+ "2.\t**Preprocess** phenotypic variables for use in domain adaptation, and obtain class labels (ASD vs CONTROL) and site labels.\n",
+ "3.\t**Extract** functional connectivity **embedding** from ROI-based time series.\n",
+ "4.\t**Build** a **training** and **evaluation** pipeline to assess classification performance under various domain adaptation strategies.\n",
+ "5.\t**Interpret** the learned model by extracting weights for pairwise ROI feature importance and visualizing them using a connectome plot."
+ ],
+ "cell_type": "markdown"
+ },
{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "✔ File found: /home/zarizky/projects/embc-mmai25/tutorials/brain-disorder-diagnosis/data/abide/fc/hcp-ica/tangent-pearson.npy\n",
- "✔ File found: /home/zarizky/projects/embc-mmai25/tutorials/brain-disorder-diagnosis/data/abide/phenotypes.csv\n",
- "✔ Atlas folder found: /home/zarizky/projects/embc-mmai25/tutorials/brain-disorder-diagnosis/data/atlas/deterministic/hcp-ica\n"
- ]
- }
- ],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "## Phenotype Preprocessing\n",
- "\n",
- "The phenotypic information in the dataset contains several missing values. We impute and encode these variables to make them suitable for modeling. The `preprocess_phenotypic_data` function handles this functionality for us.\n",
- "\n",
- "### Categorical Variables\n",
- "\n",
- "The following categorical phenotypes are used and will be **one-hot encoded**:\n",
- "\n",
- "- `SITE_ID`\n",
- "- `SEX`\n",
- "- `HANDEDNESS_CATEGORY`\n",
- "- `EYE_STATUS_AT_SCAN`\n",
- "\n",
- "### Continuous Variables\n",
- "\n",
- "The following continuous phenotypes will optionally be **standardized**:\n",
- "\n",
- "- `AGE_AT_SCAN`\n",
- "- `FIQ`\n",
- "\n",
- "Standardization options for continuous phenotypes (`standardize` argument):\n",
- "\n",
- "- `\"all\"` or `True`: Standardize across all subjects.\n",
- "- `\"site\"`: Standardize within each site.\n",
- "- `False`: No standardization.\n",
- "\n",
- "### Handling Missing Values\n",
- "\n",
- "- `HANDEDNESS_CATEGORY`: Missing values are assumed to correspond to `right-handed` subjects.\n",
- "- `FIQ`: Missing values are imputed with a default score of `100`.\n",
- "\n",
- "### Label Encoding\n",
- "\n",
- "The diagnostic label `DX_GROUP` is used to assign the target class:\n",
- "\n",
- "- `CONTROL` → `0`\n",
- "- `ASD` → `1`"
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "from preprocess import preprocess_phenotypic_data\n",
- "\n",
- "labels, sites, phenotypes = preprocess_phenotypic_data(\n",
- " phenotypes, cfg.PHENOTYPE.STANDARDIZE\n",
- ")"
- ],
- "cell_type": "code",
- "outputs": [],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "# Modeling\n",
- "\n",
- "We define and train machine learning models for classifying autism spectrum disorder (ASD) using functional connectivity features.\n",
- "\n",
- "We explore different configurations including a baseline model, domain adaptation using site information, and an extended approach that incorporates additional phenotypic variables.\n",
- "\n",
- "Each model is evaluated using cross-validation, and we analyze the impact of domain adaptation on classification performance."
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "### Cross-Validation Split\n",
- "\n",
- "To evaluate model performance reliably, we define a cross-validation (CV) strategy. By default, we use **Repeated Stratified K-Fold**, which preserves class distribution across folds and supports repeated trials for more stable estimates.\n",
- "\n",
- "Alternatively, we can also use **Leave-P-Groups-Out (LPGO)** cross-validation. This strategy is particularly useful in multi-site studies, as it ensures that data from the same group (e.g., imaging site) are not shared between training and test sets, enabling more realistic generalization assessment under domain shift.\n",
- "\n",
- "In this tutorial, we specify the following arguments:\n",
- "\n",
- "- **`split`**: Defines the cross-validation strategy.\n",
- " - Available options: \n",
- " - `\"skf\"`: Stratified K-Fold to maintain label balance in each fold.\n",
- " - `\"lpgo\"`: Leave-P-Groups-Out to evaluate generalization across sites by holding out entire groups (e.g., imaging sites).\n",
- " - *Default:* `\"skf\"`\n",
- "\n",
- "- **`num_folds`**: The number of folds for Stratified K-Fold or the number of groups to leave out in LPGO.\n",
- " - *Default:* `10`\n",
- "\n",
- "- **`num_repeats`**: The number of times the k-fold procedure is repeated to obtain more stable estimates (ignored when using LPGO).\n",
- " - *Default:* `5`\n",
- "\n",
- "- **`random_state`**: Seed for random number generators for reproducibility.\n",
- " - *Default:* `None`"
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "from sklearn.model_selection import LeavePGroupsOut, RepeatedStratifiedKFold\n",
- "\n",
- "# Define the default cross-validation strategy:\n",
- "# Repeated stratified k-fold maintains class distribution across folds and supports multiple repetitions\n",
- "cv = RepeatedStratifiedKFold(\n",
- " # Number of stratified folds\n",
- " n_splits=cfg.CROSS_VALIDATION.NUM_FOLDS,\n",
- " # Number of repeat rounds\n",
- " n_repeats=cfg.CROSS_VALIDATION.NUM_REPEATS,\n",
- " # Ensures reproducibility, intentionally set to the seed to have the same splits across runs\n",
- " random_state=cfg.RANDOM_STATE,\n",
- ")\n",
- "\n",
- "# Override with leave-p-proups-out if specified\n",
- "# This strategy holds out `p` unique groups (e.g., sites) per fold, enabling group-level generalization\n",
- "if cfg.CROSS_VALIDATION.SPLIT == \"lpgo\":\n",
- " # Use group-based CV for domain adaptation or site bias evaluation\n",
- " cv = LeavePGroupsOut(cfg.CROSS_VALIDATION.NUM_FOLDS)"
- ],
- "cell_type": "code",
- "outputs": [],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "### Model Definition\n",
- "\n",
- "We define several model configurations used for classification. Each model shares the same base classifier but differs in how domain adaptation is applied:\n",
- "\n",
- "- **Baseline**: A standard model trained directly on functional connectivity features without domain adaptation.\n",
- "- **Site Only**: A domain-adapted model that uses site labels as the adaptation factor to reduce site-specific bias.\n",
- "- **All Phenotypes**: An extended domain-adapted model that incorporates multiple phenotypic variables (e.g., age, sex, handedness) to further reduce inter-site variability.\n",
- "\n",
- "We also specify the hyperparameter search strategy and other training parameters for each configuration, including:\n",
- "\n",
- "- **`classifier`**: The base model used for classification.\n",
- " - Available options:\n",
- " - `\"lda\"`: Linear Discriminant Analysis\n",
- " - `\"lr\"`: Logistic Regression\n",
- " - `\"linear_svm\"`: Linear Support Vector Machine\n",
- " - `\"svm\"`: Kernel Support Vector Machine\n",
- " - `\"ridge\"`: Ridge Classifier (L2-regularized linear model)\n",
- " - `\"auto\"`: Automatically selects an appropriate model based on data characteristics.\n",
- " - *Default:* `\"lr\"`\n",
- "\n",
- "- **`nonlinear`**: Whether to apply non-linear transformations (non-interpretable).\n",
- " - *Type:* `boolean`\n",
- " - *Default:* `False`\n",
- "\n",
- "- **`search_strategy`**: The hyperparameter search method.\n",
- " - Available options:\n",
- " - `\"random\"`: Randomly search over finite iterations.\n",
- " - `\"grid\"`: Search over all possible combinations.\n",
- " - *Default:* `\"random\"`\n",
- "\n",
- "- **`num_search_iterations`**: The number of hyperparameter combinations to evaluate in randomized search.\n",
- " - *Default:* `1,000`\n",
- "\n",
- "- **`num_solver_iterations`**: The maximum number of iterations allowed for solver convergence.\n",
- " - *Default:* `1,000,000`\n",
- "\n",
- "- **`scoring`**: A list of performance metrics used during cross-validation.\n",
- " - Available options:\n",
- " - `\"accuracy\"`: Accuracy\n",
- " - `\"precision\"`: Precision\n",
- " - `\"recall\"`: Recall\n",
- " - `\"f1\"`: F1-Score\n",
- " - `\"roc_auc\"`: Area Under ROC Curve (AUROC)\n",
- " - `\"matthews_corrcoef\"`: Matthews Correlation Coefficient (MCC)\n",
- " - *Default:* `[\"accuracy\", \"roc_auc\"]`\n",
- "\n",
- "- **`refit`**: The metric used to refit the best model after hyperparameter tuning.\n",
- " - *Default:* `\"accuracy\"`\n",
- "\n",
- "- **`num_jobs`**: The number of CPU cores used for training and hyperparameter search.\n",
- " - Set to `-1` for all CPUs, `-k` for all but `k` CPUs.\n",
- " - *Default:* `-1`\n",
- "\n",
- "- **`pre_dispatch`**: Controls job pre-dispatching for parallel execution.\n",
- " - *Default:* `\"2*n_jobs\"`\n",
- "\n",
- "- **`verbose`**: Controls verbosity of training output.\n",
- " - *Default:* `0`\n",
- "\n",
- "- **`random_state`**: Seed for random number generators for reproducibility.\n",
- " - *Default:* `None`"
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "from kale.pipeline.multi_domain_adapter import AutoMIDAClassificationTrainer as Trainer\n",
- "from sklearn.base import clone\n",
- "\n",
- "# Configuration with cv and random_state/seed included\n",
- "trainer_cfg = {k.lower(): v for k, v in cfg.TRAINER.items()}\n",
- "trainer_cfg = {**trainer_cfg, \"cv\": cv, \"random_state\": cfg.RANDOM_STATE}\n",
- "\n",
- "# Initialize dictionary for different trainers\n",
- "trainers = {}\n",
- "\n",
- "# Create a baseline trainer without domain adaptation (MIDA disabled)\n",
- "trainers[\"baseline\"] = Trainer(use_mida=False, **trainer_cfg)\n",
- "\n",
- "# Create a trainer with MIDA enabled, using site labels as domain adaptation factors\n",
- "trainers[\"site_only\"] = Trainer(use_mida=True, **trainer_cfg)\n",
- "\n",
- "# Clone the 'site_only' trainer to create 'all_phenotypes' trainer\n",
- "# This enables reusing the same training configuration, while modifying only the input domain factors\n",
- "trainers[\"all_phenotypes\"] = clone(trainers[\"site_only\"])"
- ],
- "cell_type": "code",
- "outputs": [],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "# Training\n",
- "\n",
- "We train each model configuration using the previously defined cross-validation strategy. The training process involves fitting the model on functional connectivity features and evaluating its performance using multiple scoring metrics (e.g., accuracy, F1-score, AUROC).\n",
- "\n",
- "For models with domain adaptation, we pass additional domain factors (such as site or phenotypic variables) to guide the alignment of embedding. Cross-validation is performed to ensure robust performance estimates and to select the best hyperparameter configuration for each model."
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "import pandas as pd\n",
- "from tqdm import tqdm\n",
- "\n",
- "# Define common training arguments for all models: features (X), labels (y), and group info (sites)\n",
- "fit_args = {\"x\": fc, \"y\": labels, \"groups\": sites}\n",
- "\n",
- "cv_results = {}\n",
- "for model in (pbar := tqdm(trainers)):\n",
- " args = clone(fit_args, safe=False)\n",
- " if model == \"site_only\":\n",
- " args[\"group_labels\"] = sites\n",
- " if model == \"all_phenotypes\":\n",
- " args[\"group_labels\"] = phenotypes\n",
- "\n",
- " pbar.set_description(f\"Fitting {model} model\")\n",
- " trainers[model].fit(**args)\n",
- " cv_results[model] = pd.DataFrame(trainers[model].cv_results_)"
- ],
- "cell_type": "code",
- "outputs": [
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Setup\n",
+ "\n",
+ "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.\n",
+ "\n",
+ "In addition, several helper scripts are provided to modularize the code and simplify the workflow. These can be inspected directly as `.py` files in the notebook\u2019s current directory. The helper scripts include:\n",
+ "\n",
+ "- **`config.py`**: Defines the base configuration settings, which can be customized and overridden using external `.yml` files.\n",
+ "- **`data.py`**: Provides data loading functions and utilities to automatically download any required datasets.\n",
+ "- **`parsing.py`**: Contains utilities to compile and summarize evaluation results from the training process.\n",
+ "- **`preprocess.py`**: Handles phenotype preprocessing, including missing value imputation, categorical variable encoding, and feature extraction from fMRI time series data."
+ ],
+ "cell_type": "markdown"
+ },
{
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "Fitting all_phenotypes model: 100%|██████████| 3/3 [00:21<00:00, 7.23s/it]\n"
- ]
- }
- ],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "# Evaluation\n",
- "\n",
- "We evaluate and compare the performance of different model configurations using cross-validation results. We aggregate the top-performing scores for each model based on a specified evaluation metric (e.g., accuracy), allowing us to assess the effectiveness of domain adaptation strategies.\n",
- "\n",
- "By comparing models with and without domain adaptation, we can determine the impact of incorporating site and phenotypic information on multi-site autism classification performance. This analysis helps identify which configurations generalize best across heterogeneous imaging sites."
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "from parsing import compile_results\n",
- "\n",
- "# Compile the cross-validation results into a summary table,\n",
- "# sorting by the model with the highest test accuracy across CV folds\n",
- "compiled_results = compile_results(cv_results, \"accuracy\")\n",
- "\n",
- "# Display the compiled results DataFrame (models as rows, metrics as formatted strings)\n",
- "display(compiled_results)"
- ],
- "cell_type": "code",
- "outputs": [
+ "metadata": {
+ "tags": [
+ "hide-input"
+ ]
+ },
+ "source": [
+ "import os\n",
+ "import site\n",
+ "import sys\n",
+ "import warnings\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"\n",
+ "\n",
+ "if \"google.colab\" in str(get_ipython()):\n",
+ " sys.path.insert(0, site.getusersitepackages())\n",
+ " !git clone --single-branch https://github.com/pykale/embc-mmai25.git\n",
+ " %cp -r /content/embc-mmai25/tutorials/brain-disorder-diagnosis/* /content/\n",
+ " %rm -r /content/embc-mmai25"
+ ],
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null
+ },
{
- "output_type": "display_data",
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- "
\n",
- "
\n",
- "
Accuracy
\n",
- "
AUROC
\n",
- "
\n",
- "
\n",
- "
Model
\n",
- "
\n",
- "
\n",
- "
\n",
- " \n",
- " \n",
- "
\n",
- "
Baseline
\n",
- "
0.6629 ± 0.0523
\n",
- "
0.7105 ± 0.0556
\n",
- "
\n",
- "
\n",
- "
Site Only
\n",
- "
0.6609 ± 0.0509
\n",
- "
0.7127 ± 0.0596
\n",
- "
\n",
- "
\n",
- "
All Phenotypes
\n",
- "
0.6474 ± 0.0597
\n",
- "
0.7057 ± 0.0514
\n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "## Packages\n",
+ "\n",
+ "The main packages required for this tutorial are:\n",
+ "\n",
+ "- **pykale**: An open-source interdisciplinary machine learning library developed at the University of Sheffield. It focuses on applications in biomedical and scientific domains, providing tools for multimodal learning, domain adaptation, and model interpretability.\n",
+ "\n",
+ "- **gdown**: A utility package that simplifies downloading files and folders directly from Google Drive.\n",
+ "\n",
+ "- **nilearn**: A Python library for neuroimaging analysis. It offers convenient tools for processing, analyzing, and visualizing functional MRI (fMRI) data.\n",
+ "\n",
+ "- **yacs**: A lightweight configuration management library used to store and organize experiment settings in a hierarchical and human-readable format."
],
- "text/plain": [
- " Accuracy AUROC\n",
- "Model \n",
- "Baseline 0.6629 ± 0.0523 0.7105 ± 0.0556\n",
- "Site Only 0.6609 ± 0.0509 0.7127 ± 0.0596\n",
- "All Phenotypes 0.6474 ± 0.0597 0.7057 ± 0.0514"
- ]
- },
- "metadata": {}
- }
- ],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "# Interpretation"
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "We interpret the trained models by analyzing the learned weights associated with functional connectivity features. Specifically, we extract the top-weighted ROI pairs that contributed most to the classification decision.\n",
- "\n",
- "These weights are visualized as a **connectome plot**, allowing us to examine which brain region interactions are most informative for distinguishing individuals with autism from controls. This not only enhances the interpretability of the model but also provides potential insights into neurobiological patterns relevant to autism."
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "import numpy as np\n",
- "from kale.interpret.visualize import visualize_connectome\n",
- "\n",
- "# Fetch coefficients to visualize feature importance\n",
- "coef = trainers[\"site_only\"].coef_.ravel()\n",
- "# check if coef != features, assumes augmented features with phenotypes/sites\n",
- "if coef.shape[0] != fc.shape[1]:\n",
- " coef, _ = np.split(coef, [fc.shape[1]])\n",
- "\n",
- "# Visualize the coefficients as a connectome plot\n",
- "proj = visualize_connectome(\n",
- " trainers[\"baseline\"].coef_.ravel(),\n",
- " rois,\n",
- " coords,\n",
- " 0.015, # Take top 1.5% of connections\n",
- " legend_params={\n",
- " \"bbox_to_anchor\": (2.75, -0.1), # Align legend outside the plot\n",
- " \"ncol\": 2,\n",
- " },\n",
- ")\n",
- "\n",
- "# Display the resulting connectome plot\n",
- "display(proj)"
- ],
- "cell_type": "code",
- "outputs": [
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": [
+ "hide-input"
+ ]
+ },
+ "source": [
+ "!pip install --quiet --user \\\n",
+ " git+https://github.com/pykale/pykale@main \\\n",
+ " gdown==5.2.0 nilearn==0.10.4 yacs==0.1.8 \\\n",
+ " && echo \"pykale, gdown, nilearn, and yacs installed successfully \u2705\" \\\n",
+ " || echo \"Failed to install pykale, gdown, nilearn, and yacs \u274c\""
+ ],
+ "cell_type": "code",
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "pykale, gdown, nilearn, and yacs installed successfully \u2705\n"
+ ]
+ }
+ ],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "source": [
+ "## Configuration\n",
+ "\n",
+ "To minimize the footprint of the notebook when specifying configurations, we provide a `config.py` file that defines default parameters. These can be customized by supplying a `.yml` configuration file, such as `experiments/base.yml` as an example.\n",
+ "\n",
+ "Please refer to these files for detailed instructions on how to customize the experiment settings. \n",
+ "We provide detailed descriptions of each configurable option in the following sections."
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": [
+ "hide-input"
+ ]
+ },
+ "source": [
+ "from config import get_cfg_defaults\n",
+ "\n",
+ "cfg = get_cfg_defaults()\n",
+ "cfg.merge_from_file(\"experiments/base.yml\")\n",
+ "cfg.freeze()\n",
+ "print(cfg)"
+ ],
+ "cell_type": "code",
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "CROSS_VALIDATION:\n",
+ " NUM_FOLDS: 10\n",
+ " NUM_REPEATS: 1\n",
+ " SPLIT: skf\n",
+ "DATASET:\n",
+ " ATLAS: hcp-ica\n",
+ " FC: tangent-pearson\n",
+ " PATH: data\n",
+ "PHENOTYPE:\n",
+ " STANDARDIZE: site\n",
+ "RANDOM_STATE: 0\n",
+ "TRAINER:\n",
+ " CLASSIFIER: lr\n",
+ " NONLINEAR: False\n",
+ " NUM_SEARCH_ITER: 20\n",
+ " NUM_SOLVER_ITER: 100\n",
+ " N_JOBS: -1\n",
+ " PRE_DISPATCH: 2*n_jobs\n",
+ " REFIT: accuracy\n",
+ " SCORING: ['accuracy', 'roc_auc']\n",
+ " SEARCH_STRATEGY: random\n",
+ " VERBOSE: 0\n"
+ ]
+ }
+ ],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Data Loading and Preprocessing\n",
+ "\n",
+ "Typically, raw fMRI scans require extensive preprocessing before they can be used in a machine learning pipeline. However, the **ABIDE** dataset provides several preprocessed derivatives, which can be downloaded directly from the [Preprocessed Connectomes Project (PCP)](https://preprocessed-connectomes-project.org/abide/), eliminating the need for manual preprocessing.\n",
+ "\n",
+ "Given the long runtime required to extract the functional connectivity embedding, we will omit this step from this notebook and provide pre-computed embeddings through the provided `load_data` function with the associated atlas.\n",
+ "\n",
+ "For users interested in computing the time series and functional connectivity embeddings from scratch, assuming preprocessed images are available, please refer to:\n",
+ "\n",
+ "- [`NiftiLabelsMasker` (Deterministic / 3D Atlas)](https://nilearn.github.io/stable/modules/generated/nilearn.maskers.NiftiLabelsMasker.html)\n",
+ "- [`NiftiMapsMasker` (Probabilistic / 4D Atlas)](https://nilearn.github.io/stable/modules/generated/nilearn.maskers.NiftiMapsMasker.html)\n",
+ "- `extract_functional_connectivity` function implemented in `preprocess.py`.\n",
+ "\n",
+ "In this tutorial, we focus on the following preprocessing options:\n",
+ "\n",
+ "- **`path`** (or `data_dir`): Directory where the preprocessed dataset is located.\n",
+ " - *Default:* Current working directory + `/data`\n",
+ "\n",
+ "- **`atlas`**: The brain atlas used to extract ROI time series.\n",
+ " - Available options:\n",
+ " - `\"aal\"`: AAL Atlas\n",
+ " - `\"cc200\"`: Cameron Craddock 200\n",
+ " - `\"cc400\"`: Cameron Craddock 400\n",
+ " - `\"difumo64\"`: DiFuMo 64\n",
+ " - `\"dos160\"`: Dosenbach 160\n",
+ " - `\"hcp-ica\"`: HCP-ICA\n",
+ " - `\"ho\"`: Harvard-Oxford\n",
+ " - `\"tt\"`: Talairach-Tournoux \n",
+ " - *Default:* `\"cc200\"`\n",
+ "\n",
+ "- **`fc`**: The functional connectivity measure used to compute pairwise associations between ROIs.\n",
+ " - Available options:\n",
+ " - `\"pearson\"`: Pearson correlation\n",
+ " - `\"partial\"`: Partial correlation\n",
+ " - `\"tangent\"`: Tangent embedding\n",
+ " - `\"precision\"`: Precision (inverse covariance)\n",
+ " - `\"covariance\"`: Covariance\n",
+ " - `\"tangent-pearson\"`: Tangent-Pearson hybrid connectivity \n",
+ " - *Default:* `\"tangent-pearson\"`"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "from data import load_data\n",
+ "\n",
+ "fc, phenotypes, rois, coords = load_data(\n",
+ " cfg.DATASET.PATH, cfg.DATASET.ATLAS, cfg.DATASET.FC\n",
+ ")"
+ ],
+ "cell_type": "code",
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\u2714 File found: data/abide/fc/hcp-ica/tangent-pearson.npy\n",
+ "\u2714 File found: data/abide/phenotypes.csv\n",
+ "\u2714 Atlas folder found: data/atlas/deterministic/hcp-ica\n"
+ ]
+ }
+ ],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "## Phenotype Preprocessing\n",
+ "\n",
+ "The phenotypic information in the dataset contains several missing values. We impute and encode these variables to make them suitable for modeling. The `preprocess_phenotypic_data` function handles this functionality for us.\n",
+ "\n",
+ "### Categorical Variables\n",
+ "\n",
+ "The following categorical phenotypes are used and will be **one-hot encoded**:\n",
+ "\n",
+ "- `SITE_ID`\n",
+ "- `SEX`\n",
+ "- `HANDEDNESS_CATEGORY`\n",
+ "- `EYE_STATUS_AT_SCAN`\n",
+ "\n",
+ "### Continuous Variables\n",
+ "\n",
+ "The following continuous phenotypes will optionally be **standardized**:\n",
+ "\n",
+ "- `AGE_AT_SCAN`\n",
+ "- `FIQ`\n",
+ "\n",
+ "Standardization options for continuous phenotypes (`standardize` argument):\n",
+ "\n",
+ "- `\"all\"` or `True`: Standardize across all subjects.\n",
+ "- `\"site\"`: Standardize within each site.\n",
+ "- `False`: No standardization.\n",
+ "\n",
+ "### Handling Missing Values\n",
+ "\n",
+ "- `HANDEDNESS_CATEGORY`: Missing values are assumed to correspond to `right-handed` subjects.\n",
+ "- `FIQ`: Missing values are imputed with a default score of `100`.\n",
+ "\n",
+ "### Label Encoding\n",
+ "\n",
+ "The diagnostic label `DX_GROUP` is used to assign the target class:\n",
+ "\n",
+ "- `CONTROL` \u2192 `0`\n",
+ "- `ASD` \u2192 `1`"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "from preprocess import preprocess_phenotypic_data\n",
+ "\n",
+ "labels, sites, phenotypes = preprocess_phenotypic_data(\n",
+ " phenotypes, cfg.PHENOTYPE.STANDARDIZE\n",
+ ")"
+ ],
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Modeling\n",
+ "\n",
+ "We define and train machine learning models for classifying autism spectrum disorder (ASD) using functional connectivity features.\n",
+ "\n",
+ "We explore different configurations including a baseline model, domain adaptation using site information, and an extended approach that incorporates additional phenotypic variables.\n",
+ "\n",
+ "Each model is evaluated using cross-validation, and we analyze the impact of domain adaptation on classification performance."
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "### Cross-Validation Split\n",
+ "\n",
+ "To evaluate model performance reliably, we define a cross-validation (CV) strategy. By default, we use **Repeated Stratified K-Fold**, which preserves class distribution across folds and supports repeated trials for more stable estimates.\n",
+ "\n",
+ "Alternatively, we can also use **Leave-P-Groups-Out (LPGO)** cross-validation. This strategy is particularly useful in multi-site studies, as it ensures that data from the same group (e.g., imaging site) are not shared between training and test sets, enabling more realistic generalization assessment under domain shift.\n",
+ "\n",
+ "In this tutorial, we specify the following arguments:\n",
+ "\n",
+ "- **`split`**: Defines the cross-validation strategy.\n",
+ " - Available options: \n",
+ " - `\"skf\"`: Stratified K-Fold to maintain label balance in each fold.\n",
+ " - `\"lpgo\"`: Leave-P-Groups-Out to evaluate generalization across sites by holding out entire groups (e.g., imaging sites).\n",
+ " - *Default:* `\"skf\"`\n",
+ "\n",
+ "- **`num_folds`**: The number of folds for Stratified K-Fold or the number of groups to leave out in LPGO.\n",
+ " - *Default:* `10`\n",
+ "\n",
+ "- **`num_repeats`**: The number of times the k-fold procedure is repeated to obtain more stable estimates (ignored when using LPGO).\n",
+ " - *Default:* `5`\n",
+ "\n",
+ "- **`random_state`**: Seed for random number generators for reproducibility.\n",
+ " - *Default:* `None`"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "from sklearn.model_selection import LeavePGroupsOut, RepeatedStratifiedKFold\n",
+ "\n",
+ "# Define the default cross-validation strategy:\n",
+ "# Repeated stratified k-fold maintains class distribution across folds and supports multiple repetitions\n",
+ "cv = RepeatedStratifiedKFold(\n",
+ " # Number of stratified folds\n",
+ " n_splits=cfg.CROSS_VALIDATION.NUM_FOLDS,\n",
+ " # Number of repeat rounds\n",
+ " n_repeats=cfg.CROSS_VALIDATION.NUM_REPEATS,\n",
+ " # Ensures reproducibility, intentionally set to the seed to have the same splits across runs\n",
+ " random_state=cfg.RANDOM_STATE,\n",
+ ")\n",
+ "\n",
+ "# Override with leave-p-proups-out if specified\n",
+ "# This strategy holds out `p` unique groups (e.g., sites) per fold, enabling group-level generalization\n",
+ "if cfg.CROSS_VALIDATION.SPLIT == \"lpgo\":\n",
+ " # Use group-based CV for domain adaptation or site bias evaluation\n",
+ " cv = LeavePGroupsOut(cfg.CROSS_VALIDATION.NUM_FOLDS)"
+ ],
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "### Model Definition\n",
+ "\n",
+ "We define several model configurations used for classification. Each model shares the same base classifier but differs in how domain adaptation is applied:\n",
+ "\n",
+ "- **Baseline**: A standard model trained directly on functional connectivity features without domain adaptation.\n",
+ "- **Site Only**: A domain-adapted model that uses site labels as the adaptation factor to reduce site-specific bias.\n",
+ "- **All Phenotypes**: An extended domain-adapted model that incorporates multiple phenotypic variables (e.g., age, sex, handedness) to further reduce inter-site variability.\n",
+ "\n",
+ "We also specify the hyperparameter search strategy and other training parameters for each configuration, including:\n",
+ "\n",
+ "- **`classifier`**: The base model used for classification.\n",
+ " - Available options:\n",
+ " - `\"lda\"`: Linear Discriminant Analysis\n",
+ " - `\"lr\"`: Logistic Regression\n",
+ " - `\"linear_svm\"`: Linear Support Vector Machine\n",
+ " - `\"svm\"`: Kernel Support Vector Machine\n",
+ " - `\"ridge\"`: Ridge Classifier (L2-regularized linear model)\n",
+ " - `\"auto\"`: Automatically selects an appropriate model based on data characteristics.\n",
+ " - *Default:* `\"lr\"`\n",
+ "\n",
+ "- **`nonlinear`**: Whether to apply non-linear transformations (non-interpretable).\n",
+ " - *Type:* `boolean`\n",
+ " - *Default:* `False`\n",
+ "\n",
+ "- **`search_strategy`**: The hyperparameter search method.\n",
+ " - Available options:\n",
+ " - `\"random\"`: Randomly search over finite iterations.\n",
+ " - `\"grid\"`: Search over all possible combinations.\n",
+ " - *Default:* `\"random\"`\n",
+ "\n",
+ "- **`num_search_iterations`**: The number of hyperparameter combinations to evaluate in randomized search.\n",
+ " - *Default:* `1,000`\n",
+ "\n",
+ "- **`num_solver_iterations`**: The maximum number of iterations allowed for solver convergence.\n",
+ " - *Default:* `1,000,000`\n",
+ "\n",
+ "- **`scoring`**: A list of performance metrics used during cross-validation.\n",
+ " - Available options:\n",
+ " - `\"accuracy\"`: Accuracy\n",
+ " - `\"precision\"`: Precision\n",
+ " - `\"recall\"`: Recall\n",
+ " - `\"f1\"`: F1-Score\n",
+ " - `\"roc_auc\"`: Area Under ROC Curve (AUROC)\n",
+ " - `\"matthews_corrcoef\"`: Matthews Correlation Coefficient (MCC)\n",
+ " - *Default:* `[\"accuracy\", \"roc_auc\"]`\n",
+ "\n",
+ "- **`refit`**: The metric used to refit the best model after hyperparameter tuning.\n",
+ " - *Default:* `\"accuracy\"`\n",
+ "\n",
+ "- **`num_jobs`**: The number of CPU cores used for training and hyperparameter search.\n",
+ " - Set to `-1` for all CPUs, `-k` for all but `k` CPUs.\n",
+ " - *Default:* `-1`\n",
+ "\n",
+ "- **`pre_dispatch`**: Controls job pre-dispatching for parallel execution.\n",
+ " - *Default:* `\"2*n_jobs\"`\n",
+ "\n",
+ "- **`verbose`**: Controls verbosity of training output.\n",
+ " - *Default:* `0`\n",
+ "\n",
+ "- **`random_state`**: Seed for random number generators for reproducibility.\n",
+ " - *Default:* `None`"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "from sklearn.base import clone\n",
+ "from kale.pipeline.multi_domain_adapter import AutoMIDAClassificationTrainer as Trainer\n",
+ "\n",
+ "# Configuration with cv and random_state/seed included\n",
+ "trainer_cfg = {k.lower(): v for k, v in cfg.TRAINER.items()}\n",
+ "trainer_cfg = {**trainer_cfg, \"cv\": cv, \"random_state\": cfg.RANDOM_STATE}\n",
+ "\n",
+ "# Initialize dictionary for different trainers\n",
+ "trainers = {}\n",
+ "\n",
+ "# Create a baseline trainer without domain adaptation (MIDA disabled)\n",
+ "trainers[\"baseline\"] = Trainer(use_mida=False, **trainer_cfg)\n",
+ "\n",
+ "# Create a trainer with MIDA enabled, using site labels as domain adaptation factors\n",
+ "trainers[\"site_only\"] = Trainer(use_mida=True, **trainer_cfg)\n",
+ "\n",
+ "# Clone the 'site_only' trainer to create 'all_phenotypes' trainer\n",
+ "# This enables reusing the same training configuration, while modifying only the input domain factors\n",
+ "trainers[\"all_phenotypes\"] = clone(trainers[\"site_only\"])"
+ ],
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null
+ },
{
- "output_type": "display_data",
- "data": {
- "text/plain": [
- ""
- ]
- },
- "metadata": {}
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Training\n",
+ "\n",
+ "We train each model configuration using the previously defined cross-validation strategy. The training process involves fitting the model on functional connectivity features and evaluating its performance using multiple scoring metrics (e.g., accuracy, F1-score, AUROC).\n",
+ "\n",
+ "For models with domain adaptation, we pass additional domain factors (such as site or phenotypic variables) to guide the alignment of embedding. Cross-validation is performed to ensure robust performance estimates and to select the best hyperparameter configuration for each model."
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "import pandas as pd\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "# Define common training arguments for all models: features (X), labels (y), and group info (sites)\n",
+ "fit_args = {\"x\": fc, \"y\": labels, \"groups\": sites}\n",
+ "\n",
+ "cv_results = {}\n",
+ "for model in (pbar := tqdm(trainers)):\n",
+ " args = clone(fit_args, safe=False)\n",
+ " if model == \"site_only\":\n",
+ " args[\"group_labels\"] = sites\n",
+ " if model == \"all_phenotypes\":\n",
+ " args[\"group_labels\"] = phenotypes\n",
+ "\n",
+ " pbar.set_description(f\"Fitting {model} model\")\n",
+ " trainers[model].fit(**args)\n",
+ " cv_results[model] = pd.DataFrame(trainers[model].cv_results_)"
+ ],
+ "cell_type": "code",
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Fitting all_phenotypes model: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 3/3 [00:24<00:00, 8.19s/it]\n"
+ ]
+ }
+ ],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Evaluation\n",
+ "\n",
+ "We evaluate and compare the performance of different model configurations using cross-validation results. We aggregate the top-performing scores for each model based on a specified evaluation metric (e.g., accuracy), allowing us to assess the effectiveness of domain adaptation strategies.\n",
+ "\n",
+ "By comparing models with and without domain adaptation, we can determine the impact of incorporating site and phenotypic information on multi-site autism classification performance. This analysis helps identify which configurations generalize best across heterogeneous imaging sites."
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "from parsing import compile_results\n",
+ "\n",
+ "# Compile the cross-validation results into a summary table,\n",
+ "# sorting by the model with the highest test accuracy across CV folds\n",
+ "compiled_results = compile_results(cv_results, \"accuracy\")\n",
+ "\n",
+ "# Display the compiled results DataFrame (models as rows, metrics as formatted strings)\n",
+ "display(compiled_results)"
+ ],
+ "cell_type": "code",
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
Accuracy
\n",
+ "
AUROC
\n",
+ "
\n",
+ "
\n",
+ "
Model
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
Baseline
\n",
+ "
0.6629 \u00b1 0.0523
\n",
+ "
0.7105 \u00b1 0.0556
\n",
+ "
\n",
+ "
\n",
+ "
Site Only
\n",
+ "
0.6609 \u00b1 0.0509
\n",
+ "
0.7127 \u00b1 0.0596
\n",
+ "
\n",
+ "
\n",
+ "
All Phenotypes
\n",
+ "
0.6474 \u00b1 0.0597
\n",
+ "
0.7057 \u00b1 0.0514
\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Accuracy AUROC\n",
+ "Model \n",
+ "Baseline 0.6629 \u00b1 0.0523 0.7105 \u00b1 0.0556\n",
+ "Site Only 0.6609 \u00b1 0.0509 0.7127 \u00b1 0.0596\n",
+ "All Phenotypes 0.6474 \u00b1 0.0597 0.7057 \u00b1 0.0514"
+ ]
+ },
+ "metadata": {}
+ }
+ ],
+ "execution_count": null
},
{
- "output_type": "display_data",
- "data": {
- "image/png": "",
- "text/plain": [
- ""
- ]
- },
- "metadata": {}
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Interpretation"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "We interpret the trained models by analyzing the learned weights associated with functional connectivity features. Specifically, we extract the top-weighted ROI pairs that contributed most to the classification decision.\n",
+ "\n",
+ "These weights are visualized as a **connectome plot**, allowing us to examine which brain region interactions are most informative for distinguishing individuals with autism from controls. This not only enhances the interpretability of the model but also provides potential insights into neurobiological patterns relevant to autism."
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "import numpy as np\n",
+ "from kale.interpret.visualize import visualize_connectome\n",
+ "\n",
+ "# Fetch coefficients to visualize feature importance\n",
+ "coef = trainers[\"site_only\"].coef_.ravel()\n",
+ "# check if coef != features, assumes augmented features with phenotypes/sites\n",
+ "if coef.shape[0] != fc.shape[1]:\n",
+ " coef, _ = np.split(coef, [fc.shape[1]])\n",
+ "\n",
+ "# Visualize the coefficients as a connectome plot\n",
+ "proj = visualize_connectome(\n",
+ " trainers[\"baseline\"].coef_.ravel(),\n",
+ " rois,\n",
+ " coords,\n",
+ " 0.015, # Take top 1.5% of connections\n",
+ " legend_params={\n",
+ " \"bbox_to_anchor\": (2.75, -0.1), # Align legend outside the plot\n",
+ " \"ncol\": 2,\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "# Display the resulting connectome plot\n",
+ "display(proj)"
+ ],
+ "cell_type": "code",
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {}
+ }
+ ],
+ "execution_count": null
+ },
+ {
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "### Interpretation of Discriminative Connectivity Patterns\n",
+ "\n",
+ "This plot shows the **most discriminative ROI connections** for classifying ASD vs Control subjects.\n",
+ "- **Red edges** indicate connections **stronger in ASD**.\n",
+ "- **Blue edges** indicate connections **stronger in Control**.\n",
+ "- Color intensity reflects the **magnitude of contribution** to the model\u2019s decision.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "**Key Patterns**:\n",
+ "\n",
+ "- **Default Mode Network (DMN)**:\n",
+ " - *DefaultMode.MPFC*, *DefaultMode.PCC*, *DefaultMode.LP (L/R)*\n",
+ " - Core hubs of the DMN, associated with **self-referential processing**, **social cognition**, and often disrupted in ASD.\n",
+ "\n",
+ "- **Fronto-Parietal Network**:\n",
+ " - *FrontoParietal.PPC (L)*\n",
+ " - Involved in **executive function** and **cognitive flexibility**, domains typically impaired in ASD.\n",
+ "\n",
+ "- **Dorsal Attention Network**:\n",
+ " - *DorsalAttention.IPS (L)*\n",
+ " - Associated with **goal-directed attention**, potentially altered in ASD subjects.\n",
+ "\n",
+ "- **Salience Network**:\n",
+ " - *Salience.SMG (R)*\n",
+ " - Plays a role in **interoception** and **social-emotional processing**, relevant for ASD symptoms.\n",
+ "\n",
+ "- **Language Network**:\n",
+ " - *Language.pSTG (R)*\n",
+ " - Critical for **language comprehension** and **social communication**, often affected in ASD.\n",
+ "\n",
+ "- **Sensorimotor and Cerebellar Regions**:\n",
+ " - *SensoriMotor.Lateral (L)*, *Cerebellar.Posterior*\n",
+ " - Linked to **motor coordination** and **sensorimotor integration**, commonly atypical in ASD.\n",
+ "\n",
+ "The interpretability analysis of the trained model highlights that **functional connectivity alterations across DMN, attention, salience, language, and sensorimotor systems** are key discriminative factors for distinguishing **ASD** from **Control** subjects."
+ ],
+ "cell_type": "markdown"
}
- ],
- "execution_count": null
- },
- {
- "metadata": {
- "tags": []
- },
- "source": [
- "### Interpretation of Discriminative Connectivity Patterns\n",
- "\n",
- "This plot shows the **most discriminative ROI connections** for classifying ASD vs Control subjects.\n",
- "- **Red edges** indicate connections **stronger in ASD**.\n",
- "- **Blue edges** indicate connections **stronger in Control**.\n",
- "- Color intensity reflects the **magnitude of contribution** to the model’s decision.\n",
- "\n",
- "---\n",
- "\n",
- "**Key Patterns**:\n",
- "\n",
- "- **Default Mode Network (DMN)**:\n",
- " - *DefaultMode.MPFC*, *DefaultMode.PCC*, *DefaultMode.LP (L/R)*\n",
- " - Core hubs of the DMN, associated with **self-referential processing**, **social cognition**, and often disrupted in ASD.\n",
- "\n",
- "- **Fronto-Parietal Network**:\n",
- " - *FrontoParietal.PPC (L)*\n",
- " - Involved in **executive function** and **cognitive flexibility**, domains typically impaired in ASD.\n",
- "\n",
- "- **Dorsal Attention Network**:\n",
- " - *DorsalAttention.IPS (L)*\n",
- " - Associated with **goal-directed attention**, potentially altered in ASD subjects.\n",
- "\n",
- "- **Salience Network**:\n",
- " - *Salience.SMG (R)*\n",
- " - Plays a role in **interoception** and **social-emotional processing**, relevant for ASD symptoms.\n",
- "\n",
- "- **Language Network**:\n",
- " - *Language.pSTG (R)*\n",
- " - Critical for **language comprehension** and **social communication**, often affected in ASD.\n",
- "\n",
- "- **Sensorimotor and Cerebellar Regions**:\n",
- " - *SensoriMotor.Lateral (L)*, *Cerebellar.Posterior*\n",
- " - Linked to **motor coordination** and **sensorimotor integration**, commonly atypical in ASD.\n",
- "\n",
- "The interpretability analysis of the trained model highlights that **functional connectivity alterations across DMN, attention, salience, language, and sensorimotor systems** are key discriminative factors for distinguishing **ASD** from **Control** subjects."
- ],
- "cell_type": "markdown"
- }
- ]
+ ]
}
diff --git a/tutorials/brain-disorder-diagnosis/preprocess.py b/tutorials/brain-disorder-diagnosis/preprocess.py
index 57779f5..9ea88a1 100644
--- a/tutorials/brain-disorder-diagnosis/preprocess.py
+++ b/tutorials/brain-disorder-diagnosis/preprocess.py
@@ -123,18 +123,19 @@ def preprocess_phenotypic_data(data, standardize=False):
@validate_params(
- {"data": ["array-like"], "measures": [list]}, prefer_skip_nested_validation=False
+ {"data": ["array-like"], "measures": [list, tuple]},
+ prefer_skip_nested_validation=False,
)
def extract_functional_connectivity(data, measures=["pearson"]):
"""Extract functional connectivity features from time series data.
Parameters
----------
- data : list[array-like] of shape (n_subjects,)
+ data : list[array-like] or tuple[array-like] of shape (n_subjects,)
An array of numpy arrays, where each array is a time series of shape (t, n_rois).
The time series data for each subject.
- measures : list[str], optional (default=["pearson"])
+ measures : list[str] or tuple[str], optional (default=["pearson"])
A list of connectivity measures to use for feature extraction.
Supported measures are "pearson", "partial", "tangent", "covariance", and "precision".
Multiple measures can be specified as a list to compose a higher-order measure.