# Reducing Phenotypical Effect to Improve Multi-site Autism Classification Performance

In this tutorial, we will show how to leverage patient's phenotypic information to reduce the site-dependencies of functional connectivity data using domain adaptation for improving multi-site autism classification performance.

The basis of this notebook is to extend the original work by Kunda et al. from IEEE TMI 2022 that proposes a second-order functional connectivity named Tangent-Pearson, the Tangent correlation of the correlation and the application of domain adaptation for neuroimaging to reduce site-dependencies.

Kunda et al. previously applied domain adaptation by only leveraging the site labels. We will extend this work by applying domain adaptation using various phenotypic information like sex, handedness, age, and eye status.

Our **objectives** are to:
1. Load the ABIDE dataset with its available preprocessing pipelines and atlasses.
2. Extracting functional connectivity from the time series extracted from the preprocessed scans.
3. Preprocess the phenotypic information to be used for domain adaptation and obtaining the classification and site labels.
4. Creating a pipeline to train and evaluate the performance.

# Setup

As a starting point, we need to install some packages and have provided helper functions to assist in this tutorial.

## Packages

The packages that we will require for this tutorial includes PyKale and Nilearn. PyKale is an interdisciplinary machine learning library internally developed at the University of Sheffield and Nilearn is a neuroimaging library mainly intended for fMRI data analysis.

In [1]:
!pip install --quiet git+https://github.com/pykale/pykale@main nilearn \
    && echo "PyKale and Nilearn installed successfully ✅" \
    || echo "Failed to install PyKale and Nilearn ❌"

PyKale and Nilearn installed successfully ✅


## Helper Functions

The helper functions is adopted from the [source code](https://github.com/zaRizk7/abide-demo) originally intended for demonstrating the use cases of containers to improve reproducibility and reusability in ML experiments.

### Feature Extractions

Contains the imputation, categorical mapping, continuous standardization for the selected phenotypes and chaining functional connectivity extraction for the time series. 

In [2]:
import logging

import numpy as np
import pandas as pd
from nilearn.connectome import ConnectivityMeasure
from sklearn.preprocessing import StandardScaler
from sklearn.utils._param_validation import (
    Integral,
    Interval,
    StrOptions,
    validate_params,
)

SELECTED_PHENOTYPES = [
    "SUB_ID",
    "SITE_ID",
    "SEX",
    "AGE_AT_SCAN",
    "FIQ",
    "HANDEDNESS_CATEGORY",
    "EYE_STATUS_AT_SCAN",
    "DX_GROUP",
]

MAPPING = {
    "SEX": {1: "MALE", 2: "FEMALE"},
    "HANDEDNESS_CATEGORY": {
        "L": "LEFT",
        "R": "RIGHT",
        "Mixed": "AMBIDEXTROUS",
        "Ambi": "AMBIDEXTROUS",
        "L->R": "AMBIDEXTROUS",
        "R->L": "AMBIDEXTROUS",
        "-9999": "LEFT",
        np.nan: "LEFT",
    },
    "EYE_STATUS_AT_SCAN": {1: "OPEN", 2: "CLOSED"},
    "DX_GROUP": {1: "ASD", 2: "CONTROL"},
}

AVAILABLE_FC_MEASURES = {
    "pearson": "correlation",
    "partial": "partial correlation",
    "tangent": "tangent",
    "covariance": "covariance",
    "precision": "precision",
}


@validate_params(
    {
        "data": [pd.DataFrame],
        "standardize": [StrOptions({"site", "all"}), "boolean"],
        "verbose": ["verbose"],
    },
    prefer_skip_nested_validation=False,
)
def process_phenotypic_data(data, standardize=False, verbose=0):
    """Process phenotypic data to impute missing values and and encode categorical
    variables including sex, handedness, eye status at scan, and diagnostic group.

    Parameters
    ----------
    data : pd.DataFrame of shape (n_subjects, n_phenotypes)
        The phenotypes data to be processed.

    standardize: boolean or str of ("site", "all")
                Standardize FIQ and age. The default is 0.
                Setting to True or "all" standardizes the
                values over all subjects while "site"
                standardizes according to the site.

    verbose : int, optional
            The verbosity level. The default is 0.
            verbose > 0 will log the current processing step.

    Returns
    -------
    labels : pd.Series of shape (n_subjects)
            The encoded classification group. 0 is "CONTROL" and
            1 is "ASD"

    phenotypes : pd.DataFrame of shape (n_subjects, n_selected_phenotypes)
                The processed selected phenotype data with imputed values.
    """
    logger = logging.getLogger("feature_extraction.process_phenotypic_data")
    if verbose > 0:
        logger.setLevel(logging.INFO)
        logger.info("Imputing missing values and encoding handedness...")

    # Avoid in-place modification
    data = data.copy()

    # Check for missing values, either -9999 or NaN
    # and impute them with FIQ = 100 following original code.
    fiq = data["FIQ"].copy()
    data["FIQ"] = fiq.where((fiq != -9999) & (~np.isnan(fiq)), 100)

    # Standardize FIQ and age by site
    if standardize == "site":
        for site in data["SITE_ID"].unique():
            mask = site == data["SITE_ID"]
            values = data.loc[mask, ["AGE_AT_SCAN", "FIQ"]]
            values = StandardScaler().fit_transform(values)
            data.loc[mask, ["AGE_AT_SCAN", "FIQ"]] = values
    elif standardize:
        values = data.loc[:, ["AGE_AT_SCAN", "FIQ"]]
        values = StandardScaler().fit_transform(values)
        data.loc[:, ["AGE_AT_SCAN", "FIQ"]] = values

    # Encode categorical variables to be more explicit categorical
    # values. For handedness, if we found missing values, we
    # impute them by using 'LEFT' as default. Values
    # like 'Ambi', 'Mixed', 'L->R', and 'R->L' are mapped to
    # 'AMBIDEXTROUS'. The rest of the values are mapped to 'LEFT' or 'RIGHT'
    # for 'L' or 'R' respectively.
    for key in MAPPING:
        values = data[key].copy().map(MAPPING[key])
        data[key] = values.astype("category")

    # Subsets the phenotypes
    data = data[SELECTED_PHENOTYPES].set_index("SUB_ID")

    # Separate the class labels, sites, and phenotypes
    labels = data["DX_GROUP"].map({"CONTROL": 0, "ASD": 1})
    sites = data["SITE_ID"]
    phenotypes = data.drop(columns=["DX_GROUP"])
    # One-hot encode categorical valued phenotypes
    phenotypes = pd.get_dummies(phenotypes)

    if verbose > 0:
        logger.info("Imputation and encoding completed.")

    return labels, sites, phenotypes


@validate_params(
    {
        "data": ["array-like"],
        "measures": [list],
        "verbose": ["verbose"],
    },
    prefer_skip_nested_validation=False,
)
def extract_functional_connectivity(data, measures=["pearson"], verbose=0):
    """Extract functional connectivity features from time series data.

    Parameters
    ----------
    data : list[array-like] of shape (n_subjects,)
        An array of numpy arrays, where each array is a time series of shape (t, n_rois).
        The time series data for each subject.

    measures : list[str], optional
        A list of connectivity measures to use for feature extraction.
        The default is ["pearson"].
        Supported measures are "pearson", "partial", "tangent", "covariance", and "precision".
        Multiple measures can be specified as a list to compose a higher-order measure.

    verbose : int, optional
        The verbosity level. The default is 0.
        verbose > 0 will log the current processing step.

    Returns
    -------
    features : array-like
        An array of shape (n_subjects, n_features) containing the extracted features.
        n_features is equal to `n_rois * (n_rois - 1) / 2` for each subjects.
    """
    if verbose > 0:
        logger = logging.getLogger("feature_extraction.extract_functional_connectivity")
        logger.setLevel(logging.INFO)
        logger.info("Extracting functional connectivity features...")
        logger.info(f"Using measures: {measures}")

    for i, k in enumerate(reversed(measures), 1):
        k = AVAILABLE_FC_MEASURES.get(k)

        # If it is the last transformation, vectorize and discard the diagonal
        # of shape (n_rois * (n_rois - 1) / 2)
        islast = i == len(measures)
        measure = ConnectivityMeasure(kind=k, vectorize=islast, discard_diagonal=islast)
        data = measure.fit_transform(data)

    if verbose > 0:
        logger.info("Functional connectivity features extracted.")

    return data

### Trainer

Contains the hyperparameter grid and a wrapper function to train the pipeline with or without domain adaptation.

In [3]:
import logging

import numpy as np
from kale.pipeline.mida_trainer import MIDATrainer
from sklearn.base import clone
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression, RidgeClassifier
from sklearn.metrics import get_scorer_names
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, check_cv
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.utils._param_validation import (
    Integral,
    Interval,
    StrOptions,
    validate_params,
)

__all__ = ["create_trainer"]

# Inverse regularization coefficients for the classifiers
# For Ridge (alpha) and MIDA (mu and eta), we use 1 / (2C)
C = np.logspace(start=-15, stop=15, num=30 + 1, base=2)

CLASSIFIER = {
    "logistic": LogisticRegression(),
    "svm": LinearSVC(),
    "ridge": RidgeClassifier(),
}

CLASSIFIER_GRID = {
    "logistic": {"C": C},
    "svm": {"C": C},
    "ridge": {"alpha": 1 / (2 * C)},
}

MIDA_GRID = {
    "num_components": [32, 64, 128, 256, None],
    "kernel": ["linear", "rbf"],
    "mu": 1 / (2 * C),
    "eta": 1 / (2 * C),
    "ignore_y": [True, False],
    "augment": [True, False],
}
MIDA_GRID = {f"domain_adapter__{key}": value for key, value in MIDA_GRID.items()}


@validate_params(
    {
        "classifier": [StrOptions({"logistic", "svm", "ridge"})],
        "mida": ["boolean"],
        "search_strategy": [StrOptions({"grid", "random"})],
        "cv": ["cv_object"],
        "scoring": [StrOptions(set(get_scorer_names())), list, None],
        "num_solver_iterations": [Interval(Integral, 1, None, closed="left")],
        "num_search_iterations": [Interval(Integral, 1, None, closed="left")],
        "num_jobs": [Integral, None],
        "random_state": ["random_state"],
        "verbose": ["verbose"],
    },
    prefer_skip_nested_validation=False,
)
def create_trainer(
    classifier="logistic",
    mida=False,
    search_strategy="grid",
    cv=None,
    scoring=None,
    num_solver_iterations=100,
    num_search_iterations=10,
    num_jobs=None,
    random_state=None,
    verbose=0,
):
    """Create a trainer for a classification model.

    Parameters
    ----------
    classifier : str, default="logistic"
        The classifier to use. Can be "logistic", "svm", or "ridge".

    mida : bool, default=False
        Whether to use MIDA for site-dependency reduction.

    search_strategy : str, default="grid"
        The search strategy for hyperparameter tuning. Can be "grid" or "random".

    cv : int, cross-validation generator, or iterable, default=None
        The cross-validation splitting strategy. If None, the default 5-fold
        cross-validation is used.

    scoring : str, list of str, callable, or None, default=None
        A single string or a list of strings to use as the scoring metric(s).
        If None, the default scoring metric for the classifier is used.

    num_solver_iterations : int, default=100
        The number of iterations for the solver. This is used to set the
        max_iter parameter of the classifier.

    num_search_iterations : int, default=10
        The number of iterations for the random search. This is only used
        if search_strategy is "random".

    num_jobs : int, default=None
        The number of jobs to run in parallel with joblib.Parallel. If None,
        the number of jobs is set to run on a single core.

    random_state : int, RandomState instance, or None, default=None
        The random seed for the random number generator. If None, the
        random state is not set.

    Returns
    -------
    trainer : sklearn.model_selection.BaseSearchCV or MIDATrainer
        The model trainer object. This can be either a GridSearchCV,
        RandomizedSearchCV, or MIDATrainer object.
    """
    if verbose > 0:
        logger = logging.getLogger("modeling.create_trainer")
        logger.setLevel(logging.INFO)

        logger.info(f"Creating trainer with classifier: {classifier}")
        logger.info(f"Using MIDA: {mida}")
        logger.info(f"Search strategy: {search_strategy}")
        logger.info(f"Scoring: {scoring}")
        logger.info(f"Number of solver iterations: {num_solver_iterations}")
        logger.info(f"Number of search iterations: {num_search_iterations}")
        logger.info(f"Number of jobs: {num_jobs}")
        logger.info(f"Random state: {random_state}")

    # Generate classifier with its parameter grid
    clf = clone(CLASSIFIER[classifier])
    clf.set_params(max_iter=num_solver_iterations, random_state=random_state)
    param_grid = clone(CLASSIFIER_GRID[classifier], safe=False)

    # Update with MIDA's parameters if we are using MIDA
    if mida:
        param_grid.update(MIDA_GRID)

    # Construct trainer
    trainer_args = {
        "cv": check_cv(cv, [0, 1], classifier=True),
        "scoring": scoring,
        "refit": scoring[0] if isinstance(scoring, list) else scoring,
        "n_jobs": num_jobs,
        "error_score": "raise",
        "verbose": verbose,
    }

    if verbose > 0:
        logger.info("Finished constructing trainer.")

    if mida:
        return MIDATrainer(
            estimator=clf,
            param_grid=param_grid,
            search_strategy=search_strategy,
            num_iter=num_search_iterations,
            random_state=random_state,
            **trainer_args,
        )

    if search_strategy == "grid":
        return GridSearchCV(estimator=clf, param_grid=param_grid, **trainer_args)

    return RandomizedSearchCV(
        estimator=clf,
        param_distributions=param_grid,
        n_iter=num_search_iterations,
        random_state=random_state,
        **trainer_args,
    )

# Pipeline

## Resting-state fMRI Preprocessing

Usually, we need to preprocess the fMRI scans first before running the pipeline. However, ABIDE dataset provides several preprocessed subsets that can be downloaded directly. The ones we are going to focus on includes:
- `atlas`: Brain atlas used for extracting the time series. Available ones are: `"aal"`, `"cc200"`, `"cc400"`, `"dosenbach160"`, `"ez"`, `"ho"`, and `"tt"`. Default: `"cc200"`.
- `bp`: Band-pass filter signals between 0.01Hz and 0.1Hz. Default: `False`.
- `gsr`: Applies global signal regression on the signals. Default: `False`.
- `qc`: Only use scans that passes all quality checks. Default: `True`.

In [4]:
atlas = "cc200"
bp = False
gsr = False
qc = True

In [5]:
from nilearn.datasets import fetch_abide_pcp

dataset = fetch_abide_pcp(
    derivatives=[f"rois_{atlas}"],
    band_pass_filtering=bp,
    global_signal_regression=gsr,
    quality_checked=qc,
)

[get_dataset_dir] Dataset found in /home/zarizky/nilearn_data/ABIDE_pcp


## Phenotype Preprocessing 

The phenotypic information comes with several missing data. To utilize it for modeling, we need to impute and encode the missing values. Categorical phenotypes that we will use are `SITE_ID`, `SEX`, `HANDEDNESS_CATEGORY`, `EYE_STATUS_AT_SCAN` which will be one-hot encoded while the continuous ones including `AGE_AT_SCAN` and `FIQ` will be optionally standardized by defining the argument options:
- `standardize`: Standardization strategy for subject's age and FIQ. If set to `True` or `"all"`, standardizes the values over all subjects while `"site"` standardizes according to the sites. Default: `False`.

There are several missing `HANDEDNESS_CATEGORY` values, we consider that the missing ones by default are right-handed subjects while for `FIQ`, we impute the missing values by setting them to `100`.

The labels that assigns control and ASD group is the `DX_GROUP` phenotype that is binary encoded such that `CONTROL` and `ASD` is assigned to `0` and `1` respectively.

In [6]:
standardize = "site"

In [7]:
labels, sites, phenotypes = process_phenotypic_data(dataset["phenotypic"], standardize)

## Feature Extraction

- `measures`: Sequences of connectivity measure transformation to extract features from the time series. Available ones are `"pearson"`, `"partial"`, `"tangent"`, `"covariance"`, and `"precision"`. Default: `["pearson"]`.

In [8]:
measures = ["pearson"]

In [9]:
features = extract_functional_connectivity(dataset[f"rois_{atlas}"], measures)

## Modeling

### Random Seed

In [10]:
seed = 0

In [11]:
from sklearn.utils.validation import check_random_state

random_state = check_random_state(seed)

### Cross-Validation Split

In [12]:
split = "skf"
num_folds = 10
num_cv_repeats = 5

In [13]:
from sklearn.model_selection import LeavePGroupsOut, RepeatedStratifiedKFold

cv = RepeatedStratifiedKFold(
    n_splits=num_folds,
    n_repeats=num_cv_repeats,
    random_state=random_state,
)

if split == "lpgo":
    cv = LeavePGroupsOut(num_folds)

### Model Definition

In [14]:
classifier = "logistic"
mida = False
search_strategy = "random"
scoring = ["accuracy", "roc_auc"]
num_solver_iterations = 100
num_search_iterations = 10
num_jobs = None

In [15]:
trainer = create_trainer(
    classifier,
    mida,
    search_strategy,
    cv,
    scoring,
    num_solver_iterations,
    num_search_iterations,
    num_jobs,
    random_state,
)

### Training

In [16]:
fit_args = {"x" if mida else "X": features, "y": labels, "groups": sites}

if mida and site_only:
    fit_args["factors"] = pd.get_dummies(groups)
elif mida:
    fit_args["factors"] = phenotypes

trainer.fit(**fit_args)



### Evaluation

In [17]:
cv_results = pd.DataFrame(trainer.cv_results_)
cv_results = cv_results[
    [f"{aggregate}_test_{score}" for score in scoring for aggregate in ["mean", "std"]]
]

cv_results = cv_results.sort_values("mean_test_accuracy", ascending=False)
cv_results = cv_results.round(4).reset_index(drop=True)
cv_results.index.name = "Rank"

In [18]:
cv_results

Unnamed: 0_level_0,mean_test_accuracy,std_test_accuracy,mean_test_roc_auc,std_test_roc_auc
Rank,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.6966,0.0434,0.7568,0.0405
1,0.695,0.0421,0.7548,0.0399
2,0.6948,0.0317,0.7476,0.0384
3,0.6939,0.0314,0.7477,0.0384
4,0.6936,0.0318,0.747,0.0388
5,0.6934,0.0325,0.7482,0.0385
6,0.6934,0.0318,0.7479,0.0384
7,0.6886,0.0373,0.7515,0.0388
8,0.687,0.036,0.7466,0.0373
9,0.5612,0.0225,0.6736,0.0655
