-
Notifications
You must be signed in to change notification settings - Fork 2
Update for brain decoding demo #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dcf93a4
f0d92aa
7047c23
092bf9d
448c359
1e32ba2
cf9e370
bff672b
72161f0
11da26e
d22141d
b5b5923
5bbcc39
237691d
afb3b0a
fb65616
8567ddd
ac4a5e5
7a175ba
fe52be1
bebe42f
79f0336
d93784d
1ad406b
544fb50
63d78c4
d41007a
dba1953
df96558
8688472
c0fb3f4
7aa3355
ed386a1
627008b
a584d53
667e097
7efc7c2
763c1b9
ac29f3e
98a8d16
7ba387c
2175be3
fabed46
f11b4ba
42f8329
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,24 +1,34 @@ | ||
| import os | ||
| from yacs.config import CfgNode | ||
|
|
||
| DEFAULT_DIR = os.path.join(os.getcwd(), "data") | ||
|
|
||
| _C = CfgNode() | ||
|
|
||
| # Dataset configuration | ||
| _C.DATASET = CfgNode() | ||
| # Path to the dataset directory | ||
| _C.DATASET.PATH = "nilearn_data" | ||
| _C.DATASET.PATH = DEFAULT_DIR | ||
| # Name of the brain atlas to use | ||
| # Available options: | ||
| # - "aal" (AAL) | ||
| # - "cc200" (Cameron Craddock 200) | ||
| # - "cc400" (Cameron Craddock 400) | ||
| # - "difumo64" (DiFuMo 64) | ||
| # - "dos160" (Dosenbach 160) | ||
| # - "hcp-ica" (HCP-ICA) | ||
| # - "ho" (Harvard-Oxford) | ||
| # - "tt" (Talairach-Tournoux) | ||
| _C.DATASET.ATLAS = "cc200" | ||
| # Whether to apply bandpass filtering | ||
| _C.DATASET.BANDPASS = False | ||
| # Whether to apply global signal regression | ||
| _C.DATASET.GLOBAL_SIGNAL_REGRESSION = False | ||
| # Whether to use only quality-checked data | ||
| _C.DATASET.QUALITY_CHECKED = False | ||
|
|
||
| # Connectivity configuration | ||
| _C.CONNECTIVITY = CfgNode() | ||
| # List of connectivity measures to compute | ||
| _C.CONNECTIVITY.MEASURES = ["pearson"] | ||
| # Functional connectivity to use | ||
| # Available options: | ||
| # - "pearson" | ||
| # - "partial" | ||
| # - "tangent" | ||
| # - "precision" | ||
| # - "covariance" | ||
| # - "tangent-pearson" | ||
| _C.DATASET.FC = "tangent-pearson" | ||
|
|
||
| # Phenotype configuration | ||
| _C.PHENOTYPE = CfgNode() | ||
|
|
@@ -27,37 +37,57 @@ | |
|
|
||
| # Cross-validation configuration | ||
| _C.CROSS_VALIDATION = CfgNode() | ||
| # Cross-validation split method (e.g., leave-p-groups-out) | ||
| # Cross-validation split method | ||
| # Available options: | ||
| # - "skf" (Stratified K-Folds) | ||
| # - "lpgo" (Leave-P-Groups-Out) | ||
| _C.CROSS_VALIDATION.SPLIT = "skf" | ||
| # Number of folds for cross-validation | ||
| # or number of groups for Leave-P-Groups-Out | ||
| _C.CROSS_VALIDATION.NUM_FOLDS = 10 | ||
| # Number of repeats for cross-validation | ||
| _C.CROSS_VALIDATION.NUM_REPEATS = 1 | ||
| _C.CROSS_VALIDATION.NUM_REPEATS = 5 | ||
|
Comment on lines
48
to
+49
|
||
|
|
||
| # Trainer configuration | ||
| _C.TRAINER = CfgNode() | ||
| # Classifier to use (e.g., auto-select) | ||
| # Classifier to use | ||
| # Available options: | ||
| # - "lda" | ||
| # - "lr" | ||
| # - "linear_svm" | ||
| # - "svm" | ||
| # - "ridge" | ||
| # - "auto" | ||
| _C.TRAINER.CLASSIFIER = "lr" | ||
| # Use non-linear transformations | ||
| # Use non-linear transformations (no interpretability) | ||
| _C.TRAINER.NONLINEAR = False | ||
| # Search strategy for hyperparameter tuning | ||
| _C.TRAINER.SEARCH_STRATEGY = "random" | ||
| # Number of iterations for hyperparameter search | ||
| _C.TRAINER.NUM_SEARCH_ITER = 100 | ||
| _C.TRAINER.NUM_SEARCH_ITER = int(1e3) | ||
| # Number of iterations for solver | ||
| _C.TRAINER.NUM_SOLVER_ITER = int(1e6) | ||
| # List of scoring metrics | ||
| # Available options: | ||
| # - "accuracy" | ||
| # - "precision" | ||
| # - "recall" | ||
| # - "f1" | ||
| # - "roc_auc" | ||
| # - "matthews_corrcoef" | ||
| _C.TRAINER.SCORING = ["accuracy", "roc_auc"] | ||
| # Refit based on the best hyperparameters on a scoring metric | ||
| _C.TRAINER.REFIT = "accuracy" | ||
| # Number of parallel jobs (-1: all CPUs, -4: all but 4 CPUs) | ||
| _C.TRAINER.N_JOBS = -4 | ||
| _C.TRAINER.N_JOBS = 1 | ||
| # Pre-dispatch of jobs for parallel processing | ||
| _C.TRAINER.PRE_DISPATCH = "2*n_jobs" | ||
| # Verbosity level | ||
| _C.TRAINER.VERBOSE = 0 | ||
|
|
||
| # Random state for reproducibility | ||
| # Seed for random number generators | ||
| _C.RANDOM_STATE = 0 | ||
| _C.RANDOM_STATE = None | ||
|
|
||
|
|
||
| def get_cfg_defaults(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,160 @@ | ||||||||||
| import os | ||||||||||
| import json | ||||||||||
| import numpy as np | ||||||||||
| import pandas as pd | ||||||||||
| import gdown | ||||||||||
|
|
||||||||||
| from sklearn.utils._param_validation import StrOptions, validate_params | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @validate_params( | ||||||||||
| { | ||||||||||
| "data_dir": [str], | ||||||||||
| "atlas": [ | ||||||||||
| StrOptions( | ||||||||||
zaRizk7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| {"aal", "cc200", "cc400", "difumo64", "dos160", "hcp-ica", "ho", "tt"} | ||||||||||
| ) | ||||||||||
| ], | ||||||||||
| "fc": [ | ||||||||||
| StrOptions( | ||||||||||
| { | ||||||||||
| "pearson", | ||||||||||
| "partial", | ||||||||||
| "tangent", | ||||||||||
| "precision", | ||||||||||
| "covariance", | ||||||||||
| "tangent-pearson", | ||||||||||
| } | ||||||||||
| ) | ||||||||||
| ], | ||||||||||
| "vectorize": [bool], | ||||||||||
| "verbose": [bool], | ||||||||||
| }, | ||||||||||
| prefer_skip_nested_validation=False, | ||||||||||
| ) | ||||||||||
| def load_data( | ||||||||||
| data_dir="data", atlas="cc200", fc="tangent-pearson", vectorize=True, verbose=True | ||||||||||
| ): | ||||||||||
| """ | ||||||||||
| Load functional connectivity data and phenotypic data with gdown support. | ||||||||||
|
|
||||||||||
| This function uses manifest files to download the required files from Google Drive if not present locally. | ||||||||||
| It automatically downloads files listed in manifests/abide.json and folders listed in manifests/atlas.json. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| data_dir : str, optional (default="data") | ||||||||||
| Local directory to store the dataset. | ||||||||||
|
|
||||||||||
| atlas : str, optional (default="cc200") | ||||||||||
| Atlas name (subfolder inside fc/). | ||||||||||
|
|
||||||||||
| fc : str, optional (default="tangent-pearson") | ||||||||||
| Functional connectivity file name (without extension). | ||||||||||
|
|
||||||||||
| vectorize : bool, optional (default=True) | ||||||||||
| Whether to vectorize the upper triangle of the connectivity matrices. | ||||||||||
|
|
||||||||||
| verbose : bool, optional (default=True) | ||||||||||
| Whether to print download and progress messages. | ||||||||||
|
|
||||||||||
| Returns | ||||||||||
| ------- | ||||||||||
| fc_data : np.ndarray | ||||||||||
| Functional connectivity data (vectorized if requested). | ||||||||||
|
|
||||||||||
| phenotypes : pd.DataFrame | ||||||||||
| Loaded phenotypic data. | ||||||||||
|
|
||||||||||
| rois : np.ndarray | ||||||||||
| ROI labels. | ||||||||||
|
|
||||||||||
| coords : np.ndarray | ||||||||||
| ROI coordinates. | ||||||||||
|
|
||||||||||
| Raises | ||||||||||
| ------ | ||||||||||
| FileNotFoundError | ||||||||||
| If the required file paths are not found after attempted download. | ||||||||||
| """ | ||||||||||
| # Paths | ||||||||||
| fc_path = os.path.join(data_dir, "abide", "fc", atlas, f"{fc}.npy") | ||||||||||
| is_proba = atlas in {"difumo64"} | ||||||||||
| atlas_type = "probabilistic" if is_proba else "deterministic" | ||||||||||
| atlas_path = os.path.join(data_dir, "atlas", atlas_type, atlas) | ||||||||||
| phenotypes_path = os.path.join(data_dir, "abide", "phenotypes.csv") | ||||||||||
|
|
||||||||||
| # Ensure all files exist (download if needed) | ||||||||||
| _ensure_abide_file(data_dir, fc_path, verbose) | ||||||||||
| _ensure_abide_file(data_dir, phenotypes_path, verbose) | ||||||||||
| _ensure_atlas_folder(data_dir, atlas_path, verbose) | ||||||||||
|
|
||||||||||
| # Load connectivity data | ||||||||||
| fc_data = np.load(fc_path) | ||||||||||
| if vectorize: | ||||||||||
| row, col = np.triu_indices(fc_data.shape[1], 1) | ||||||||||
| fc_data = fc_data[..., row, col] | ||||||||||
|
|
||||||||||
| phenotypes = pd.read_csv(phenotypes_path) | ||||||||||
|
|
||||||||||
| with open(os.path.join(atlas_path, "labels.txt"), "r") as f: | ||||||||||
| rois = np.array(f.read().strip().split("\n")) | ||||||||||
| coords = np.load(os.path.join(atlas_path, "coords.npy")) | ||||||||||
|
|
||||||||||
| return fc_data, phenotypes, rois, coords | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _ensure_abide_file(data_dir, target_path, verbose): | ||||||||||
| """Ensure abide file exists locally; download from manifest if missing.""" | ||||||||||
| if os.path.exists(target_path): | ||||||||||
| if verbose: | ||||||||||
| print(f"✔ File found: {target_path}") | ||||||||||
| return | ||||||||||
|
|
||||||||||
| manifest_path = os.path.join(os.path.dirname(__file__), "manifests", "abide.json") | ||||||||||
| with open(manifest_path, "r") as f: | ||||||||||
| manifest = json.load(f) | ||||||||||
|
Comment on lines
+114
to
+116
|
||||||||||
| manifest_path = os.path.join(os.path.dirname(__file__), "manifests", "abide.json") | |
| with open(manifest_path, "r") as f: | |
| manifest = json.load(f) | |
| manifest = _load_abide_manifest() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,11 @@ | ||
| DATASET: | ||
| ATLAS: aal | ||
| ATLAS: hcp-ica | ||
|
|
||
| CROSS_VALIDATION: | ||
| NUM_REPEATS: 1 | ||
|
|
||
| TRAINER: | ||
| NUM_SEARCH_ITER: 50 | ||
| NUM_SEARCH_ITER: 20 | ||
| NUM_SOLVER_ITER: 100 | ||
|
|
||
| RANDOM_STATE: 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default
NUM_REPEATShere (5) conflicts withNUM_REPEATS: 1inbase.yml; align these values to avoid confusion.