Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
dcf93a4
cast site to numpy
zaRizk7 Jun 4, 2025
f0d92aa
add num_solver_iter and rename extension
zaRizk7 Jun 4, 2025
7047c23
update notebook objectives and trainer imports
zaRizk7 Jun 4, 2025
092bf9d
update base exp yaml
zaRizk7 Jun 5, 2025
448c359
use skf by default
zaRizk7 Jun 5, 2025
1e32ba2
add handle for google colab runtime
zaRizk7 Jun 5, 2025
cf9e370
update output
zaRizk7 Jun 5, 2025
bff672b
reduce preprocess_phenotypic_data functionality and use polars to rep…
zaRizk7 Jun 16, 2025
72161f0
use polars to replace pandas
zaRizk7 Jun 16, 2025
11da26e
add manifest and load_data function to fetch data from gdrive
zaRizk7 Jun 16, 2025
d22141d
update default cfg and base exp yml
zaRizk7 Jun 16, 2025
b5b5923
update notebook contents
zaRizk7 Jun 16, 2025
5bbcc39
add polars and gdown to req
zaRizk7 Jun 16, 2025
237691d
change nilearn req
zaRizk7 Jun 16, 2025
afb3b0a
remove param_validation
zaRizk7 Jun 16, 2025
fb65616
add handle to prioritize site-packages for colab
zaRizk7 Jun 16, 2025
8567ddd
use single core only
zaRizk7 Jun 16, 2025
ac4a5e5
update pre_dispatch config
zaRizk7 Jun 16, 2025
7a175ba
add --user to handle site-packages
zaRizk7 Jun 16, 2025
fe52be1
use default n_jobs
zaRizk7 Jun 16, 2025
bebe42f
fallback to pandas
zaRizk7 Jun 16, 2025
79f0336
update config and base yml
zaRizk7 Jun 16, 2025
d93784d
remove polars
zaRizk7 Jun 16, 2025
1ad406b
use tangent-pearson by default
zaRizk7 Jun 16, 2025
544fb50
remove fc cfg
zaRizk7 Jun 16, 2025
63d78c4
reduce search iter
zaRizk7 Jun 16, 2025
d41007a
update notebook with new cfg
zaRizk7 Jun 16, 2025
dba1953
Merge branch 'main' into brain-decoding
zaRizk7 Jun 16, 2025
df96558
revert to use param_validation for load_data
zaRizk7 Jun 16, 2025
8688472
fix pydoc typo
zaRizk7 Jun 16, 2025
c0fb3f4
explicitly name loaded fc as fc_data
zaRizk7 Jun 16, 2025
7aa3355
add dirname(__file__) to prevent relative dir errors
zaRizk7 Jun 16, 2025
ed386a1
use dirname(__file__) for atlas_folder
zaRizk7 Jun 16, 2025
627008b
remove note for colab
zaRizk7 Jun 16, 2025
a584d53
remove check_random_state
zaRizk7 Jun 16, 2025
667e097
annotate config for classifier and split
zaRizk7 Jun 16, 2025
7efc7c2
fix seed with trainer
zaRizk7 Jun 16, 2025
763c1b9
include cc400 in the validation for load_data
zaRizk7 Jun 16, 2025
ac29f3e
update comments
zaRizk7 Jun 16, 2025
98a8d16
remove nilearn imports
zaRizk7 Jun 16, 2025
7ba387c
remove aal imports
zaRizk7 Jun 16, 2025
2175be3
remove unused seaborn import
zaRizk7 Jun 16, 2025
fabed46
reformat load_data validation
zaRizk7 Jun 16, 2025
f11b4ba
resolve missing return in pydoc
zaRizk7 Jun 16, 2025
42f8329
update markdown per-section
zaRizk7 Jun 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ matplotlib==3.10.3
seaborn==0.13.2
numpy==1.26.4
git+https://github.com/pykale/pykale@main
nilearn==0.11.1
nilearn==0.10.4
yacs==0.1.8
gdown==5.2.0
68 changes: 49 additions & 19 deletions tutorials/brain-disorder-diagnosis/config.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
import os
from yacs.config import CfgNode

DEFAULT_DIR = os.path.join(os.getcwd(), "data")

_C = CfgNode()

# Dataset configuration
_C.DATASET = CfgNode()
# Path to the dataset directory
_C.DATASET.PATH = "nilearn_data"
_C.DATASET.PATH = DEFAULT_DIR
# Name of the brain atlas to use
# Available options:
# - "aal" (AAL)
# - "cc200" (Cameron Craddock 200)
# - "cc400" (Cameron Craddock 400)
# - "difumo64" (DiFuMo 64)
# - "dos160" (Dosenbach 160)
# - "hcp-ica" (HCP-ICA)
# - "ho" (Harvard-Oxford)
# - "tt" (Talairach-Tournoux)
_C.DATASET.ATLAS = "cc200"
# Whether to apply bandpass filtering
_C.DATASET.BANDPASS = False
# Whether to apply global signal regression
_C.DATASET.GLOBAL_SIGNAL_REGRESSION = False
# Whether to use only quality-checked data
_C.DATASET.QUALITY_CHECKED = False

# Connectivity configuration
_C.CONNECTIVITY = CfgNode()
# List of connectivity measures to compute
_C.CONNECTIVITY.MEASURES = ["pearson"]
# Functional connectivity to use
# Available options:
# - "pearson"
# - "partial"
# - "tangent"
# - "precision"
# - "covariance"
# - "tangent-pearson"
_C.DATASET.FC = "tangent-pearson"

# Phenotype configuration
_C.PHENOTYPE = CfgNode()
Expand All @@ -27,37 +37,57 @@

# Cross-validation configuration
_C.CROSS_VALIDATION = CfgNode()
# Cross-validation split method (e.g., leave-p-groups-out)
# Cross-validation split method
# Available options:
# - "skf" (Stratified K-Folds)
# - "lpgo" (Leave-P-Groups-Out)
_C.CROSS_VALIDATION.SPLIT = "skf"
# Number of folds for cross-validation
# or number of groups for Leave-P-Groups-Out
_C.CROSS_VALIDATION.NUM_FOLDS = 10
# Number of repeats for cross-validation
_C.CROSS_VALIDATION.NUM_REPEATS = 1
_C.CROSS_VALIDATION.NUM_REPEATS = 5
Copy link

Copilot AI Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default NUM_REPEATS here (5) conflicts with NUM_REPEATS: 1 in base.yml; align these values to avoid confusion.

Suggested change
_C.CROSS_VALIDATION.NUM_REPEATS = 5
_C.CROSS_VALIDATION.NUM_REPEATS = 1

Copilot uses AI. Check for mistakes.
Comment on lines 48 to +49
Copy link

Copilot AI Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This default of 5 repeats conflicts with the NUM_REPEATS: 1 set in base.yml. It’s easy to confuse users when defaults diverge; consider aligning these values or documenting which configuration takes precedence.

Copilot uses AI. Check for mistakes.

# Trainer configuration
_C.TRAINER = CfgNode()
# Classifier to use (e.g., auto-select)
# Classifier to use
# Available options:
# - "lda"
# - "lr"
# - "linear_svm"
# - "svm"
# - "ridge"
# - "auto"
_C.TRAINER.CLASSIFIER = "lr"
# Use non-linear transformations
# Use non-linear transformations (no interpretability)
_C.TRAINER.NONLINEAR = False
# Search strategy for hyperparameter tuning
_C.TRAINER.SEARCH_STRATEGY = "random"
# Number of iterations for hyperparameter search
_C.TRAINER.NUM_SEARCH_ITER = 100
_C.TRAINER.NUM_SEARCH_ITER = int(1e3)
# Number of iterations for solver
_C.TRAINER.NUM_SOLVER_ITER = int(1e6)
# List of scoring metrics
# Available options:
# - "accuracy"
# - "precision"
# - "recall"
# - "f1"
# - "roc_auc"
# - "matthews_corrcoef"
_C.TRAINER.SCORING = ["accuracy", "roc_auc"]
# Refit based on the best hyperparameters on a scoring metric
_C.TRAINER.REFIT = "accuracy"
# Number of parallel jobs (-1: all CPUs, -4: all but 4 CPUs)
_C.TRAINER.N_JOBS = -4
_C.TRAINER.N_JOBS = 1
# Pre-dispatch of jobs for parallel processing
_C.TRAINER.PRE_DISPATCH = "2*n_jobs"
# Verbosity level
_C.TRAINER.VERBOSE = 0

# Random state for reproducibility
# Seed for random number generators
_C.RANDOM_STATE = 0
_C.RANDOM_STATE = None


def get_cfg_defaults():
Expand Down
160 changes: 160 additions & 0 deletions tutorials/brain-disorder-diagnosis/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os
import json
import numpy as np
import pandas as pd
import gdown

from sklearn.utils._param_validation import StrOptions, validate_params


@validate_params(
{
"data_dir": [str],
"atlas": [
StrOptions(
{"aal", "cc200", "cc400", "difumo64", "dos160", "hcp-ica", "ho", "tt"}
)
],
"fc": [
StrOptions(
{
"pearson",
"partial",
"tangent",
"precision",
"covariance",
"tangent-pearson",
}
)
],
"vectorize": [bool],
"verbose": [bool],
},
prefer_skip_nested_validation=False,
)
def load_data(
data_dir="data", atlas="cc200", fc="tangent-pearson", vectorize=True, verbose=True
):
"""
Load functional connectivity data and phenotypic data with gdown support.

This function uses manifest files to download the required files from Google Drive if not present locally.
It automatically downloads files listed in manifests/abide.json and folders listed in manifests/atlas.json.

Parameters
----------
data_dir : str, optional (default="data")
Local directory to store the dataset.

atlas : str, optional (default="cc200")
Atlas name (subfolder inside fc/).

fc : str, optional (default="tangent-pearson")
Functional connectivity file name (without extension).

vectorize : bool, optional (default=True)
Whether to vectorize the upper triangle of the connectivity matrices.

verbose : bool, optional (default=True)
Whether to print download and progress messages.

Returns
-------
fc_data : np.ndarray
Functional connectivity data (vectorized if requested).

phenotypes : pd.DataFrame
Loaded phenotypic data.

rois : np.ndarray
ROI labels.

coords : np.ndarray
ROI coordinates.

Raises
------
FileNotFoundError
If the required file paths are not found after attempted download.
"""
# Paths
fc_path = os.path.join(data_dir, "abide", "fc", atlas, f"{fc}.npy")
is_proba = atlas in {"difumo64"}
atlas_type = "probabilistic" if is_proba else "deterministic"
atlas_path = os.path.join(data_dir, "atlas", atlas_type, atlas)
phenotypes_path = os.path.join(data_dir, "abide", "phenotypes.csv")

# Ensure all files exist (download if needed)
_ensure_abide_file(data_dir, fc_path, verbose)
_ensure_abide_file(data_dir, phenotypes_path, verbose)
_ensure_atlas_folder(data_dir, atlas_path, verbose)

# Load connectivity data
fc_data = np.load(fc_path)
if vectorize:
row, col = np.triu_indices(fc_data.shape[1], 1)
fc_data = fc_data[..., row, col]

phenotypes = pd.read_csv(phenotypes_path)

with open(os.path.join(atlas_path, "labels.txt"), "r") as f:
rois = np.array(f.read().strip().split("\n"))
coords = np.load(os.path.join(atlas_path, "coords.npy"))

return fc_data, phenotypes, rois, coords


def _ensure_abide_file(data_dir, target_path, verbose):
"""Ensure abide file exists locally; download from manifest if missing."""
if os.path.exists(target_path):
if verbose:
print(f"✔ File found: {target_path}")
return

manifest_path = os.path.join(os.path.dirname(__file__), "manifests", "abide.json")
with open(manifest_path, "r") as f:
manifest = json.load(f)
Comment on lines +114 to +116
Copy link

Copilot AI Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loading and parsing the full manifest JSON on every missing‐file check can be costly for large manifests. Consider loading each manifest once (e.g., module‐level cache) rather than inside each helper call.

Suggested change
manifest_path = os.path.join(os.path.dirname(__file__), "manifests", "abide.json")
with open(manifest_path, "r") as f:
manifest = json.load(f)
manifest = _load_abide_manifest()

Copilot uses AI. Check for mistakes.

rel_path = os.path.relpath(target_path, data_dir).replace("\\", "/")
for file_entry in manifest:
if file_entry["path"] == rel_path:
if verbose:
print(f"⬇ Downloading {rel_path} ...")
os.makedirs(os.path.dirname(target_path), exist_ok=True)
gdown.download(file_entry["url"], output=target_path, quiet=not verbose)
if os.path.exists(target_path):
return
else:
break

raise FileNotFoundError(f"File not found and not found in manifest: {target_path}")


def _ensure_atlas_folder(data_dir, atlas_path, verbose):
"""Ensure atlas folder exists locally; download using gdown.download_folder if missing."""
if os.path.exists(atlas_path):
if verbose:
print(f"✔ Atlas folder found: {atlas_path}")
return

manifest_path = os.path.join(os.path.dirname(__file__), "manifests", "atlas.json")
with open(manifest_path, "r") as f:
manifest = json.load(f)

rel_path = os.path.relpath(atlas_path, data_dir).replace("\\", "/")
for folder_entry in manifest:
if folder_entry["path"] == rel_path:
if verbose:
print(f"⬇ Downloading atlas folder {rel_path} ...")
os.makedirs(os.path.dirname(atlas_path), exist_ok=True)
gdown.download_folder(
id=folder_entry["id"], output=atlas_path, quiet=not verbose
)
if os.path.exists(atlas_path):
return
else:
break

raise FileNotFoundError(
f"Atlas folder not found and not found in manifest: {atlas_path}"
)
9 changes: 7 additions & 2 deletions tutorials/brain-disorder-diagnosis/experiments/base.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
DATASET:
ATLAS: aal
ATLAS: hcp-ica

CROSS_VALIDATION:
NUM_REPEATS: 1

TRAINER:
NUM_SEARCH_ITER: 50
NUM_SEARCH_ITER: 20
NUM_SOLVER_ITER: 100

RANDOM_STATE: 0
Loading