# Structural MRI Foundation Models: Systematic Benchmarking and Evaluation

In [None]:
%load_ext slurm_magic

## CAT12 Preprocessing | AnatCL | HBN 

In [None]:
%%writefile ~/links/projects/rrg-glatard/arelbaha/HBN_BIDS/new_cat12_preprocessing.py

import os
import glob
import boutiques
from boutiques import bosh
from boutiques.descriptor2func import function

boutiques_descriptor = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/descriptors/cat12_prepro.json"
base_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_BIDS"
output_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/preprocessed_output"

bosh(["exec", 
      "prepare", 
      boutiques_descriptor, 
      "--imagepath", 
      "/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif"
])

cat12 = boutiques.descriptor2func.function(boutiques_descriptor)

t1_nii_files = glob.glob(os.path.join(base_dir, "sub-*", "ses-*", "anat", "sub-*_T1w.nii"))
print(f"Found {len(t1_nii_files)} T1w files.")

task_id = int(os.environ["SLURM_ARRAY_TASK_ID"])

if task_id >= len(t1_nii_files) or task_id < 0:
    print(f"SLURM_ARRAY_TASK_ID={task_id} out of range")
    exit(1)

input_file = t1_nii_files[task_id]

path_parts = input_file.split(os.sep)
subject_id = next((part for part in path_parts if part.startswith("sub-")), None)

if subject_id is None:
        print(f"No subject ID in path: {input_file}")
        exit(1)

filename = os.path.basename(input_file)

subject_output_dir = os.path.join(output_dir, subject_id)
os.makedirs(subject_output_dir, exist_ok=True)

print(f"Processing={task_id} -> file: {filename} (subject: {subject_id})")

result = cat12('--imagepath=/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif', 
                   input_file=input_file, 
                   output_dir=subject_output_dir)

result_dict = vars(result)
print("Available result keys:", result_dict.keys())
print("\nExit code:", result_dict.get("exit_code"))
print("\nStdout:\n", result_dict.get ("stdout" ))
print("\nStderr:\n", result_dict.get ("stderr"))
print("\nOutput files:\n", result_dict.get("output_files"))

In [None]:
%%sbatch --array=500-520
#!/bin/bash
#SBATCH --job-name=CAT12_preproc
#SBATCH --account=rrg-glatard
#SBATCH --mem=12G
#SBATCH --cpus-per-task=2
#SBATCH --nodes=1
#SBATCH --output=CAT12_hbn_preproc_%A_%a.out
#SBATCH --error=CAT12_hbn_preproc_%A_%a.err
#SBATCH --time=1:0:0

source ~/.venvs/jupyter_py3/bin/activate

module load apptainer

cd ~/links/projects/rrg-glatard/arelbaha/HBN_BIDS #Location of preprocessing data and script

echo "Running task ID: $SLURM_ARRAY_TASK_ID"

python new_cat12_preprocessing.py

## CAT12 Preprocessing | AnatCL | PPMI

In [None]:
%%writefile ~/links/projects/def-glatard/arelbaha/data/inputs/ppmicat12_preprocessing.py

import os
import glob
import boutiques
from boutiques import bosh
from boutiques.descriptor2func import function

#Downloading Container Part + Paths
boutiques_descriptor = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/descriptors/cat12_prepro.json"
base_dir = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs"
output_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/preprocessed_output"

bosh(["exec", "prepare", boutiques_descriptor, "--imagepath", "/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif"])
cat12 = boutiques.descriptor2func.function(boutiques_descriptor)

data = glob.glob(os.path.join(base_dir, "sub-*", "ses-*", "anat", "*.nii"))
print(f"Found {len(data)}")

task_id = int(os.environ["SLURM_ARRAY_TASK_ID"])

if task_id >= len(data) or task_id < 0:
    print(f"SLURM_ARRAY_TASK_ID={task_id} out of range")
    exit(1)

input_file = data[task_id]

path_parts = input_file.split(os.sep)
subject_id = next((part for part in path_parts if part.startswith("sub-")), None)

if subject_id is None:
        print(f"Could not find subject ID in path: {input_file}")
        exit(1)

filename = os.path.basename(input_file)

subject_output_dir = os.path.join(output_dir, subject_id)
os.makedirs(subject_output_dir, exist_ok=True)

print(f"Processing SLURM_ARRAY_TASK_ID ={task_id} -> file: {filename} (subject: {subject_id})")

result = cat12('--imagepath=/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif', 
                   input_file=input_file, 
                   output_dir=subject_output_dir)

result_dict = vars(result)
print("Available result keys:", result_dict.keys())
print("\nExit code:", result_dict.get("exit_code"))
print("\nStdout:\n", result_dict.get ("stdout" ))
print("\nStderr:\n", result_dict.get ("stderr"))
print("\nOutput files:\n", result_dict.get("output_files"))

In [None]:
%%sbatch --array=1040-1040
#!/bin/bash
#SBATCH --job-name=CAT12_parkinson_preproc
#SBATCH --account=rrg-glatard
#SBATCH --mem=12G
#SBATCH --cpus-per-task=2
#SBATCH --nodes=1
#SBATCH --output=cat12_ppmi_preproc_%A_%a.out
#SBATCH --error=cat12_ppmi_preproc_%A_%a.err
#SBATCH --time=1:0:0

source ~/.venvs/jupyter_py3/bin/activate

module load apptainer

cd ~/links/projects/def-glatard/arelbaha/data/inputs #Location of preprocessing data and script

echo "Running task ID: $SLURM_ARRAY_TASK_ID"

python classification_parkinson.py

## CAT12 Preprocessing | AnatCL | NKI

In [None]:
%%writefile /home/arelbaha/links/scratch/NKI/cat12_script/nki_cat12_preprocessing.py

import os
import glob
import boutiques
from boutiques import bosh
from boutiques.descriptor2func import function

boutiques_descriptor = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/descriptors/cat12_prepro.json"
base_dir = "/home/arelbaha/links/scratch/NKI/NKI_BIDS"
output_dir = "/home/arelbaha/links/scratch/cat12_NKI"

bosh(["exec", 
      "prepare", 
      boutiques_descriptor, 
      "--imagepath", 
      "/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif"
])

cat12 = boutiques.descriptor2func.function(boutiques_descriptor)

with open("/home/arelbaha/links/scratch/NKI/final_subjects.txt") as f:
    subjects = [line.strip() for line in f if line.strip()]

t1_nii_files = []
for subj in subjects:
    files = glob.glob(os.path.join(base_dir, f"sub-{subj}", "ses-BAS1", "anat", "*T1w.nii"))
    if files:
        t1_nii_files.append(files[0])

print(f"Found {len(t1_nii_files)} T1w files from {len(subjects)} subjects.")

task_id = int(os.environ["SLURM_ARRAY_TASK_ID"])
if task_id >= len(t1_nii_files) or task_id < 0:
    print(f"SLURM_ARRAY_TASK_ID={task_id} out of range")
    exit(1)

input_file = t1_nii_files[task_id]
path_parts = input_file.split(os.sep)
subject_id = next((part for part in path_parts if part.startswith("sub-")), None)

if subject_id is None:
    print(f"No subject ID in path: {input_file}")
    exit(1)

filename = os.path.basename(input_file)
subject_output_dir = os.path.join(output_dir, subject_id)
os.makedirs(subject_output_dir, exist_ok=True)

print(f"Processing={task_id} -> file: {filename} (subject: {subject_id})")

result = cat12('--imagepath=/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif', 
                   input_file=input_file, 
                   output_dir=subject_output_dir)

result_dict = vars(result)
print("Available result keys:", result_dict.keys())
print("\nExit code:", result_dict.get("exit_code"))
print("\nStdout:\n", result_dict.get("stdout"))
print("\nStderr:\n", result_dict.get("stderr"))
print("\nOutput files:\n", result_dict.get("output_files"))

In [None]:
%%sbatch
#!/bin/bash
#SBATCH --array=21-957
#SBATCH --job-name=cat12_nki
#SBATCH --time=4:00:00
#SBATCH --account=rrg-glatard
#SBATCH --mem=32G
#SBATCH --cpus-per-task=4
#SBATCH --nodes=1
#SBATCH --output=/home/arelbaha/links/scratch/NKI/cat12_script/logs/cat12_nki_%A_%a.out
#SBATCH --error=/home/arelbaha/links/scratch/NKI/cat12_script/logs/cat12_nki_%A_%a.err

source ~/.venvs/jupyter_py3/bin/activate
module load apptainer
cd /home/arelbaha/links/scratch/NKI/cat12_script

echo "Running task ID: $SLURM_ARRAY_TASK_ID"
python nki_cat12_preprocessing.py

## Turboprep Processing | BrainIAC, SwinBrain, 3D-Neuro-SimCLR | HBN

In [None]:
%%sbatch --array=75
#!/bin/bash
#SBATCH --account=def-glatard
#SBATCH --time=4:00:00
#SBATCH --mem=32G
#SBATCH --cpus-per-task=4
#SBATCH --job-name=turboprep
#SBATCH --output=logs/turboprep_batch_%a.out
#SBATCH --error=logs/turboprep_batch_%a.err


INPUT_BASE="/home/arelbaha/links/projects/rrg-glatard/arelbaha/brainiac_p_files"
OUTPUT_BASE="/home/arelbaha/links/scratch/turboprep_output"
TEMPLATE="/home/arelbaha/links/scratch/t1_template/mni_icbm152_nlin_sym_09c/mni_icbm152_t1_tal_nlin_sym_09c.nii"
CONTAINER="/home/arelbaha/links/scratch/turboprep_container/turboprep.sif"


BATCH_NUM=$(printf "%03d" $SLURM_ARRAY_TASK_ID)
BATCH_DIR="${INPUT_BASE}/batch_${BATCH_NUM}"

echo "Processing batch ${BATCH_NUM}"

module load apptainer

for T1_FILE in "$BATCH_DIR"/*.nii.gz; do
    SUBJ_ID=$(basename "$T1_FILE" | sed 's/sub-\([^_]*\)_.*/\1/')
    
    SUBJ_OUTPUT="${OUTPUT_BASE}/${SUBJ_ID}"

    mkdir -p "$SUBJ_OUTPUT"
    
    echo "Processing $SUBJ_ID"
    
    apptainer run "$CONTAINER" "$T1_FILE" "$SUBJ_OUTPUT" "$TEMPLATE" -m t1 -r r
    
done

echo "Completed batch ${BATCH_NUM}"

## Turboprep Processing | BrainIAC, SwinBrain, 3D-Neuro-SimCLR | NKI

In [None]:
%%sbatch --array=21-958
#!/bin/bash
#SBATCH --account=def-glatard
#SBATCH --time=2:00:00
#SBATCH --mem=32G
#SBATCH --cpus-per-task=4
#SBATCH --job-name=turboprep_nki
#SBATCH --output=logs/turboprep_nki_%a.out
#SBATCH --error=logs/turboprep_nki_%a.err

SUBJ_LIST="/home/arelbaha/links/scratch/NKI/final_subjects.txt"
BIDS_DIR="/home/arelbaha/links/scratch/NKI/NKI_BIDS"
OUTPUT_BASE="/home/arelbaha/links/scratch/turboprep_NKI"
TEMPLATE="/home/arelbaha/links/scratch/t1_template/mni_icbm152_nlin_sym_09c/mni_icbm152_t1_tal_nlin_sym_09c.nii"
CONTAINER="/home/arelbaha/links/scratch/turboprep_container/turboprep.sif"

SUBJ_ID=$(sed -n "${SLURM_ARRAY_TASK_ID}p" "$SUBJ_LIST")
echo "Processing sub-${SUBJ_ID}"

T1_FILE=$(find -L "${BIDS_DIR}/sub-${SUBJ_ID}/ses-BAS1/anat" -name "*T1w.nii.gz" 2>/dev/null | head -1)

if [ -z "$T1_FILE" ]; then
    echo "ERROR: No downloaded T1w found for sub-${SUBJ_ID}, skipping"
    exit 1
fi

echo "Using: $T1_FILE"

SUBJ_OUTPUT="${OUTPUT_BASE}/${SUBJ_ID}"
mkdir -p "$SUBJ_OUTPUT"

module load apptainer
apptainer run "$CONTAINER" "$T1_FILE" "$SUBJ_OUTPUT" "$TEMPLATE" -m t1 -r r

echo "Completed sub-${SUBJ_ID}"

## Turboprep Preprocessing | BrainIAC, SwinBrain, 3D-Neuro-SimCLR | PPMI

In [None]:
%%sbatch --array=10
#!/bin/bash
#SBATCH --account=def-glatard
#SBATCH --time=8:00:00
#SBATCH --mem=32G
#SBATCH --cpus-per-task=4
#SBATCH --job-name=turboprep_ppmi
#SBATCH --output=logs/turboprep_ppmi_batch_%a.out
#SBATCH --error=logs/turboprep_ppmi_batch_%a.err

INPUT_BASE="/home/arelbaha/links/projects/def-glatard/arelbaha/data/raw_files_brainiac"
OUTPUT_BASE="/home/arelbaha/links/scratch/turboprep_output_ppmi"
TEMPLATE="/home/arelbaha/links/scratch/t1_template/mni_icbm152_nlin_sym_09c/mni_icbm152_t1_tal_nlin_sym_09c.nii"
CONTAINER="/home/arelbaha/links/scratch/turboprep_container/turboprep.sif"

BATCH_DIR="${INPUT_BASE}/batch_${SLURM_ARRAY_TASK_ID}"

echo "Processing batch ${SLURM_ARRAY_TASK_ID}"

module load apptainer

for T1_FILE in "$BATCH_DIR"/*.nii.gz; do
    SUBJ_ID=$(basename "$T1_FILE" | sed 's/PPMI_\([0-9]*\)_.*/\1/')
    
    SUBJ_OUTPUT="${OUTPUT_BASE}/${SUBJ_ID}"
    mkdir -p "$SUBJ_OUTPUT"
    
    echo "Processing $SUBJ_ID"
    
    apptainer run "$CONTAINER" "$T1_FILE" "$SUBJ_OUTPUT" "$TEMPLATE" -m t1 -r r
    
done
echo "Completed batch ${SLURM_ARRAY_TASK_ID}"

## Multi-Task Evaluation of Models

In [None]:
%%writefile /home/arelbaha/links/projects/rrg-glatard/arelbaha/all_multimodel_comparison.py

import os
import glob
import numpy as np
import pandas as pd

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import nibabel as nib
import sys
import types
import anatcl.models as _anatcl_models

class _AgeEstimator:
    pass

_estimators_mod = types.ModuleType('models.estimators')
_estimators_mod.AgeEstimator = _AgeEstimator
sys.modules['models'] = _anatcl_models
sys.modules['models.estimators'] = _estimators_mod

from anatcl import AnatCL
from monai.networks.nets import ViT, SwinUNETR, resnet18
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, mean_absolute_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.dummy import DummyClassifier, DummyRegressor
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import squareform
from scipy.stats import permutation_test
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

SEEDS = [0, 1, 2, 3, 42]

CNN_SEED_CONFIGS = {
    'CNN_0':  [0],
    'CNN_42': [42],
}

torch.use_deterministic_algorithms(True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

device = "cpu"
DROPOUT_RATE = 0.3
ANATCL_ENCODER_PATH = "/home/arelbaha/.venvs/jupyter_py3/bin"
BRAINIAC_CKPT = "/home/arelbaha/.venvs/jupyter_py3/bin/BrainIAC.ckpt"
SWINBRAIN_CKPT = "/home/arelbaha/.venvs/jupyter_py3/bin/Brain_Swin_UNETR.pth"
SIMCLR_CKPT = "/home/arelbaha/.venvs/jupyter_py3/bin/simclr_3d_brain_foundation.tar"

class CAT12VBMDataset(Dataset):
    def __init__(self, data, transform):
        self.data, self.transform = data, transform
        self.target_shape = (121, 128, 121)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img = nib.load(self.data[idx]).get_fdata().astype(np.float32)
        if img.shape != self.target_shape:
            img = F.interpolate(torch.from_numpy(img[None, None]), size=self.target_shape,
                               mode='trilinear', align_corners=False).squeeze().numpy()
        img = self.transform(img).unsqueeze(0)
        return img

class BrainIACDataset(Dataset):
    def __init__(self, paths, transform):
        self.paths, self.transform = paths, transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (96, 96, 96):
            img = F.interpolate(torch.from_numpy(img[None, None]), size=(96, 96, 96),
                              mode='trilinear', align_corners=False).squeeze().numpy()
        img_tensor = torch.from_numpy(img)
        img_normalized = self.transform(img_tensor)
        return img_normalized[None]

class SwinBrainDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        img = F.interpolate(torch.from_numpy(img[None, None]), size=(128, 128, 64),
                           mode='trilinear', align_corners=False).squeeze(0)
        img = (img - img.mean()) / (img.std() + 1e-6)
        img = img.repeat(3, 1, 1, 1)
        return img

class CNNDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (91, 109, 91):
            img = F.interpolate(torch.from_numpy(img[None, None]), size=(91, 109, 91),
                               mode='trilinear', align_corners=False).squeeze().numpy()
        img = (img - img.mean()) / (img.std() + 1e-6)
        return torch.from_numpy(img[None]).float()

class TurboPrepDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
        self.target_shape = (150, 192, 192)
    def __len__(self):
        return len(self.paths)
    def _center_crop(self, img, target_shape):
        starts = [(s - t) // 2 for s, t in zip(img.shape, target_shape)]
        return img[starts[0]:starts[0]+target_shape[0],
                   starts[1]:starts[1]+target_shape[1],
                   starts[2]:starts[2]+target_shape[2]]
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        img = np.transpose(img, (2, 0, 1))
        img = self._center_crop(img, self.target_shape)
        img = (img - img.mean()) / (img.std() + 1e-6)
        return torch.from_numpy(img[None]).float()

# Models

class CNN3D(nn.Module):
    def __init__(self, dropout_rate=DROPOUT_RATE):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 64, 3, padding=1)
        self.bn1, self.pool1, self.drop1 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv2 = nn.Conv3d(64, 64, 3, padding=1)
        self.bn2, self.pool2, self.drop2 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv3 = nn.Conv3d(64, 128, 3, padding=1)
        self.bn3, self.pool3, self.drop3 = nn.BatchNorm3d(128), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv4 = nn.Conv3d(128, 256, 3, padding=1)
        self.bn4, self.pool4, self.drop4 = nn.BatchNorm3d(256), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1, self.drop5 = nn.Linear(256, 512), nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.drop1(self.bn1(self.pool1(F.relu(self.conv1(x)))))
        x = self.drop2(self.bn2(self.pool2(F.relu(self.conv2(x)))))
        x = self.drop3(self.bn3(self.pool3(F.relu(self.conv3(x)))))
        x = self.drop4(self.bn4(self.pool4(F.relu(self.conv4(x)))))
        return self.drop5(F.relu(self.fc1(self.global_pool(x).view(x.size(0), -1))))

def load_brainiac_vit(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    vit = ViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16),
              hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, save_attn=True)
    pe_key = "backbone.patch_embedding.position_embeddings"
    if pe_key in ckpt.get("state_dict", ckpt):
        vit.patch_embedding.position_embeddings = nn.Parameter(ckpt["state_dict"][pe_key].clone())
    backbone_state = {k.replace("backbone.", ""): v for k, v in ckpt.get("state_dict", ckpt).items()
                     if k.startswith("backbone.")}
    vit.load_state_dict(backbone_state, strict=False)
    return vit.to(device).eval()

def load_swinbrain(ckpt_path, device):
    model = SwinUNETR(img_size=(128, 128, 64), in_channels=3, out_channels=3, spatial_dims=3,
                      feature_size=24, drop_rate=0.0, attn_drop_rate=0.0)
    ckpt = torch.load(ckpt_path, map_location="cpu")
    clean_state = {k.replace("module.", ""): v for k, v in ckpt.items()}
    model.load_state_dict(clean_state, strict=True)
    return model.to(device).eval()

def load_simclr_encoder(ckpt_path, device):
    encoder = resnet18(spatial_dims=3, n_input_channels=1, num_classes=0)
    encoder.fc = nn.Identity()
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = {k.replace("module.encoder.", ""): v
                  for k, v in ckpt['model_state_dict'].items()
                  if "encoder" in k and "projector" not in k}
    encoder.load_state_dict(state_dict, strict=True)
    return encoder.to(device).eval()

def extract_cnn_features_multiseed(turboprep_paths, seeds, device):
    dl = DataLoader(CNNDataset(turboprep_paths), batch_size=8, num_workers=0)
    per_seed = []
    for seed in seeds:
        torch.manual_seed(seed)
        model = CNN3D().to(device).eval()
        with torch.no_grad():
            feats = np.vstack([model(x.to(device)).cpu().numpy() for x in dl])
        per_seed.append(feats)
        del model
    return np.hstack(per_seed)
    
# Downstream Analysis

def run_classification(X_dict, y, test_size=0.1, task_name="Classification"):
    print(f"\n{task_name}")
    results = {}
    alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

    dummy_cv_acc, dummy_cv_auc, dummy_test_acc, dummy_test_auc = [], [], [], []
    for seed in SEEDS:
        cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)),
                                            stratify=y, random_state=seed)
        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
        for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
            dummy = DummyClassifier(strategy='stratified', random_state=seed)
            dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
            dummy_cv_acc.append(balanced_accuracy_score(y[val_idx], dummy.predict(np.zeros((len(val_idx), 1)))))
            dummy_cv_auc.append(roc_auc_score(y[val_idx], dummy.predict_proba(np.zeros((len(val_idx), 1)))[:, 1]))
        dummy = DummyClassifier(strategy='stratified', random_state=seed)
        dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
        dummy_test_acc.append(balanced_accuracy_score(y[test_idx], dummy.predict(np.zeros((len(test_idx), 1)))))
        dummy_test_auc.append(roc_auc_score(y[test_idx], dummy.predict_proba(np.zeros((len(test_idx), 1)))[:, 1]))

    results['DummyClassifier'] = {
        'best_alpha': 1.0,
        'test_acc_mean': np.mean(dummy_test_acc), 'test_acc_std': np.std(dummy_test_acc),
        'test_auc_mean': np.mean(dummy_test_auc), 'test_auc_std': np.std(dummy_test_auc),
        'cv_auc_mean': np.mean(dummy_cv_auc), 'cv_auc_std': np.std(dummy_cv_auc),
        'cv_acc_mean': np.mean(dummy_cv_acc), 'cv_acc_std': np.std(dummy_cv_acc),
        'fold_results': {'val_auc': dummy_cv_auc, 'val_acc': dummy_cv_acc},
        'test_seed_results': {'test_acc': dummy_test_acc, 'test_auc': dummy_test_auc},
    }
    print(f"Dummy: AUC={np.mean(dummy_test_auc):.3f}+/-{np.std(dummy_test_auc):.3f}")

    for model, X in X_dict.items():
        all_val_auc = {a: [] for a in alphas}
        all_val_acc = {a: [] for a in alphas}
        all_train_auc = {a: [] for a in alphas}
        all_train_acc = {a: [] for a in alphas}
        all_test_acc, all_test_auc = [], []

        for seed in SEEDS:
            cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)),
                                                stratify=y, random_state=seed)
            skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
            for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
                train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
                for alpha in alphas:
                    n = int(alpha * len(train_idx))
                    if n >= len(train_idx):
                        tr_idx = train_idx
                    else:
                        tr_idx, _ = train_test_split(train_idx, train_size=n, stratify=y[train_idx], random_state=seed)
                    if len(np.unique(y[tr_idx])) < 2:
                        continue
                    scaler = StandardScaler()
                    rf = RandomForestClassifier(n_estimators=200, max_depth=6, min_samples_split=5,
                                               random_state=seed, n_jobs=1, max_features='sqrt', class_weight='balanced')
                    rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])
                    val_pred_proba = rf.predict_proba(scaler.transform(X[val_idx]))[:, 1]
                    all_val_acc[alpha].append(balanced_accuracy_score(y[val_idx], rf.predict(scaler.transform(X[val_idx]))))
                    all_val_auc[alpha].append(roc_auc_score(y[val_idx], val_pred_proba))
                    all_train_acc[alpha].append(balanced_accuracy_score(y[tr_idx], rf.predict(scaler.transform(X[tr_idx]))))
                    all_train_auc[alpha].append(roc_auc_score(y[tr_idx], rf.predict_proba(scaler.transform(X[tr_idx]))[:, 1]))

            scaler = StandardScaler()
            rf = RandomForestClassifier(n_estimators=200, max_depth=6, min_samples_split=5,
                                        random_state=seed, n_jobs=1, max_features='sqrt', class_weight='balanced')
            rf.fit(scaler.fit_transform(X[cv_idx]), y[cv_idx])
            test_pred_proba = rf.predict_proba(scaler.transform(X[test_idx]))[:, 1]
            all_test_acc.append(balanced_accuracy_score(y[test_idx], rf.predict(scaler.transform(X[test_idx]))))
            all_test_auc.append(roc_auc_score(y[test_idx], test_pred_proba))

        avg_auc = {a: np.mean(all_val_auc[a]) for a in alphas if all_val_auc[a]}
        best_alpha = max(avg_auc, key=avg_auc.get)
        train_val_gaps = {a: np.mean(all_train_auc[a]) - np.mean(all_val_auc[a]) for a in alphas if all_train_auc[a] and all_val_auc[a]}

        results[model] = {
            'best_alpha': best_alpha,
            'test_acc_mean': np.mean(all_test_acc), 'test_acc_std': np.std(all_test_acc),
            'test_auc_mean': np.mean(all_test_auc), 'test_auc_std': np.std(all_test_auc),
            'cv_auc_mean': np.mean(all_val_auc[best_alpha]), 'cv_auc_std': np.std(all_val_auc[best_alpha]),
            'cv_acc_mean': np.mean(all_val_acc[best_alpha]), 'cv_acc_std': np.std(all_val_acc[best_alpha]),
            'train_val_gap': train_val_gaps.get(best_alpha, 0),
            'is_overfitting': train_val_gaps.get(best_alpha, 0) > 0.25,
            'fold_results': {'val_auc': all_val_auc[best_alpha], 'val_acc': all_val_acc[best_alpha]},
            'fold_results_alpha1': {'val_acc': all_val_acc[1.0], 'val_auc': all_val_auc[1.0]},
            'test_seed_results': {'test_acc': all_test_acc, 'test_auc': all_test_auc},
            'cv_results': {
                'alphas': alphas,
                'acc': {a: np.mean(all_val_acc[a]) for a in alphas if all_val_acc[a]},
                'auc': avg_auc,
                'acc_std': {a: np.std(all_val_acc[a]) for a in alphas if all_val_acc[a]},
                'auc_std': {a: np.std(all_val_auc[a]) for a in alphas if all_val_auc[a]},
                'train_acc': {a: np.mean(all_train_acc[a]) for a in alphas if all_train_acc[a]},
                'train_acc_std': {a: np.std(all_train_acc[a]) for a in alphas if all_train_acc[a]},
                'train_auc': {a: np.mean(all_train_auc[a]) for a in alphas if all_train_auc[a]},
                'train_auc_std': {a: np.std(all_train_auc[a]) for a in alphas if all_train_auc[a]},
            }
        }
        print(f"{model}: AUC={np.mean(all_test_auc):.3f}+/-{np.std(all_test_auc):.3f}, Acc={np.mean(all_test_acc):.3f}+/-{np.std(all_test_acc):.3f}")
    return results

def run_regression(X_dict, y, test_size=0.1, task_name="Regression"):
    print(f"\n{task_name}")
    results = {}
    alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

    dummy_cv_r2, dummy_cv_mae, dummy_test_r2, dummy_test_mae = [], [], [], []
    for seed in SEEDS:
        cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), random_state=seed)
        kf = KFold(n_splits=5, shuffle=True, random_state=seed)
        for train_rel, val_rel in kf.split(cv_idx):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
            dummy = DummyRegressor(strategy='mean')
            dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
            pred = dummy.predict(np.zeros((len(val_idx), 1)))
            dummy_cv_r2.append(r2_score(y[val_idx], pred))
            dummy_cv_mae.append(mean_absolute_error(y[val_idx], pred))
        dummy = DummyRegressor(strategy='mean')
        dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
        test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
        dummy_test_r2.append(r2_score(y[test_idx], test_pred))
        dummy_test_mae.append(mean_absolute_error(y[test_idx], test_pred))

    results['DummyRegressor'] = {
        'best_alpha': 1.0,
        'test_r2_mean': np.mean(dummy_test_r2), 'test_r2_std': np.std(dummy_test_r2),
        'test_mae_mean': np.mean(dummy_test_mae), 'test_mae_std': np.std(dummy_test_mae),
        'cv_r2_mean': np.mean(dummy_cv_r2), 'cv_r2_std': np.std(dummy_cv_r2),
        'cv_mae_mean': np.mean(dummy_cv_mae), 'cv_mae_std': np.std(dummy_cv_mae),
        'fold_results': {'val_r2': dummy_cv_r2, 'val_mae': dummy_cv_mae},
        'test_seed_results': {'test_mae': dummy_test_mae, 'test_r2': dummy_test_r2},
    }
    print(f"Dummy: MAE={np.mean(dummy_test_mae):.2f}+/-{np.std(dummy_test_mae):.2f}")

    for model, X in X_dict.items():
        all_val_r2 = {a: [] for a in alphas}
        all_val_mae = {a: [] for a in alphas}
        all_train_r2 = {a: [] for a in alphas}
        all_train_mae = {a: [] for a in alphas}
        all_test_r2, all_test_mae = [], []

        for seed in SEEDS:
            cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), random_state=seed)
            kf = KFold(n_splits=5, shuffle=True, random_state=seed)
            for train_rel, val_rel in kf.split(cv_idx):
                train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
                for alpha in alphas:
                    n = int(alpha * len(train_idx))
                    tr_idx, _ = train_test_split(train_idx, train_size=n, random_state=seed) if n < len(train_idx) else (train_idx, None)
                    scaler = StandardScaler()
                    rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_split=5,
                                               random_state=seed, n_jobs=1)
                    rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])
                    pred = rf.predict(scaler.transform(X[val_idx]))
                    train_pred = rf.predict(scaler.transform(X[tr_idx]))
                    all_val_r2[alpha].append(r2_score(y[val_idx], pred))
                    all_val_mae[alpha].append(mean_absolute_error(y[val_idx], pred))
                    all_train_r2[alpha].append(r2_score(y[tr_idx], train_pred))
                    all_train_mae[alpha].append(mean_absolute_error(y[tr_idx], train_pred))

            # Held-out test evaluation: train on ALL cv_idx, predict test_idx once per seed
            scaler = StandardScaler()
            rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_split=5,
                                        random_state=seed, n_jobs=1)
            rf.fit(scaler.fit_transform(X[cv_idx]), y[cv_idx])
            test_pred = rf.predict(scaler.transform(X[test_idx]))
            all_test_r2.append(r2_score(y[test_idx], test_pred))
            all_test_mae.append(mean_absolute_error(y[test_idx], test_pred))

        avg_r2 = {a: np.mean(all_val_r2[a]) for a in alphas if all_val_r2[a]}
        best_alpha = max(avg_r2, key=avg_r2.get)
        train_val_gaps = {a: np.mean(all_train_r2[a]) - np.mean(all_val_r2[a]) for a in alphas if all_train_r2[a] and all_val_r2[a]}

        results[model] = {
            'best_alpha': best_alpha,
            'test_r2_mean': np.mean(all_test_r2), 'test_r2_std': np.std(all_test_r2),
            'test_mae_mean': np.mean(all_test_mae), 'test_mae_std': np.std(all_test_mae),
            'cv_r2_mean': np.mean(all_val_r2[best_alpha]), 'cv_r2_std': np.std(all_val_r2[best_alpha]),
            'cv_mae_mean': np.mean(all_val_mae[best_alpha]), 'cv_mae_std': np.std(all_val_mae[best_alpha]),
            'train_val_gap': train_val_gaps.get(best_alpha, 0),
            'is_overfitting': train_val_gaps.get(best_alpha, 0) > 0.35,
            'fold_results': {'val_r2': all_val_r2[best_alpha], 'val_mae': all_val_mae[best_alpha]},
            'fold_results_alpha1': {'val_mae': all_val_mae[1.0], 'val_r2': all_val_r2[1.0]},
            'test_seed_results': {'test_mae': all_test_mae, 'test_r2': all_test_r2},
            'cv_results': {
                'alphas': alphas,
                'r2': avg_r2,
                'mae': {a: np.mean(all_val_mae[a]) for a in alphas if all_val_mae[a]},
                'r2_std': {a: np.std(all_val_r2[a]) for a in alphas if all_val_r2[a]},
                'mae_std': {a: np.std(all_val_mae[a]) for a in alphas if all_val_mae[a]},
                'train_r2': {a: np.mean(all_train_r2[a]) for a in alphas if all_train_r2[a]},
                'train_r2_std': {a: np.std(all_train_r2[a]) for a in alphas if all_train_r2[a]},
                'train_mae': {a: np.mean(all_train_mae[a]) for a in alphas if all_train_mae[a]},
                'train_mae_std': {a: np.std(all_train_mae[a]) for a in alphas if all_train_mae[a]},
            }
        }
        print(f"{model}: MAE={np.mean(all_test_mae):.2f}+/-{np.std(all_test_mae):.2f}, R2={np.mean(all_test_r2):.3f}+/-{np.std(all_test_r2):.3f}")
    return results

# Correlation Analysis

def compute_pairwise_correlations(features_dict, dataset_name):
    print(f"\n{'='*60}")
    print(f"MEAN ABSOLUTE CORRELATIONS - {dataset_name}")
    print(f"{'='*60}")
    model_names = list(features_dict.keys())
    scaled_features = {}
    for model in model_names:
        X = StandardScaler().fit_transform(features_dict[model])
        valid = ~np.isnan(X).any(axis=0) & (np.std(X, axis=0) > 1e-10)
        scaled_features[model] = X[:, valid]
    corr_results = {}
    for i, model_i in enumerate(model_names):
        for j, model_j in enumerate(model_names):
            if j <= i:
                continue
            X_i, X_j = scaled_features[model_i], scaled_features[model_j]
            cross_corr = np.corrcoef(X_i.T, X_j.T)[:X_i.shape[1], X_i.shape[1]:]
            corr_results[(model_i, model_j)] = np.mean(np.abs(cross_corr))
    fs_keys = [k for k in model_names if 'FreeSurfer' in k or 'FS_' in k or k == 'FreeSurfer']
    for fs_key in fs_keys:
        print(f"\nvs {fs_key}:")
        for model in model_names:
            if model == fs_key:
                continue
            key = (model, fs_key) if (model, fs_key) in corr_results else (fs_key, model)
            print(f"  {model} vs {fs_key}: {corr_results[key]:.4f}")
    print(f"\nAll pairwise (sorted by correlation):")
    for (m1, m2), corr in sorted(corr_results.items(), key=lambda x: x[1], reverse=True):
        print(f"  {m1} vs {m2}: {corr:.4f}")

def compute_correlation_with_cluster_labels(features_dict, fs_feature_infos, fs_keys, dataset_name):
    all_features, boundaries, reordered_indices, model_names_used, cluster_dfs = [], [0], {}, [], {}
    cnn_keys = list(CNN_SEED_CONFIGS.keys())
    model_order = ['AnatCL_Global', 'AnatCL_Local', 'BrainIAC', 'SwinBrain'] + \
                  cnn_keys + ['3D-Neuro-SimCLR'] + fs_keys

    for model in model_order:
        if model not in features_dict:
            continue
        feats = StandardScaler().fit_transform(features_dict[model])
        valid = ~np.isnan(feats).any(axis=0) & (np.std(feats, axis=0) > 1e-10)
        feats = feats[:, valid]
        if feats.shape[1] == 0:
            continue
        print(f"  {model}: {feats.shape[1]} features")
        corr = np.corrcoef(feats.T)
        corr = np.nan_to_num((corr + corr.T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
        np.fill_diagonal(corr, 1.0)
        dist = np.abs(1 - np.abs(corr))
        np.fill_diagonal(dist, 0)
        condensed = np.nan_to_num(squareform((dist + dist.T) / 2, checks=False), nan=0.0, posinf=1.0, neginf=0.0)
        reorder_idx = dendrogram(linkage(condensed, method='ward'), no_plot=True)['leaves']
        if model in fs_keys and model in fs_feature_infos:
            fs_info = fs_feature_infos[model]
            thickness_indices = [i for i in reorder_idx if fs_info[i]['feature_type'] == 'Thickness']
            surfarea_indices  = [i for i in reorder_idx if fs_info[i]['feature_type'] == 'Surface_Area']
            reorder_idx = thickness_indices + surfarea_indices
        all_features.append(feats[:, reorder_idx])
        boundaries.append(boundaries[-1] + feats.shape[1])
        reordered_indices[model] = reorder_idx
        model_names_used.append(model)

    combined = np.hstack(all_features)
    corr = np.nan_to_num((np.corrcoef(combined.T) + np.corrcoef(combined.T).T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
    np.fill_diagonal(corr, 1.0)
    for fs_key in fs_keys:
        if fs_key in reordered_indices:
            fs_info = fs_feature_infos[fs_key]
            cluster_df = pd.DataFrame([{**fs_info[idx], 'original_index': idx, 'reordered_position': pos}
                                       for pos, idx in enumerate(reordered_indices[fs_key])])
            cluster_dfs[fs_key] = cluster_df
    return corr, boundaries, model_names_used, reordered_indices, cluster_dfs

def plot_correlation_with_labels(corr_matrix, boundaries, model_names, cluster_dfs, fs_keys, dataset_name):
    fig, ax = plt.subplots(figsize=(32, 30))
    im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='equal', interpolation='nearest')
    plt.colorbar(im, ax=ax, shrink=0.3)
    for b in boundaries[1:-1]:
        ax.axhline(y=b-0.5, color='black', linewidth=4)
        ax.axvline(x=b-0.5, color='black', linewidth=4)
    bar_width = 8
    colors = {'Surface_Area': '#2ECC71', 'Thickness': '#3498DB'}
    for fs_key in fs_keys:
        if fs_key not in cluster_dfs:
            continue
        cluster_df = cluster_dfs[fs_key]
        fs_idx = model_names.index(fs_key)
        fs_start = boundaries[fs_idx]
        for i, ft in enumerate(cluster_df['feature_type'].values):
            color = colors.get(ft, '#95A5A6')
            ax.add_patch(plt.Rectangle((boundaries[-1] + 3, fs_start + i - 0.5), bar_width, 1, facecolor=color, edgecolor='none'))
            ax.add_patch(plt.Rectangle((-bar_width - 5, fs_start + i - 0.5), bar_width, 1, facecolor=color, edgecolor='none'))
    display_names = {'FS_Schaefer': 'FS: Schaefer', 'FS_aparc': 'FS: aparc', 'FreeSurfer': 'FS'}
    label_fontsize = 18 if dataset_name.lower() == 'ppmi' else 10
    for i in range(len(boundaries) - 1):
        pos = (boundaries[i] + boundaries[i+1]) / 2
        label = display_names.get(model_names[i], model_names[i])
        ax.text(pos, -30, label, ha='center', fontsize=label_fontsize, weight='bold')
        ax.text(-30, pos, label, ha='right', va='center', fontsize=label_fontsize, weight='bold', rotation=90)
    ax.set_xlim(-bar_width - 60, boundaries[-1] + bar_width + 50)
    ax.set_ylim(boundaries[-1] + 20, -60)
    ax.set_xticks([]); ax.set_yticks([]); ax.axis('off')
    ax.legend(handles=[Patch(facecolor='#2ECC71', edgecolor='black', label='Surface Area'),
                       Patch(facecolor='#3498DB', edgecolor='black', label='Thickness')],
              loc='upper left', fontsize=14, bbox_to_anchor=(1.02, 1.0), frameon=True)
    plt.tight_layout()
    plt.savefig(f'{dataset_name.lower()}_cross_model_correlation.pdf', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved: {dataset_name.lower()}_cross_model_correlation.pdf")

def export_correlation_data(corr_matrix, boundaries, model_names, cluster_dfs, dataset_name, output_dir=''):
    prefix = os.path.join(output_dir, dataset_name.lower())
    np.save(f'{prefix}_corr_matrix.npy', corr_matrix)
    meta_rows = []
    for i, model in enumerate(model_names):
        meta_rows.append({'dataset': dataset_name, 'model': model,
                          'boundary_start': boundaries[i], 'boundary_end': boundaries[i+1],
                          'n_features': boundaries[i+1] - boundaries[i]})
    pd.DataFrame(meta_rows).to_csv(f'{prefix}_corr_metadata.csv', index=False)
    print(f"Saved: {prefix}_corr_matrix.npy ({corr_matrix.shape})")
    print(f"Saved: {prefix}_corr_metadata.csv")

# Feature Extraction

def extract_features_ppmi(cat12_paths, turboprep_paths, device):
    anatcl_transform = transforms.Compose([
        transforms.Lambda(lambda x: torch.from_numpy(x.copy()).float()),
        transforms.Normalize(mean=0.0, std=1.0)
    ])

    print("Extracting AnatCL (global)")
    all_fold_features = []
    for fold_idx in range(5):
        path = os.path.join(ANATCL_ENCODER_PATH, f"weights_global_fold{fold_idx}.pth")
        encoder = AnatCL(descriptor="global", fold=fold_idx, pretrained=False).to(device).eval()
        encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
        for p in encoder.parameters(): p.requires_grad = False
        dl = DataLoader(CAT12VBMDataset(cat12_paths, anatcl_transform), batch_size=32, num_workers=0)
        with torch.no_grad():
            all_fold_features.append(torch.cat([encoder(v.to(device)).cpu() for v in dl]).numpy())
        del encoder
    anatcl_features = np.mean(all_fold_features, axis=0)
    print(f"  AnatCL (global): {anatcl_features.shape}")

    print("Extracting AnatCL (local)")
    all_fold_features_local = []
    for fold_idx in range(5):
        path = os.path.join(ANATCL_ENCODER_PATH, f"weights_local_fold{fold_idx}.pth")
        encoder = AnatCL(descriptor="local", fold=fold_idx, pretrained=False).to(device).eval()
        encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
        for p in encoder.parameters(): p.requires_grad = False
        dl = DataLoader(CAT12VBMDataset(cat12_paths, anatcl_transform), batch_size=32, num_workers=0)
        with torch.no_grad():
            all_fold_features_local.append(torch.cat([encoder(v.to(device)).cpu() for v in dl]).numpy())
        del encoder
    anatcl_local_features = np.mean(all_fold_features_local, axis=0)
    print(f"  AnatCL (local): {anatcl_local_features.shape}")

    print("Extracting SwinBrain")
    swinbrain_model = load_swinbrain(SWINBRAIN_CKPT, device)
    swinbrain_hook_output = [None]
    def swinbrain_hook(module, input, output): swinbrain_hook_output[0] = output
    hook_handle = swinbrain_model.encoder10.register_forward_hook(swinbrain_hook)
    pool = nn.AdaptiveAvgPool3d((1, 1, 1))
    dl = DataLoader(SwinBrainDataset(turboprep_paths), batch_size=8, num_workers=0)
    swinbrain_list = []
    with torch.no_grad():
        for x in dl:
            _ = swinbrain_model(x.to(device))
            pooled = pool(swinbrain_hook_output[0])
            swinbrain_list.append(pooled.view(pooled.size(0), -1).cpu().numpy())
    hook_handle.remove()
    swinbrain_features = np.vstack(swinbrain_list)
    print(f"  SwinBrain: {swinbrain_features.shape}")
    del swinbrain_model

    cnn_features_dict = {}
    for tag, seeds in CNN_SEED_CONFIGS.items():
        print(f"Extracting {tag} (seeds={seeds})")
        cnn_features_dict[tag] = extract_cnn_features_multiseed(turboprep_paths, seeds, device)
        print(f"  {tag}: {cnn_features_dict[tag].shape}")

    return anatcl_features, anatcl_local_features, swinbrain_features, cnn_features_dict

def extract_features_full(cat12_paths, turboprep_paths, device):
    anatcl_transform = transforms.Compose([
        transforms.Lambda(lambda x: torch.from_numpy(x.copy()).float()),
        transforms.Normalize(mean=0.0, std=1.0)
    ])

    print("Extracting AnatCL (global)")
    all_fold_features = []
    for fold_idx in range(5):
        path = os.path.join(ANATCL_ENCODER_PATH, f"weights_global_fold{fold_idx}.pth")
        encoder = AnatCL(descriptor="global", fold=fold_idx, pretrained=False).to(device).eval()
        encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
        for p in encoder.parameters(): p.requires_grad = False
        dl = DataLoader(CAT12VBMDataset(cat12_paths, anatcl_transform), batch_size=32, num_workers=0)
        with torch.no_grad():
            all_fold_features.append(torch.cat([encoder(v.to(device)).cpu() for v in dl]).numpy())
        del encoder
    anatcl_features = np.mean(all_fold_features, axis=0)
    print(f"  AnatCL (global): {anatcl_features.shape}")

    print("Extracting AnatCL (local)")
    all_fold_features_local = []
    for fold_idx in range(5):
        path = os.path.join(ANATCL_ENCODER_PATH, f"weights_local_fold{fold_idx}.pth")
        encoder = AnatCL(descriptor="local", fold=fold_idx, pretrained=False).to(device).eval()
        encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
        for p in encoder.parameters(): p.requires_grad = False
        dl = DataLoader(CAT12VBMDataset(cat12_paths, anatcl_transform), batch_size=32, num_workers=0)
        with torch.no_grad():
            all_fold_features_local.append(torch.cat([encoder(v.to(device)).cpu() for v in dl]).numpy())
        del encoder
    anatcl_local_features = np.mean(all_fold_features_local, axis=0)
    print(f"  AnatCL (local): {anatcl_local_features.shape}")

    print("Extracting BrainIAC")
    brainiac_vit = load_brainiac_vit(BRAINIAC_CKPT, device)
    brainiac_transform = lambda x: (x - x.mean()) / (x.std() + 1e-6)
    dl = DataLoader(BrainIACDataset(turboprep_paths, brainiac_transform), batch_size=16, num_workers=0)
    brainiac_list = []
    with torch.no_grad():
        for x in dl:
            out = brainiac_vit(x.to(device))
            cls_token = out[0][:, 0] if isinstance(out, tuple) else out[:, 0]
            brainiac_list.append(cls_token.cpu().numpy())
    brainiac_features = np.vstack(brainiac_list)
    print(f"  BrainIAC: {brainiac_features.shape}")
    del brainiac_vit

    print("Extracting SwinBrain")
    swinbrain_model = load_swinbrain(SWINBRAIN_CKPT, device)
    swinbrain_hook_output = [None]
    def swinbrain_hook(module, input, output): swinbrain_hook_output[0] = output
    hook_handle = swinbrain_model.encoder10.register_forward_hook(swinbrain_hook)
    pool = nn.AdaptiveAvgPool3d((1, 1, 1))
    dl = DataLoader(SwinBrainDataset(turboprep_paths), batch_size=8, num_workers=0)
    swinbrain_list = []
    with torch.no_grad():
        for x in dl:
            _ = swinbrain_model(x.to(device))
            pooled = pool(swinbrain_hook_output[0])
            swinbrain_list.append(pooled.view(pooled.size(0), -1).cpu().numpy())
    hook_handle.remove()
    swinbrain_features = np.vstack(swinbrain_list)
    print(f"  SwinBrain: {swinbrain_features.shape}")
    del swinbrain_model

    cnn_features_dict = {}
    for tag, seeds in CNN_SEED_CONFIGS.items():
        print(f"Extracting {tag} (seeds={seeds})")
        cnn_features_dict[tag] = extract_cnn_features_multiseed(turboprep_paths, seeds, device)
        print(f"  {tag}: {cnn_features_dict[tag].shape}")

    print("Extracting 3D-Neuro-SimCLR")
    simclr_encoder = load_simclr_encoder(SIMCLR_CKPT, device)
    dl = DataLoader(TurboPrepDataset(turboprep_paths), batch_size=8, num_workers=0)
    simclr_list = []
    with torch.no_grad():
        for x in dl:
            simclr_list.append(simclr_encoder(x.to(device)).cpu().numpy())
    simclr_features = np.vstack(simclr_list)
    print(f"  3D-Neuro-SimCLR: {simclr_features.shape}")
    del simclr_encoder

    return anatcl_features, anatcl_local_features, brainiac_features, swinbrain_features, cnn_features_dict, simclr_features

def mean_statistic(sample, axis):
    return np.mean(sample, axis=axis)

def run_permutation_test_classification(results, fs_key, models_to_test, n_resamples=9999):
    perm_results = {}
    if fs_key not in results:
        return perm_results
    fs_folds = np.array(results[fs_key]['fold_results_alpha1']['val_acc'])
    fs_mean = np.mean(fs_folds)
    models_beating_fs = [m for m in models_to_test if m != fs_key and m in results and 'Dummy' not in m
                         and np.mean(results[m]['fold_results_alpha1']['val_acc']) > fs_mean]
    n_tests = len(models_beating_fs)
    if n_tests == 0:
        return perm_results
    for model in models_beating_fs:
        model_folds = np.array(results[model]['fold_results_alpha1']['val_acc'])
        diff = model_folds - fs_folds
        res = permutation_test((diff,), mean_statistic, permutation_type='samples', n_resamples=n_resamples, vectorized=True, axis=0, alternative='greater')
        p_corr = min(res.pvalue * n_tests, 1.0)
        perm_results[model] = {'p_raw': res.pvalue, 'p_corr': p_corr, 'n_tests': n_tests, 'diff': np.mean(diff), 'significant': p_corr < 0.05}
        print(f"    {model}: diff={np.mean(diff):.4f}, p_raw={res.pvalue:.4f}, p_corr={p_corr:.4f}")
    return perm_results

def run_permutation_test_regression(results, fs_key, models_to_test, n_resamples=9999):
    perm_results = {}
    if fs_key not in results:
        return perm_results
    fs_folds = np.array(results[fs_key]['fold_results_alpha1']['val_mae'])
    fs_mean = np.mean(fs_folds)
    models_beating_fs = [m for m in models_to_test if m != fs_key and m in results and 'Dummy' not in m
                         and np.mean(results[m]['fold_results_alpha1']['val_mae']) < fs_mean]
    n_tests = len(models_beating_fs)
    if n_tests == 0:
        return perm_results
    for model in models_beating_fs:
        model_folds = np.array(results[model]['fold_results_alpha1']['val_mae'])
        diff = model_folds - fs_folds
        res = permutation_test((diff,), mean_statistic, permutation_type='samples', n_resamples=n_resamples, vectorized=True, axis=0, alternative='less')
        p_corr = min(res.pvalue * n_tests, 1.0)
        perm_results[model] = {'p_raw': res.pvalue, 'p_corr': p_corr, 'n_tests': n_tests, 'diff': np.mean(diff), 'significant': p_corr < 0.05}
        print(f"    {model}: diff={np.mean(diff):.4f}, p_raw={res.pvalue:.4f}, p_corr={p_corr:.4f}")
    return perm_results

alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

def get_training_sizes(n_total, alphas, test_size=0.1, n_folds=5):
    n_test = int(test_size * n_total)
    n_cv = n_total - n_test
    n_train_fold = n_cv - (n_cv // n_folds)
    return [int(a * n_train_fold) for a in alphas]

def export_cv_results_classification(results, training_sizes, dataset_name, task_name, models_list):
    rows = []
    for model in models_list:
        if model not in results:
            continue
        cv_res = results[model].get('cv_results')
        for i, alpha in enumerate(alphas):
            if cv_res and alpha in cv_res.get('acc', {}):
                rows.append({
                    'dataset': dataset_name, 'task': task_name, 'model': model, 'alpha': alpha,
                    'n_train': training_sizes[i],
                    'val_acc_mean': cv_res['acc'][alpha], 'val_acc_std': cv_res['acc_std'][alpha],
                    'val_auc_mean': cv_res['auc'].get(alpha, np.nan), 'val_auc_std': cv_res['auc_std'].get(alpha, np.nan),
                    'train_acc_mean': cv_res.get('train_acc', {}).get(alpha, np.nan),
                    'train_acc_std': cv_res.get('train_acc_std', {}).get(alpha, np.nan),
                    'train_auc_mean': cv_res.get('train_auc', {}).get(alpha, np.nan),
                    'train_auc_std': cv_res.get('train_auc_std', {}).get(alpha, np.nan),
                    'test_acc_mean': results[model]['test_acc_mean'], 'test_acc_std': results[model]['test_acc_std'],
                    'test_auc_mean': results[model]['test_auc_mean'], 'test_auc_std': results[model]['test_auc_std'],
                })
    if 'DummyClassifier' in results:
        dummy = results['DummyClassifier']
        for i, alpha in enumerate(alphas):
            rows.append({
                'dataset': dataset_name, 'task': task_name, 'model': 'Dummy', 'alpha': alpha,
                'n_train': training_sizes[i],
                'val_acc_mean': dummy['cv_acc_mean'], 'val_acc_std': dummy['cv_acc_std'],
                'val_auc_mean': dummy['cv_auc_mean'], 'val_auc_std': dummy['cv_auc_std'],
                'train_acc_mean': np.nan, 'train_acc_std': np.nan,
                'train_auc_mean': np.nan, 'train_auc_std': np.nan,
                'test_acc_mean': dummy['test_acc_mean'], 'test_acc_std': dummy['test_acc_std'],
                'test_auc_mean': dummy['test_auc_mean'], 'test_auc_std': dummy['test_auc_std'],
            })
    return pd.DataFrame(rows)

def export_cv_results_regression(results, training_sizes, dataset_name, task_name, models_list):
    rows = []
    for model in models_list:
        if model not in results:
            continue
        cv_res = results[model].get('cv_results')
        for i, alpha in enumerate(alphas):
            if cv_res and alpha in cv_res.get('mae', {}):
                rows.append({
                    'dataset': dataset_name, 'task': task_name, 'model': model, 'alpha': alpha,
                    'n_train': training_sizes[i],
                    'val_mae_mean': cv_res['mae'][alpha], 'val_mae_std': cv_res['mae_std'][alpha],
                    'val_r2_mean': cv_res['r2'].get(alpha, np.nan), 'val_r2_std': cv_res['r2_std'].get(alpha, np.nan),
                    'train_mae_mean': cv_res.get('train_mae', {}).get(alpha, np.nan),
                    'train_mae_std': cv_res.get('train_mae_std', {}).get(alpha, np.nan),
                    'train_r2_mean': cv_res.get('train_r2', {}).get(alpha, np.nan),
                    'train_r2_std': cv_res.get('train_r2_std', {}).get(alpha, np.nan),
                    'test_mae_mean': results[model]['test_mae_mean'], 'test_mae_std': results[model]['test_mae_std'],
                    'test_r2_mean': results[model]['test_r2_mean'], 'test_r2_std': results[model]['test_r2_std'],
                })
    if 'DummyRegressor' in results:
        dummy = results['DummyRegressor']
        for i, alpha in enumerate(alphas):
            rows.append({
                'dataset': dataset_name, 'task': task_name, 'model': 'Dummy', 'alpha': alpha,
                'n_train': training_sizes[i],
                'val_mae_mean': dummy['cv_mae_mean'], 'val_mae_std': dummy['cv_mae_std'],
                'val_r2_mean': dummy['cv_r2_mean'], 'val_r2_std': dummy['cv_r2_std'],
                'train_mae_mean': np.nan, 'train_mae_std': np.nan,
                'train_r2_mean': np.nan, 'train_r2_std': np.nan,
                'test_mae_mean': dummy['test_mae_mean'], 'test_mae_std': dummy['test_mae_std'],
                'test_r2_mean': dummy['test_r2_mean'], 'test_r2_std': dummy['test_r2_std'],
            })
    return pd.DataFrame(rows)

# PPMI Dataset
print("PPMI Dataset")

PPMI_CAT12_BASE_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs"
PPMI_DATA_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data"
PPMI_LABELS_PATH = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/processed_cohort_with_mri.csv"
PPMI_TURBOPREP_DIR = "/home/arelbaha/links/scratch/turboprep_output_ppmi"

print("Loading PPMI demographics...")
ppmi_labels_df = pd.read_csv(PPMI_LABELS_PATH)
ppmi_id_sex_dict, ppmi_id_parkinson_dict, ppmi_id_age_dict = {}, {}, {}
for _, row in ppmi_labels_df.iterrows():
    patno = str(int(row['PATNO']))
    ppmi_id_sex_dict[patno] = 1 if row['Sex'].strip().upper() == 'F' else 0
    ppmi_id_parkinson_dict[patno] = 1 if row['Group'].strip() == 'PD' else 0
    ppmi_id_age_dict[patno] = row['Age']
print(f"  {len(ppmi_id_sex_dict)} subjects")

def extract_patno_from_path(filepath):
    for part in filepath.split(os.sep):
        if part.startswith('sub-'):
            return part[4:]
    return None

print("Finding PPMI files")
ppmi_cat12_files = glob.glob(os.path.join(PPMI_CAT12_BASE_DIR, "**", "mwp1*.nii*"), recursive=True)
ppmi_cat12_data = {}
for f in ppmi_cat12_files:
    if os.path.isfile(f):
        patno = extract_patno_from_path(f)
        if patno and patno in ppmi_id_sex_dict:
            ppmi_cat12_data[patno] = f

print("Finding PPMI TurboPrep files")
ppmi_turboprep_data = {}
for folder in os.listdir(PPMI_TURBOPREP_DIR):
    folder_path = os.path.join(PPMI_TURBOPREP_DIR, folder)
    if os.path.isdir(folder_path):
        normalized_file = os.path.join(folder_path, "normalized.nii.gz")
        if os.path.exists(normalized_file) and folder in ppmi_id_sex_dict:
            ppmi_turboprep_data[folder] = normalized_file
print(f"  Found {len(ppmi_turboprep_data)} TurboPrep subjects")

print("Loading PPMI FreeSurfer")
ppmi_fs_cth_df = pd.read_csv(os.path.join(PPMI_CAT12_BASE_DIR, "FS7_APARC_CTH_23Oct2025.csv"))
ppmi_fs_sa_df  = pd.read_csv(os.path.join(PPMI_CAT12_BASE_DIR, "FS7_APARC_SA_23Oct2025.csv"))
ppmi_fs_cth_df = ppmi_fs_cth_df[ppmi_fs_cth_df['EVENT_ID'] == 'BL'].copy()
ppmi_fs_sa_df  = ppmi_fs_sa_df[ppmi_fs_sa_df['EVENT_ID'] == 'BL'].copy()
ppmi_fs_cth_df['PATNO'] = ppmi_fs_cth_df['PATNO'].astype(str)
ppmi_fs_sa_df['PATNO']  = ppmi_fs_sa_df['PATNO'].astype(str)
ppmi_cth_features = [c for c in ppmi_fs_cth_df.columns if c not in ['PATNO', 'EVENT_ID']]
ppmi_sa_features  = [c for c in ppmi_fs_sa_df.columns  if c not in ['PATNO', 'EVENT_ID']]

ppmi_fs_feature_info = ([{'feature_name': f, 'feature_type': 'Thickness'}    for f in ppmi_cth_features] +
                         [{'feature_name': f, 'feature_type': 'Surface_Area'} for f in ppmi_sa_features])

ppmi_fs_data = {}
for patno in set(ppmi_cat12_data.keys()) & set(ppmi_turboprep_data.keys()):
    cth_row = ppmi_fs_cth_df[ppmi_fs_cth_df['PATNO'] == patno]
    sa_row  = ppmi_fs_sa_df[ppmi_fs_sa_df['PATNO'] == patno]
    if len(cth_row) > 0 and len(sa_row) > 0:
        combined = np.concatenate([cth_row[ppmi_cth_features].values.flatten(),
                                   sa_row[ppmi_sa_features].values.flatten()])
        if not np.any(np.isnan(combined)):
            ppmi_fs_data[patno] = combined

ppmi_common_subjects = sorted(list(set(ppmi_cat12_data.keys()) & set(ppmi_turboprep_data.keys()) & set(ppmi_fs_data.keys())))
print(f"  Common subjects: {len(ppmi_common_subjects)}")

ppmi_cat12_paths     = [ppmi_cat12_data[p]    for p in ppmi_common_subjects]
ppmi_turboprep_paths = [ppmi_turboprep_data[p] for p in ppmi_common_subjects]
ppmi_fs_features     = np.array([ppmi_fs_data[p]        for p in ppmi_common_subjects])
ppmi_sex_labels      = np.array([ppmi_id_sex_dict[p]    for p in ppmi_common_subjects])
ppmi_parkinson_labels= np.array([ppmi_id_parkinson_dict[p] for p in ppmi_common_subjects])
ppmi_age_labels      = np.array([ppmi_id_age_dict[p]    for p in ppmi_common_subjects])

ppmi_anatcl_features, ppmi_anatcl_local_features, ppmi_swinbrain_features, ppmi_cnn_features_dict = \
    extract_features_ppmi(ppmi_cat12_paths, ppmi_turboprep_paths, device)

ppmi_features_dict = {
    'AnatCL_Global': ppmi_anatcl_features,
    'AnatCL_Local':  ppmi_anatcl_local_features,
    'SwinBrain':     ppmi_swinbrain_features,
    **ppmi_cnn_features_dict,
    'FreeSurfer':    ppmi_fs_features,
}

compute_pairwise_correlations(ppmi_features_dict, "PPMI")

print("\nComputing PPMI correlation matrix...")
ppmi_corr_matrix, ppmi_boundaries, ppmi_model_names, ppmi_reorder_indices, ppmi_cluster_dfs = \
    compute_correlation_with_cluster_labels(ppmi_features_dict, {'FreeSurfer': ppmi_fs_feature_info}, ['FreeSurfer'], "PPMI")

ppmi_models = list(ppmi_features_dict.keys())
ppmi_sex_results       = run_classification(ppmi_features_dict, ppmi_sex_labels,       task_name="PPMI Sex Classification")
ppmi_parkinson_results = run_classification(ppmi_features_dict, ppmi_parkinson_labels,  task_name="PPMI Parkinson Classification")

ppmi_male_indices   = np.where(ppmi_sex_labels == 0)[0]
ppmi_male_features_dict  = {k: v[ppmi_male_indices]   for k, v in ppmi_features_dict.items()}
ppmi_male_age_labels     = ppmi_age_labels[ppmi_male_indices]
print(f"\nPPMI Male: {len(ppmi_male_indices)} subjects")
ppmi_male_age_results = run_regression(ppmi_male_features_dict, ppmi_male_age_labels, task_name="PPMI Male Age")

ppmi_female_indices = np.where(ppmi_sex_labels == 1)[0]
ppmi_female_features_dict = {k: v[ppmi_female_indices] for k, v in ppmi_features_dict.items()}
ppmi_female_age_labels    = ppmi_age_labels[ppmi_female_indices]
print(f"\nPPMI Female: {len(ppmi_female_indices)} subjects")
ppmi_female_age_results = run_regression(ppmi_female_features_dict, ppmi_female_age_labels, task_name="PPMI Female Age")

# HBN Dataset
print("HBN Dataset")

HBN_BIDS          = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_BIDS"
HBN_BIDS_LOWER    = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/hbn_bids"
HBN_FREESURFER_DIR= "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_FreeSurfer/freesurfer"
HBN_TURBOPREP_DIR = "/home/arelbaha/links/scratch/turboprep_output"
HBN_DEMO_FILE     = os.path.join(HBN_BIDS, "final_preprocessed_subjects_with_demographics.tsv")
HBN_PARCELLATION  = "Schaefer2018_400Parcels_17Networks_order"

print("Loading HBN demographics")
hbn_demo_df = pd.read_csv(HBN_DEMO_FILE, sep='\t')
hbn_demo_df['participant_id'] = hbn_demo_df['participant_id'].astype(str)
hbn_id_sex_dict, hbn_id_age_dict = {}, {}
for _, row in hbn_demo_df.iterrows():
    hbn_id_sex_dict[row['participant_id']] = 1 if row['sex'].strip() == 'Female' else 0
    hbn_id_age_dict[row['participant_id']] = row['age']
print(f"  {len(hbn_id_sex_dict)} subjects")

print("Finding HBN files")
hbn_cat12_data = {}
for subject_id in hbn_id_sex_dict.keys():
    for base in [HBN_BIDS, HBN_BIDS_LOWER]:
        files = glob.glob(os.path.join(base, f"sub-{subject_id}", "ses-*", "anat", "mri", "mwp1sub*.nii"))
        if files:
            hbn_cat12_data[subject_id] = files[0]
            break

print("Finding HBN TurboPrep files")
hbn_turboprep_data = {}
for folder in os.listdir(HBN_TURBOPREP_DIR):
    folder_path = os.path.join(HBN_TURBOPREP_DIR, folder)
    if os.path.isdir(folder_path) and folder.startswith('NDAR'):
        normalized_file = os.path.join(folder_path, "normalized.nii.gz")
        if os.path.exists(normalized_file) and folder in hbn_id_sex_dict:
            hbn_turboprep_data[folder] = normalized_file
print(f"  Found {len(hbn_turboprep_data)} TurboPrep subjects")

print("Loading HBN FreeSurfer Schaefer")
hbn_fs_data, hbn_fs_region_info = {}, {}
for subject_id in hbn_id_sex_dict.keys():
    stats_file = os.path.join(HBN_FREESURFER_DIR, f"sub-{subject_id}", f"sub-{subject_id}_regionsurfacestats.tsv")
    if os.path.exists(stats_file):
        df = pd.read_csv(stats_file, sep='\t')
        fdf = df[df["atlas"] == HBN_PARCELLATION].sort_values("StructName")
        if not fdf.empty and "SurfArea" in fdf.columns and "ThickAvg" in fdf.columns:
            sa = fdf["SurfArea"].values[:400]; th = fdf["ThickAvg"].values[:400]
            combined = np.concatenate([sa, th])
            if not np.any(np.isnan(combined)) and len(sa) == 400 and len(th) == 400:
                hbn_fs_data[subject_id] = combined
                if not hbn_fs_region_info:
                    hbn_fs_region_info = {'region_names': fdf['StructName'].values[:400].tolist(),
                                          'hemisphere': fdf['hemisphere'].values[:400].tolist()}

hbn_fs_feature_info = []
if hbn_fs_region_info:
    for i in range(400):
        hbn_fs_feature_info.append({'feature_name': hbn_fs_region_info['region_names'][i], 'feature_type': 'Surface_Area', 'hemisphere': hbn_fs_region_info['hemisphere'][i]})
    for i in range(400):
        hbn_fs_feature_info.append({'feature_name': hbn_fs_region_info['region_names'][i], 'feature_type': 'Thickness', 'hemisphere': hbn_fs_region_info['hemisphere'][i]})

print("Loading HBN FreeSurfer aparc")
hbn_aparc_fs_data, hbn_aparc_region_info = {}, {}
for subject_id in hbn_id_sex_dict.keys():
    stats_file = os.path.join(HBN_FREESURFER_DIR, f"sub-{subject_id}", f"sub-{subject_id}_regionsurfacestats.tsv")
    if os.path.exists(stats_file):
        df = pd.read_csv(stats_file, sep='\t')
        fdf = df[df["atlas"] == "aparc"].sort_values("StructName")
        if not fdf.empty and "SurfArea" in fdf.columns and "ThickAvg" in fdf.columns:
            sa = fdf["SurfArea"].values; th = fdf["ThickAvg"].values
            combined = np.concatenate([th, sa])
            if not np.any(np.isnan(combined)) and len(sa) == 68 and len(th) == 68:
                hbn_aparc_fs_data[subject_id] = combined
                if not hbn_aparc_region_info:
                    hbn_aparc_region_info = {'region_names': fdf['StructName'].values.tolist(),
                                             'hemisphere': fdf['hemisphere'].values.tolist()}

hbn_aparc_fs_feature_info = []
if hbn_aparc_region_info:
    for i in range(68):
        hbn_aparc_fs_feature_info.append({'feature_name': hbn_aparc_region_info['region_names'][i], 'feature_type': 'Thickness', 'hemisphere': hbn_aparc_region_info['hemisphere'][i]})
    for i in range(68):
        hbn_aparc_fs_feature_info.append({'feature_name': hbn_aparc_region_info['region_names'][i], 'feature_type': 'Surface_Area', 'hemisphere': hbn_aparc_region_info['hemisphere'][i]})

hbn_common_subjects = sorted(list(set(hbn_cat12_data.keys()) & set(hbn_turboprep_data.keys()) &
                                   set(hbn_fs_data.keys()) & set(hbn_aparc_fs_data.keys())))
print(f"Common subjects: {len(hbn_common_subjects)}")

hbn_cat12_paths      = [hbn_cat12_data[s]      for s in hbn_common_subjects]
hbn_turboprep_paths  = [hbn_turboprep_data[s]  for s in hbn_common_subjects]
hbn_fs_features      = np.array([hbn_fs_data[s]       for s in hbn_common_subjects])
hbn_aparc_fs_features= np.array([hbn_aparc_fs_data[s] for s in hbn_common_subjects])
hbn_sex_labels       = np.array([hbn_id_sex_dict[s]   for s in hbn_common_subjects])
hbn_age_labels       = np.array([hbn_id_age_dict[s]   for s in hbn_common_subjects])

hbn_anatcl_features, hbn_anatcl_local_features, hbn_brainiac_features, hbn_swinbrain_features, hbn_cnn_features_dict, hbn_simclr_features = \
    extract_features_full(hbn_cat12_paths, hbn_turboprep_paths, device)

hbn_features_dict = {
    'AnatCL_Global':    hbn_anatcl_features,
    'AnatCL_Local':     hbn_anatcl_local_features,
    'BrainIAC':         hbn_brainiac_features,
    'SwinBrain':        hbn_swinbrain_features,
    **hbn_cnn_features_dict,
    '3D-Neuro-SimCLR':  hbn_simclr_features,
    'FS_Schaefer':      hbn_fs_features,
    'FS_aparc':         hbn_aparc_fs_features,
}

compute_pairwise_correlations(hbn_features_dict, "HBN")

print("\nComputing HBN correlation matrix...")
hbn_corr_matrix, hbn_boundaries, hbn_model_names, hbn_reorder_indices, hbn_cluster_dfs = \
    compute_correlation_with_cluster_labels(hbn_features_dict,
        {'FS_Schaefer': hbn_fs_feature_info, 'FS_aparc': hbn_aparc_fs_feature_info},
        ['FS_Schaefer', 'FS_aparc'], "HBN")

hbn_models = list(hbn_features_dict.keys())
hbn_sex_results = run_classification(hbn_features_dict, hbn_sex_labels, task_name="HBN Sex Classification")

hbn_male_indices  = np.where(hbn_sex_labels == 0)[0]
hbn_male_features_dict  = {k: v[hbn_male_indices]  for k, v in hbn_features_dict.items()}
hbn_male_age_labels     = hbn_age_labels[hbn_male_indices]
print(f"\nHBN Male: {len(hbn_male_indices)} subjects")
hbn_male_age_results = run_regression(hbn_male_features_dict, hbn_male_age_labels, task_name="HBN Male Age")

hbn_female_indices = np.where(hbn_sex_labels == 1)[0]
hbn_female_features_dict  = {k: v[hbn_female_indices] for k, v in hbn_features_dict.items()}
hbn_female_age_labels     = hbn_age_labels[hbn_female_indices]
print(f"\nHBN Female: {len(hbn_female_indices)} subjects")
hbn_female_age_results = run_regression(hbn_female_features_dict, hbn_female_age_labels, task_name="HBN Female Age")

# NKI Dataset
print("NKI Dataset")

NKI_BASE          = "/home/arelbaha/links/scratch/NKI"
NKI_BIDS_DIR      = os.path.join(NKI_BASE, "NKI_BIDS")
NKI_TURBOPREP_DIR = "/home/arelbaha/links/scratch/turboprep_NKI"
NKI_FREESURFER_DIR= os.path.join(NKI_BASE, "NKI_FreeSurfer", "freesurfer")
NKI_DEMO_FILE     = os.path.join(NKI_BASE, "final_demographics.txt")
NKI_SUBJECTS_FILE = os.path.join(NKI_BASE, "final_subjects.txt")
NKI_PARCELLATION  = "Schaefer2018_400Parcels_17Networks_order"

print("Loading NKI demographics")
nki_id_sex_dict, nki_id_age_dict = {}, {}
with open(NKI_DEMO_FILE) as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) == 3:
            subj_id, sex, age = parts
            nki_id_sex_dict[subj_id] = 1 if sex.strip() == 'Female' else 0
            nki_id_age_dict[subj_id] = float(age)
print(f"  {len(nki_id_sex_dict)} subjects from demographics")

print("Finding NKI CAT12 files")
nki_cat12_data = {}
for subj_id in nki_id_sex_dict.keys():
    files = glob.glob(os.path.join(NKI_BIDS_DIR, f"sub-{subj_id}", "ses-BAS1", "anat", "mri", "mwp1*.nii"))
    if files:
        nki_cat12_data[subj_id] = files[0]
print(f"  Found {len(nki_cat12_data)} CAT12 subjects")

print("Finding NKI TurboPrep files")
nki_turboprep_data = {}
for folder in os.listdir(NKI_TURBOPREP_DIR):
    folder_path = os.path.join(NKI_TURBOPREP_DIR, folder)
    if os.path.isdir(folder_path):
        normalized_file = os.path.join(folder_path, "normalized.nii.gz")
        if os.path.exists(normalized_file) and folder in nki_id_sex_dict:
            nki_turboprep_data[folder] = normalized_file
print(f"  Found {len(nki_turboprep_data)} TurboPrep subjects")

print("Loading NKI FreeSurfer Schaefer")
nki_fs_data, nki_fs_region_info = {}, {}
for subj_id in nki_id_sex_dict.keys():
    stats_file = os.path.join(NKI_FREESURFER_DIR, f"sub-{subj_id}_ses-BAS1", f"sub-{subj_id}_ses-BAS1_regionsurfacestats.tsv")
    if os.path.exists(stats_file):
        df = pd.read_csv(stats_file, sep='\t')
        fdf = df[df["atlas"] == NKI_PARCELLATION].sort_values("StructName")
        if not fdf.empty and "SurfArea" in fdf.columns and "ThickAvg" in fdf.columns:
            sa = fdf["SurfArea"].values[:400]; th = fdf["ThickAvg"].values[:400]
            combined = np.concatenate([sa, th])
            if not np.any(np.isnan(combined)) and len(sa) == 400 and len(th) == 400:
                nki_fs_data[subj_id] = combined
                if not nki_fs_region_info:
                    nki_fs_region_info = {'region_names': fdf['StructName'].values[:400].tolist(),
                                          'hemisphere': fdf['hemisphere'].values[:400].tolist()}

nki_fs_feature_info = []
if nki_fs_region_info:
    for i in range(400):
        nki_fs_feature_info.append({'feature_name': nki_fs_region_info['region_names'][i], 'feature_type': 'Surface_Area', 'hemisphere': nki_fs_region_info['hemisphere'][i]})
    for i in range(400):
        nki_fs_feature_info.append({'feature_name': nki_fs_region_info['region_names'][i], 'feature_type': 'Thickness', 'hemisphere': nki_fs_region_info['hemisphere'][i]})

print("Loading NKI FreeSurfer aparc")
nki_aparc_fs_data, nki_aparc_region_info = {}, {}
for subj_id in nki_id_sex_dict.keys():
    stats_file = os.path.join(NKI_FREESURFER_DIR, f"sub-{subj_id}_ses-BAS1", f"sub-{subj_id}_ses-BAS1_regionsurfacestats.tsv")
    if os.path.exists(stats_file):
        df = pd.read_csv(stats_file, sep='\t')
        fdf = df[df["atlas"] == "aparc"].sort_values("StructName")
        if not fdf.empty and "SurfArea" in fdf.columns and "ThickAvg" in fdf.columns:
            sa = fdf["SurfArea"].values; th = fdf["ThickAvg"].values
            combined = np.concatenate([th, sa])
            if not np.any(np.isnan(combined)) and len(sa) == 68 and len(th) == 68:
                nki_aparc_fs_data[subj_id] = combined
                if not nki_aparc_region_info:
                    nki_aparc_region_info = {'region_names': fdf['StructName'].values.tolist(),
                                             'hemisphere': fdf['hemisphere'].values.tolist()}

nki_aparc_fs_feature_info = []
if nki_aparc_region_info:
    for i in range(68):
        nki_aparc_fs_feature_info.append({'feature_name': nki_aparc_region_info['region_names'][i], 'feature_type': 'Thickness', 'hemisphere': nki_aparc_region_info['hemisphere'][i]})
    for i in range(68):
        nki_aparc_fs_feature_info.append({'feature_name': nki_aparc_region_info['region_names'][i], 'feature_type': 'Surface_Area', 'hemisphere': nki_aparc_region_info['hemisphere'][i]})

nki_common_subjects = sorted(list(
    set(nki_cat12_data.keys()) & set(nki_turboprep_data.keys()) &
    set(nki_fs_data.keys()) & set(nki_aparc_fs_data.keys())
))
print(f"Common subjects: {len(nki_common_subjects)}")

nki_cat12_paths      = [nki_cat12_data[s]       for s in nki_common_subjects]
nki_turboprep_paths  = [nki_turboprep_data[s]   for s in nki_common_subjects]
nki_fs_features      = np.array([nki_fs_data[s]       for s in nki_common_subjects])
nki_aparc_fs_features= np.array([nki_aparc_fs_data[s] for s in nki_common_subjects])
nki_sex_labels       = np.array([nki_id_sex_dict[s]   for s in nki_common_subjects])
nki_age_labels       = np.array([nki_id_age_dict[s]   for s in nki_common_subjects])

nki_anatcl_features, nki_anatcl_local_features, nki_brainiac_features, nki_swinbrain_features, nki_cnn_features_dict, nki_simclr_features = \
    extract_features_full(nki_cat12_paths, nki_turboprep_paths, device)

nki_features_dict = {
    'AnatCL_Global':    nki_anatcl_features,
    'AnatCL_Local':     nki_anatcl_local_features,
    'BrainIAC':         nki_brainiac_features,
    'SwinBrain':        nki_swinbrain_features,
    **nki_cnn_features_dict,
    '3D-Neuro-SimCLR':  nki_simclr_features,
    'FS_Schaefer':      nki_fs_features,
    'FS_aparc':         nki_aparc_fs_features,
}

compute_pairwise_correlations(nki_features_dict, "NKI")

print("\nComputing NKI correlation matrix...")
nki_corr_matrix, nki_boundaries, nki_model_names, nki_reorder_indices, nki_cluster_dfs = \
    compute_correlation_with_cluster_labels(nki_features_dict,
        {'FS_Schaefer': nki_fs_feature_info, 'FS_aparc': nki_aparc_fs_feature_info},
        ['FS_Schaefer', 'FS_aparc'], "NKI")

nki_models = list(nki_features_dict.keys())
nki_sex_results = run_classification(nki_features_dict, nki_sex_labels, task_name="NKI Sex Classification")

nki_male_indices  = np.where(nki_sex_labels == 0)[0]
nki_male_features_dict  = {k: v[nki_male_indices]  for k, v in nki_features_dict.items()}
nki_male_age_labels     = nki_age_labels[nki_male_indices]
print(f"\nNKI Male: {len(nki_male_indices)} subjects")
nki_male_age_results = run_regression(nki_male_features_dict, nki_male_age_labels, task_name="NKI Male Age")

nki_female_indices = np.where(nki_sex_labels == 1)[0]
nki_female_features_dict  = {k: v[nki_female_indices] for k, v in nki_features_dict.items()}
nki_female_age_labels     = nki_age_labels[nki_female_indices]
print(f"\nNKI Female: {len(nki_female_indices)} subjects")
nki_female_age_results = run_regression(nki_female_features_dict, nki_female_age_labels, task_name="NKI Female Age")

# Permutation Tests

print("\nPermutation Tests (Bonferroni corrected)")

print("\nPPMI Sex:")
ppmi_sex_perm = run_permutation_test_classification(ppmi_sex_results, 'FreeSurfer', ppmi_models)
print("\nPPMI Male Age:")
ppmi_male_age_perm = run_permutation_test_regression(ppmi_male_age_results, 'FreeSurfer', ppmi_models)
print("\nPPMI Female Age:")
ppmi_female_age_perm = run_permutation_test_regression(ppmi_female_age_results, 'FreeSurfer', ppmi_models)

print("\nHBN Sex (vs Schaefer):")
hbn_sex_perm_schaefer = run_permutation_test_classification(hbn_sex_results, 'FS_Schaefer', hbn_models)
print("\nHBN Sex (vs aparc):")
hbn_sex_perm_aparc = run_permutation_test_classification(hbn_sex_results, 'FS_aparc', hbn_models)
print("\nHBN Male Age (vs Schaefer):")
hbn_male_perm_schaefer = run_permutation_test_regression(hbn_male_age_results, 'FS_Schaefer', hbn_models)
print("\nHBN Male Age (vs aparc):")
hbn_male_perm_aparc = run_permutation_test_regression(hbn_male_age_results, 'FS_aparc', hbn_models)
print("\nHBN Female Age (vs Schaefer):")
hbn_female_perm_schaefer = run_permutation_test_regression(hbn_female_age_results, 'FS_Schaefer', hbn_models)
print("\nHBN Female Age (vs aparc):")
hbn_female_perm_aparc = run_permutation_test_regression(hbn_female_age_results, 'FS_aparc', hbn_models)

print("\nNKI Sex (vs Schaefer):")
nki_sex_perm_schaefer = run_permutation_test_classification(nki_sex_results, 'FS_Schaefer', nki_models)
print("\nNKI Sex (vs aparc):")
nki_sex_perm_aparc = run_permutation_test_classification(nki_sex_results, 'FS_aparc', nki_models)
print("\nNKI Male Age (vs Schaefer):")
nki_male_perm_schaefer = run_permutation_test_regression(nki_male_age_results, 'FS_Schaefer', nki_models)
print("\nNKI Male Age (vs aparc):")
nki_male_perm_aparc = run_permutation_test_regression(nki_male_age_results, 'FS_aparc', nki_models)
print("\nNKI Female Age (vs Schaefer):")
nki_female_perm_schaefer = run_permutation_test_regression(nki_female_age_results, 'FS_Schaefer', nki_models)
print("\nNKI Female Age (vs aparc):")
nki_female_perm_aparc = run_permutation_test_regression(nki_female_age_results, 'FS_aparc', nki_models)

# Correlation Plots & Exports

plot_correlation_with_labels(ppmi_corr_matrix, ppmi_boundaries, ppmi_model_names, ppmi_cluster_dfs, ['FreeSurfer'], "ppmi")
plot_correlation_with_labels(hbn_corr_matrix,  hbn_boundaries,  hbn_model_names,  hbn_cluster_dfs,  ['FS_Schaefer', 'FS_aparc'], "hbn")
plot_correlation_with_labels(nki_corr_matrix,  nki_boundaries,  nki_model_names,  nki_cluster_dfs,  ['FS_Schaefer', 'FS_aparc'], "nki")

export_correlation_data(ppmi_corr_matrix, ppmi_boundaries, ppmi_model_names, ppmi_cluster_dfs, "PPMI")
export_correlation_data(hbn_corr_matrix,  hbn_boundaries,  hbn_model_names,  hbn_cluster_dfs,  "HBN")
export_correlation_data(nki_corr_matrix,  nki_boundaries,  nki_model_names,  nki_cluster_dfs,  "NKI")

ppmi_cluster_dfs['FreeSurfer'].to_csv('ppmi_freesurfer_features_reordered.csv', index=False)
hbn_cluster_dfs['FS_Schaefer'].to_csv('hbn_schaefer_freesurfer_features_reordered.csv', index=False)
hbn_cluster_dfs['FS_aparc'].to_csv('hbn_aparc_freesurfer_features_reordered.csv', index=False)
nki_cluster_dfs['FS_Schaefer'].to_csv('nki_schaefer_freesurfer_features_reordered.csv', index=False)
nki_cluster_dfs['FS_aparc'].to_csv('nki_aparc_freesurfer_features_reordered.csv', index=False)

# CV Results Export
ppmi_class_sizes  = get_training_sizes(len(ppmi_sex_labels),    alphas)
ppmi_male_sizes   = get_training_sizes(len(ppmi_male_indices),  alphas)
ppmi_female_sizes = get_training_sizes(len(ppmi_female_indices),alphas)
hbn_class_sizes   = get_training_sizes(len(hbn_sex_labels),     alphas)
hbn_male_sizes    = get_training_sizes(len(hbn_male_indices),   alphas)
hbn_female_sizes  = get_training_sizes(len(hbn_female_indices), alphas)
nki_class_sizes   = get_training_sizes(len(nki_sex_labels),     alphas)
nki_male_sizes    = get_training_sizes(len(nki_male_indices),   alphas)
nki_female_sizes  = get_training_sizes(len(nki_female_indices), alphas)

ppmi_sex_cv_df        = export_cv_results_classification(ppmi_sex_results,       ppmi_class_sizes,  'PPMI', 'Sex',       ppmi_models)
ppmi_park_cv_df       = export_cv_results_classification(ppmi_parkinson_results,  ppmi_class_sizes,  'PPMI', 'Parkinson', ppmi_models)
ppmi_male_age_cv_df   = export_cv_results_regression    (ppmi_male_age_results,   ppmi_male_sizes,   'PPMI', 'Male_Age',  ppmi_models)
ppmi_female_age_cv_df = export_cv_results_regression    (ppmi_female_age_results, ppmi_female_sizes, 'PPMI', 'Female_Age',ppmi_models)

hbn_sex_cv_df         = export_cv_results_classification(hbn_sex_results,         hbn_class_sizes,   'HBN',  'Sex',       hbn_models)
hbn_male_age_cv_df    = export_cv_results_regression    (hbn_male_age_results,    hbn_male_sizes,    'HBN',  'Male_Age',  hbn_models)
hbn_female_age_cv_df  = export_cv_results_regression    (hbn_female_age_results,  hbn_female_sizes,  'HBN',  'Female_Age',hbn_models)

nki_sex_cv_df         = export_cv_results_classification(nki_sex_results,         nki_class_sizes,   'NKI',  'Sex',       nki_models)
nki_male_age_cv_df    = export_cv_results_regression    (nki_male_age_results,    nki_male_sizes,    'NKI',  'Male_Age',  nki_models)
nki_female_age_cv_df  = export_cv_results_regression    (nki_female_age_results,  nki_female_sizes,  'NKI',  'Female_Age',nki_models)

ppmi_classification_cv = pd.concat([ppmi_sex_cv_df, ppmi_park_cv_df], ignore_index=True)
ppmi_regression_cv     = pd.concat([ppmi_male_age_cv_df, ppmi_female_age_cv_df], ignore_index=True)
hbn_classification_cv  = hbn_sex_cv_df
hbn_regression_cv      = pd.concat([hbn_male_age_cv_df, hbn_female_age_cv_df], ignore_index=True)
nki_classification_cv  = nki_sex_cv_df
nki_regression_cv      = pd.concat([nki_male_age_cv_df, nki_female_age_cv_df], ignore_index=True)

ppmi_classification_cv.to_csv('ppmi_classification_cv_results.csv', index=False)
ppmi_regression_cv.to_csv('ppmi_regression_cv_results.csv', index=False)
hbn_classification_cv.to_csv('hbn_classification_cv_results.csv', index=False)
hbn_regression_cv.to_csv('hbn_regression_cv_results.csv', index=False)
nki_classification_cv.to_csv('nki_classification_cv_results.csv', index=False)
nki_regression_cv.to_csv('nki_regression_cv_results.csv', index=False)

all_cv_results = pd.concat([
    ppmi_classification_cv.assign(task_type='classification'),
    ppmi_regression_cv.assign(task_type='regression'),
    hbn_classification_cv.assign(task_type='classification'),
    hbn_regression_cv.assign(task_type='regression'),
    nki_classification_cv.assign(task_type='classification'),
    nki_regression_cv.assign(task_type='regression'),
], ignore_index=True)
all_cv_results.to_csv('all_cv_results.csv', index=False)

# Fold-level CV results (5 seeds x 5 folds = 25 per model, alpha=1)

fold_rows = []
for ds_name, task_results_list in [
    ('PPMI', [('Sex', ppmi_sex_results), ('Parkinson', ppmi_parkinson_results),
              ('Male_Age', ppmi_male_age_results), ('Female_Age', ppmi_female_age_results)]),
    ('HBN',  [('Sex', hbn_sex_results),  ('Male_Age', hbn_male_age_results),
              ('Female_Age', hbn_female_age_results)]),
    ('NKI',  [('Sex', nki_sex_results),  ('Male_Age', nki_male_age_results),
              ('Female_Age', nki_female_age_results)]),
]:
    for task, res in task_results_list:
        for model, r in res.items():
            if 'Dummy' in model:
                continue
            fr = r.get('fold_results_alpha1', {})
            if 'val_acc' in fr:
                for i, val in enumerate(fr['val_acc']):
                    fold_rows.append({'dataset': ds_name, 'task': task, 'model': model, 'fold': i, 'metric': 'val_acc', 'value': val})
            if 'val_mae' in fr:
                for i, val in enumerate(fr['val_mae']):
                    fold_rows.append({'dataset': ds_name, 'task': task, 'model': model, 'fold': i, 'metric': 'val_mae', 'value': val})

pd.DataFrame(fold_rows).to_csv('fold_level_results.csv', index=False)
print(f"Saved: fold_level_results.csv ({len(fold_rows)} rows)")

# Held-out test set results (5 seeds x 1 test eval = 5 per model)

heldout_rows = []
for ds_name, task_results_list in [
    ('PPMI', [('Sex', ppmi_sex_results), ('Parkinson', ppmi_parkinson_results),
              ('Male_Age', ppmi_male_age_results), ('Female_Age', ppmi_female_age_results)]),
    ('HBN',  [('Sex', hbn_sex_results),  ('Male_Age', hbn_male_age_results),
              ('Female_Age', hbn_female_age_results)]),
    ('NKI',  [('Sex', nki_sex_results),  ('Male_Age', nki_male_age_results),
              ('Female_Age', nki_female_age_results)]),
]:
    for task, res in task_results_list:
        for model, r in res.items():
            if 'Dummy' in model:
                continue
            tsr = r.get('test_seed_results', {})
            if 'test_acc' in tsr:
                for i, val in enumerate(tsr['test_acc']):
                    heldout_rows.append({'dataset': ds_name, 'task': task, 'model': model,
                                         'seed': SEEDS[i], 'metric': 'test_acc', 'value': val})
            if 'test_auc' in tsr:
                for i, val in enumerate(tsr['test_auc']):
                    heldout_rows.append({'dataset': ds_name, 'task': task, 'model': model,
                                         'seed': SEEDS[i], 'metric': 'test_auc', 'value': val})
            if 'test_mae' in tsr:
                for i, val in enumerate(tsr['test_mae']):
                    heldout_rows.append({'dataset': ds_name, 'task': task, 'model': model,
                                         'seed': SEEDS[i], 'metric': 'test_mae', 'value': val})
            if 'test_r2' in tsr:
                for i, val in enumerate(tsr['test_r2']):
                    heldout_rows.append({'dataset': ds_name, 'task': task, 'model': model,
                                         'seed': SEEDS[i], 'metric': 'test_r2', 'value': val})

pd.DataFrame(heldout_rows).to_csv('heldout_seed_results.csv', index=False)
print(f"Saved: heldout_seed_results.csv ({len(heldout_rows)} rows)")

print("\nAll Files Saved.")