# MRI Foundation Models: Pre-Trained Models' Evaluation

## SPM/CAT12 Preprocessing | AnatCL | HBN 

In [None]:
%%writefile ~/links/projects/rrg-glatard/arelbaha/HBN_BIDS/new_cat12_preprocessing.py

import os
import glob
import boutiques
from boutiques import bosh
from boutiques.descriptor2func import function

boutiques_descriptor = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/descriptors/cat12_prepro.json"
base_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_BIDS"
output_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/preprocessed_output"

bosh(["exec", 
      "prepare", 
      boutiques_descriptor, 
      "--imagepath", 
      "/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif"
])

cat12 = boutiques.descriptor2func.function(boutiques_descriptor)

t1_nii_files = glob.glob(os.path.join(base_dir, "sub-*", "ses-*", "anat", "sub-*_T1w.nii"))
print(f"Found {len(t1_nii_files)} T1w files.")

task_id = int(os.environ["SLURM_ARRAY_TASK_ID"])

if task_id >= len(t1_nii_files) or task_id < 0:
    print(f"SLURM_ARRAY_TASK_ID={task_id} out of range")
    exit(1)

input_file = t1_nii_files[task_id]

path_parts = input_file.split(os.sep)
subject_id = next((part for part in path_parts if part.startswith("sub-")), None)

if subject_id is None:
        print(f"No subject ID in path: {input_file}")
        exit(1)

filename = os.path.basename(input_file)

subject_output_dir = os.path.join(output_dir, subject_id)
os.makedirs(subject_output_dir, exist_ok=True)

print(f"Processing={task_id} -> file: {filename} (subject: {subject_id})")

result = cat12('--imagepath=/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif', 
                   input_file=input_file, 
                   output_dir=subject_output_dir)

result_dict = vars(result)
print("Available result keys:", result_dict.keys())
print("\nExit code:", result_dict.get("exit_code"))
print("\nStdout:\n", result_dict.get ("stdout" ))
print("\nStderr:\n", result_dict.get ("stderr"))
print("\nOutput files:\n", result_dict.get("output_files"))

In [None]:
%%sbatch --array=500-520
#!/bin/bash
#SBATCH --job-name=CAT12_preproc
#SBATCH --account=rrg-glatard
#SBATCH --mem=12G
#SBATCH --cpus-per-task=2
#SBATCH --nodes=1
#SBATCH --output=CAT12_preproc_%A_%a.out
#SBATCH --error=CAT12_preproc_%A_%a.err
#SBATCH --time=1:0:0

source ~/.venvs/jupyter_py3/bin/activate

module load apptainer

cd ~/links/projects/rrg-glatard/arelbaha/HBN_BIDS #Location of preprocessing data and script

echo "Running task ID: $SLURM_ARRAY_TASK_ID"

python new_cat12_preprocessing.py

## SPM/CAT12 Preprocessing | AnatCL | PPMI

In [None]:
%%writefile ~/links/projects/def-glatard/arelbaha/data/inputs/ppmicat12_preprocessing.py

import os
import glob
import boutiques
from boutiques import bosh
from boutiques.descriptor2func import function

#Downloading Container Part + Paths
boutiques_descriptor = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/descriptors/cat12_prepro.json"
base_dir = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs"
output_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/preprocessed_output"

bosh(["exec", "prepare", boutiques_descriptor, "--imagepath", "/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif"])
cat12 = boutiques.descriptor2func.function(boutiques_descriptor)

#Task ID Extraction
data = glob.glob(os.path.join(base_dir, "sub-*", "ses-*", "anat", "*.nii"))
print(f"Found {len(data)}")

task_id = int(os.environ["SLURM_ARRAY_TASK_ID"])

if task_id >= len(data) or task_id < 0:
    print(f"SLURM_ARRAY_TASK_ID={task_id} out of range")
    exit(1)

input_file = data[task_id]

path_parts = input_file.split(os.sep)
subject_id = next((part for part in path_parts if part.startswith("sub-")), None)

if subject_id is None:
        print(f"Could not find subject ID in path: {input_file}")
        exit(1)

filename = os.path.basename(input_file)

subject_output_dir = os.path.join(output_dir, subject_id)
os.makedirs(subject_output_dir, exist_ok=True)

print(f"Processing SLURM_ARRAY_TASK_ID ={task_id} -> file: {filename} (subject: {subject_id})")

result = cat12('--imagepath=/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif', 
                   input_file=input_file, 
                   output_dir=subject_output_dir)

result_dict = vars(result)
print("Available result keys:", result_dict.keys())
print("\nExit code:", result_dict.get("exit_code"))
print("\nStdout:\n", result_dict.get ("stdout" ))
print("\nStderr:\n", result_dict.get ("stderr"))
print("\nOutput files:\n", result_dict.get("output_files"))

In [None]:
%%sbatch --array=1040-1040
#!/bin/bash
#SBATCH --job-name=CAT12_parkinson_preproc
#SBATCH --account=rrg-glatard
#SBATCH --mem=12G
#SBATCH --cpus-per-task=2
#SBATCH --nodes=1
#SBATCH --output=parkinson_preproc_%A_%a.out
#SBATCH --error=parkinson_preproc_%A_%a.err
#SBATCH --time=1:0:0

source ~/.venvs/jupyter_py3/bin/activate

module load apptainer

cd ~/links/projects/def-glatard/arelbaha/data/inputs #Location of preprocessing data and script

echo "Running task ID: $SLURM_ARRAY_TASK_ID"

python classification_parkinson.py

### BrainIAC and CNN Preprocessing | HD-BET | PPMI

In [None]:
%%sbatch --array=23,24
#!/bin/bash
#SBATCH --account=def-glatard
#SBATCH --time=20:00:00
#SBATCH --mem=16G
#SBATCH --cpus-per-task=4
#SBATCH --job-name=brainiac_batch
#SBATCH --output=logs/prep_batch_%a.out
#SBATCH --error=logs/prep_batch_%a.err

export BASE_DIR="/home/arelbaha/links/projects/def-glatard/arelbaha/data"
export RAW_DIR="${BASE_DIR}/raw_files_brainiac"
export OUTPUT_DIR="${BASE_DIR}/processed_outputs"

# Load environment
module load python/3.11
module load opencv
source /home/arelbaha/.venvs/brainiac_env/bin/activate

BATCH_DIR="${RAW_DIR}/batch_${SLURM_ARRAY_TASK_ID}"
BATCH_OUTPUT="${OUTPUT_DIR}/batch_${SLURM_ARRAY_TASK_ID}"

mkdir -p "$BATCH_OUTPUT"

echo "Processing batch ${SLURM_ARRAY_TASK_ID} from ${BATCH_DIR}"

python /home/arelbaha/.venvs/brainiac_env/lib/python3.11/site-packages/BrainIAC/src/preprocessing/mri_preprocess_3d_simple.py \
    --temp_img /home/arelbaha/.venvs/brainiac_env/lib/python3.11/site-packages/BrainIAC/src/preprocessing/atlases/temp_head.nii.gz \
    --input_dir "$BATCH_DIR" \
    --output_dir "$BATCH_OUTPUT"

echo "Completed batch ${SLURM_ARRAY_TASK_ID}"

## PPMI Multi-task Evaluation | AnatCL vs BrainIAC vs CNN vs FreeSurfer

In [1]:
%%writefile ~/links/projects/def-glatard/arelbaha/data/inputs/ppmi_multimodel_comparison.py

import os
import glob
import random
import numpy as np
import pandas as pd

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import nibabel as nib
from anatcl import AnatCL
from monai.networks.nets import ViT
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.dummy import DummyClassifier, DummyRegressor
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns
import scipy.ndimage as ndi

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

CAT12_BASE_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs"
DATA_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data"
LABELS_PATH = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/processed_cohort_with_mri.csv"
BRAINIAC_MAPPING_CSV = os.path.join(DATA_DIR, "processed_files_mapping.csv")
ANATCL_ENCODER_PATH = "/home/arelbaha/.venvs/jupyter_py3/bin"
BRAINIAC_CKPT = "/home/arelbaha/.venvs/jupyter_py3/bin/BrainIAC.ckpt"
DROPOUT_RATE = 0.3

device = "cpu"

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"Random seed: {SEED}\n")

# Load demographic labels
print("Loading demographics")

labels_df = pd.read_csv(LABELS_PATH)
id_sex_dict = {}
id_parkinson_dict = {}
id_age_dict = {}

for _, row in labels_df.iterrows():
    patno = str(int(row['PATNO']))
    sex = row['Sex'].strip().upper()
    if sex == 'F':
        id_sex_dict[patno] = 1  # Female = 1 (same as HBN)
    elif sex == 'M':
        id_sex_dict[patno] = 0  # Male = 0 (same as HBN)
    group = row['Group'].strip()
    if group == 'PD':
        id_parkinson_dict[patno] = 1
    elif group == 'HC':
        id_parkinson_dict[patno] = 0
    id_age_dict[patno] = row['Age']

print(f"Loaded demographics for {len(id_sex_dict)} subjects")
print(f"Sex distribution: {sum(id_sex_dict.values())} Female, {len(id_sex_dict) - sum(id_sex_dict.values())} Male")
print(f"PD distribution: {sum(id_parkinson_dict.values())} PD, {len(id_parkinson_dict) - sum(id_parkinson_dict.values())} HC")
print(f"Age range: {min(id_age_dict.values()):.1f} - {max(id_age_dict.values()):.1f} years\n")

def extract_patno_from_path(filepath):
    for part in filepath.split(os.sep):
        if part.startswith('sub-'):
            return part[4:]
    return None

# Load CAT12 VBM data
print("Finding CAT12 (s6mwp1) files")

cat12_files = glob.glob(os.path.join(CAT12_BASE_DIR, "**", "*s6mwp1*.nii*"), recursive=True)
cat12_data = {}
for f in cat12_files:
    if not os.path.isfile(f):
        continue
    patno = extract_patno_from_path(f)
    if patno and patno in id_sex_dict and patno in id_parkinson_dict and patno in id_age_dict:
        cat12_data[patno] = f

print(f"Found {len(cat12_data)} CAT12 files\n")

# Load BrainIAC data
print("Finding BrainIAC preprocessed files")

brainiac_df = pd.read_csv(BRAINIAC_MAPPING_CSV).dropna(subset=["processed_file", "Age", "subject_id"])
brainiac_data = {}
for _, row in brainiac_df.iterrows():
    patno = str(row['subject_id'])
    if patno in id_sex_dict and patno in id_parkinson_dict and patno in id_age_dict:
        if os.path.exists(row['processed_file']):
            brainiac_data[patno] = row['processed_file']

print(f"Found {len(brainiac_data)} BrainIAC files\n")

# Load FreeSurfer data
print("Extracting FreeSurfer features")

fs_cth_df = pd.read_csv(os.path.join(CAT12_BASE_DIR, "FS7_APARC_CTH_23Oct2025.csv"))
fs_sa_df = pd.read_csv(os.path.join(CAT12_BASE_DIR, "FS7_APARC_SA_23Oct2025.csv"))
fs_cth_df = fs_cth_df[fs_cth_df['EVENT_ID'] == 'BL'].copy()
fs_sa_df = fs_sa_df[fs_sa_df['EVENT_ID'] == 'BL'].copy()
fs_cth_df['PATNO'] = fs_cth_df['PATNO'].astype(str)
fs_sa_df['PATNO'] = fs_sa_df['PATNO'].astype(str)
cth_features = [c for c in fs_cth_df.columns if c not in ['PATNO', 'EVENT_ID']]
sa_features = [c for c in fs_sa_df.columns if c not in ['PATNO', 'EVENT_ID']]


fs_feature_info = []
for feat in cth_features:
    fs_feature_info.append({
        'feature_name': feat,
        'feature_type': 'Thickness',
        'source_file': 'FS7_APARC_CTH_23Oct2025.csv'
    })
for feat in sa_features:
    fs_feature_info.append({
        'feature_name': feat,
        'feature_type': 'Surface_Area',
        'source_file': 'FS7_APARC_SA_23Oct2025.csv'
    })

print(f"Stored feature info: {len(cth_features)} Thickness + {len(sa_features)} Surface_Area features")
print(f"Total: {len(fs_feature_info)} FreeSurfer features\n")

# Overall Common Subjects
print("Overall common subjects across all modalities")

common_subjects = sorted(list(set(cat12_data.keys()) & set(brainiac_data.keys())))
fs_data = {}
for patno in common_subjects:
    cth_row = fs_cth_df[fs_cth_df['PATNO'] == patno]
    sa_row = fs_sa_df[fs_sa_df['PATNO'] == patno]
    if len(cth_row) > 0 and len(sa_row) > 0:
        combined = np.concatenate([cth_row[cth_features].values.flatten(), 
                                  sa_row[sa_features].values.flatten()])
        if not np.any(np.isnan(combined)):
            fs_data[patno] = combined

common_subjects = sorted(list(set(cat12_data.keys()) & set(brainiac_data.keys()) & set(fs_data.keys())))

print(f"Subjects with CAT12: {len(cat12_data)}")
print(f"Subjects with BrainIAC: {len(brainiac_data)}")
print(f"Subjects with FreeSurfer: {len(fs_data)}")
print(f"Common subjects (all modalities): {len(common_subjects)}\n")

if len(common_subjects) == 0:
    print("ERROR: No common subjects found across all modalities")
    exit(1)

cat12_paths = [cat12_data[p] for p in common_subjects]
brainiac_paths = [brainiac_data[p] for p in common_subjects]
fs_features = np.array([fs_data[p] for p in common_subjects])
sex_labels = np.array([id_sex_dict[p] for p in common_subjects])
parkinson_labels = np.array([id_parkinson_dict[p] for p in common_subjects])
age_labels = np.array([id_age_dict[p] for p in common_subjects])

print(f"Final matched dataset: {len(common_subjects)} subjects")
print(f"FreeSurfer features shape: {fs_features.shape}")
print(f"Sex distribution: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"PD distribution: {np.sum(parkinson_labels)} PD, {len(parkinson_labels) - np.sum(parkinson_labels)} HC")
print(f"Age: mean={age_labels.mean():.1f}, range={age_labels.min():.1f}-{age_labels.max():.1f}\n")

pd.DataFrame({
    'subject_id': common_subjects,
    'cat12_path': cat12_paths,
    'brainiac_path': brainiac_paths,
    'sex': sex_labels,
    'parkinson': parkinson_labels,
    'age': age_labels
}).to_csv('ppmi_matched_subjects.csv', index=False)
print("Saved: ppmi_matched_subjects.csv\n")

#Dataset Classes

class CAT12VBMDataset(Dataset):
    def __init__(self, data, labels, transform):
        self.data, self.labels, self.transform = data, labels, transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img = self.transform(nib.load(self.data[idx]).get_fdata()).unsqueeze(0)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

class BrainIACDataset(Dataset):
    def __init__(self, paths, labels, transform):
        self.paths, self.labels, self.transform = paths, labels, transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (96, 96, 96):
            img = F.interpolate(torch.from_numpy(img[None, None]), size=(96, 96, 96), 
                              mode='trilinear', align_corners=False).squeeze().numpy()
        img_tensor = torch.from_numpy(img)
        img_normalized = self.transform(img_tensor)
        return img_normalized[None], torch.tensor(self.labels[idx], dtype=torch.float32)

class CNNDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (92, 110, 92):
            img = ndi.zoom(img, [92/img.shape[0], 110/img.shape[1], 92/img.shape[2]], order=1)
        img = (img - img.mean()) / (img.std() + 1e-6)
        return torch.from_numpy(img[None]).float()

class CNN3D(nn.Module):
    def __init__(self, dropout_rate=DROPOUT_RATE):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 64, 3, padding=1)
        self.bn1, self.pool1, self.drop1 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv2 = nn.Conv3d(64, 64, 3, padding=1)
        self.bn2, self.pool2, self.drop2 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv3 = nn.Conv3d(64, 128, 3, padding=1)
        self.bn3, self.pool3, self.drop3 = nn.BatchNorm3d(128), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv4 = nn.Conv3d(128, 256, 3, padding=1)
        self.bn4, self.pool4, self.drop4 = nn.BatchNorm3d(256), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1, self.drop5 = nn.Linear(256, 512), nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.drop1(self.bn1(self.pool1(F.relu(self.conv1(x)))))
        x = self.drop2(self.bn2(self.pool2(F.relu(self.conv2(x)))))
        x = self.drop3(self.bn3(self.pool3(F.relu(self.conv3(x)))))
        x = self.drop4(self.bn4(self.pool4(F.relu(self.conv4(x)))))
        return self.drop5(F.relu(self.fc1(self.global_pool(x).view(x.size(0), -1))))

def load_brainiac_vit(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    vit = ViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16),
              hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, classification=False)
    pe_key = "backbone.patch_embedding.position_embeddings"
    if pe_key in ckpt.get("state_dict", ckpt):
        vit.patch_embedding.position_embeddings = nn.Parameter(ckpt["state_dict"][pe_key].clone())
    backbone_state = {k.replace("backbone.", ""): v for k, v in ckpt.get("state_dict", ckpt).items() 
                     if k.startswith("backbone.")}
    vit.load_state_dict(backbone_state, strict=False)
    return vit.to(device).eval()

# Extracting AnatCL Features
print("Extracting AnatCL features (averaging across 5 cross-validation folds)")

anatcl_transform = transforms.Compose([
    transforms.Lambda(lambda x: torch.from_numpy(x.copy()).float()),
    transforms.Normalize(mean=0.0, std=1.0)
])

num_folds = 5
all_fold_features = []

for fold_idx in range(num_folds):
    path = os.path.join(ANATCL_ENCODER_PATH, f"fold{fold_idx}.pth")
    
    if not os.path.exists(path):
        print(f"ERROR: fold{fold_idx}.pth not found at {path}")
        exit(1)
    
    print(f"Loading fold {fold_idx}...")
    encoder = AnatCL(descriptor="global", fold=fold_idx, pretrained=False).to(device).eval()
    encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
    
    for p in encoder.parameters():
        p.requires_grad = False
    
    dl = DataLoader(CAT12VBMDataset(cat12_paths, age_labels, anatcl_transform), 
                    batch_size=32, num_workers=0)
    
    with torch.no_grad():
        fold_features = torch.cat([encoder(vol.to(device)).cpu() for vol, _ in dl]).numpy()
    
    all_fold_features.append(fold_features)
    print(f"  Fold {fold_idx} features: {fold_features.shape}")
    
    del encoder
    torch.cuda.empty_cache()

anatcl_features = np.mean(all_fold_features, axis=0)

print(f"\nAnatCL features (averaged across {num_folds} folds): {anatcl_features.shape}\n")

#Extracting BrainIAC Features
print("Extracting BrainIAC features")

brainiac_vit = load_brainiac_vit(BRAINIAC_CKPT, device)
brainiac_transform = lambda x: (x - x.mean()) / (x.std() + 1e-6)
dl = DataLoader(BrainIACDataset(brainiac_paths, age_labels, brainiac_transform), 
                batch_size=16, num_workers=0)

brainiac_features = []

with torch.no_grad():
    for x, _ in dl:
        out = brainiac_vit(x.to(device))
        
        # Handle tuple or tensor
        if isinstance(out, tuple):
            cls_token = out[0][:, 0]
        else:
            cls_token = out[:, 0]
        
        brainiac_features.append(cls_token.cpu().numpy())

brainiac_features = np.vstack(brainiac_features)

print(f"BrainIAC features: {brainiac_features.shape}")
del brainiac_vit
torch.cuda.empty_cache()

# Extracting CNN Features
print("Extracting CNN features")

cnn_model = CNN3D().to(device).eval()
dl = DataLoader(CNNDataset(brainiac_paths), batch_size=8, num_workers=0)

with torch.no_grad():
    cnn_features = np.vstack([cnn_model(x.to(device)).cpu().numpy() for x in dl])

print(f"CNN features: {cnn_features.shape}\n")
del cnn_model

# Organize Features
features_dict = {
    'AnatCL': anatcl_features,
    'BrainIAC': brainiac_features,
    'CNN': cnn_features,
    'FreeSurfer': fs_features
}
print("Feature extraction complete")
for name, feats in features_dict.items():
    print(f"{name}: {feats.shape}")
print()

# Pre-filtering variance check
print("Pre-filtering variance check:")
for model in ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer']:
    feats = StandardScaler().fit_transform(features_dict[model])
    n_dead = np.sum(np.std(feats, axis=0) < 1e-10)
    print(f"{model}: {feats.shape[1]} total, {n_dead} dead ({100*n_dead/feats.shape[1]:.1f}%)")
print()

# Correlation Analysis
print("Computing cross-model feature correlations")

def compute_correlation_within_models(features_dict):
    all_features = []
    boundaries = [0]
    reordered_indices = {}

    for model in ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer']:
        if model not in features_dict:
            continue

        feats = features_dict[model]
        feats = StandardScaler().fit_transform(feats)

        valid = ~np.isnan(feats).any(axis=0) & (np.std(feats, axis=0) > 1e-10)
        feats = feats[:, valid]

        if feats.shape[1] == 0:
            continue

        print(f"  {model}: {feats.shape[1]} features")

        corr = np.corrcoef(feats.T)
        corr = np.nan_to_num((corr + corr.T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
        np.fill_diagonal(corr, 1.0)

        dist = np.abs(1 - np.abs(corr))
        np.fill_diagonal(dist, 0)
        condensed = np.nan_to_num(squareform((dist + dist.T) / 2, checks=False), 
                                  nan=0.0, posinf=1.0, neginf=0.0)

        reorder_idx = dendrogram(linkage(condensed, method='ward'), no_plot=True)['leaves']
        all_features.append(feats[:, reorder_idx])
        boundaries.append(boundaries[-1] + feats.shape[1])
        reordered_indices[model] = reorder_idx

    combined = np.hstack(all_features)
    corr = np.nan_to_num((np.corrcoef(combined.T) + np.corrcoef(combined.T).T) / 2, 
                         nan=0.0, posinf=1.0, neginf=-1.0)
    np.fill_diagonal(corr, 1.0)

    print(f"\nTotal features: {combined.shape[1]}")
    
    return corr, boundaries, ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer'], reordered_indices

corr_matrix, boundaries, model_names, reorder_indices = compute_correlation_within_models(features_dict)
pd.DataFrame(corr_matrix).to_csv('ppmi_feature_correlation_matrix.csv', index=False)
print(f"Correlation matrix: {corr_matrix.shape}")
print(f"Model boundaries: {boundaries}\n")

# Plot correlation matrix with PDF output
print(f"Plotting FULL cross-model correlation matrix ({corr_matrix.shape[0]}x{corr_matrix.shape[1]} features)...")

try:
    fig = plt.figure(figsize=(30, 28))
    print("Figure created, generating heatmap")
    
    ax = sns.heatmap(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, center=0, 
                     square=True, linewidths=0, cbar_kws={"shrink": 0.3},
                     xticklabels=False, yticklabels=False, rasterized=True)
    print("Heatmap generated, adding boundaries")
    
    for b in boundaries[1:-1]:
        plt.axhline(y=b, color='black', linewidth=4)
        plt.axvline(x=b, color='black', linewidth=4)
    
    for i, (pos, name) in enumerate(zip([(boundaries[i] + boundaries[i+1]) / 2 
                                          for i in range(len(boundaries)-1)], model_names)):
        plt.text(pos, -20, name, ha='center', fontsize=20, weight='bold')
        plt.text(-20, pos, name, ha='center', va='center', fontsize=20, weight='bold', rotation=90)
    
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    
    print("Saving PDF...")
    plt.savefig('ppmi_cross_model_correlation.pdf', dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Saved: ppmi_cross_model_correlation.pdf\n")
    
except Exception as e:
    print(f"ERROR with seaborn heatmap: {e}")
    print("Attempting fallback with matplotlib imshow")
    
    try:
        fig = plt.figure(figsize=(24, 22))
        
        im = plt.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='equal', 
                       interpolation='nearest', origin='upper')
        plt.colorbar(im, shrink=0.4)
        
        for b in boundaries[1:-1]:
            plt.axhline(y=b-0.5, color='black', linewidth=3)
            plt.axvline(x=b-0.5, color='black', linewidth=3)
        
        for i, (pos, name) in enumerate(zip([(boundaries[i] + boundaries[i+1]) / 2 
                                              for i in range(len(boundaries)-1)], model_names)):
            plt.text(pos, -15, name, ha='center', fontsize=16, weight='bold')
            plt.text(-15, pos, name, ha='center', va='center', fontsize=16, weight='bold', rotation=90)
        
        plt.title(f'Cross-Model Feature Correlations ({corr_matrix.shape[0]} features)', 
                 fontsize=18, weight='bold')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig('ppmi_cross_model_correlation.pdf', dpi=100, bbox_inches='tight')
        plt.close()
        
        print("Saved (imshow fallback): ppmi_cross_model_correlation.pdf\n")
        
    except Exception as e2:
        print(f"ERROR with imshow too: {e2}")
        print("Correlation matrix NOT saved as image (CSV available)\n")

# FREESURFER CORRELATION CLUSTER ANALYSIS
print("ANALYZING FREESURFER CORRELATION CLUSTERS")

# Extract FreeSurfer block from correlation matrix
fs_start_idx = boundaries[3]  # FreeSurfer is 4th model (index 3)
fs_end_idx = boundaries[4]
fs_corr_block = corr_matrix[fs_start_idx:fs_end_idx, fs_start_idx:fs_end_idx]

print(f"\nFreeSurfer correlation block: {fs_corr_block.shape}")

# The features are already reordered by hierarchical clustering
reordered_fs_indices = reorder_indices['FreeSurfer']

# Use the feature info we stored earlier (based on source files)
# No assumptions - just use what we stored
print(f"Using {len(fs_feature_info)} features from source files")

# Apply the reordering to get clustered feature labels
clustered_features = []
for idx in reordered_fs_indices:
    feature_info = fs_feature_info[idx].copy()
    feature_info['original_index'] = idx
    clustered_features.append(feature_info)

# Perform hierarchical clustering on FreeSurfer block to identify clusters
dist_fs = np.abs(1 - np.abs(fs_corr_block))
np.fill_diagonal(dist_fs, 0)
condensed_fs = squareform(dist_fs, checks=False)
linkage_fs = linkage(condensed_fs, method='ward')

# Cut tree to get 2 main clusters
cluster_assignments = fcluster(linkage_fs, 2, criterion='maxclust')

print(f"Identified {len(np.unique(cluster_assignments))} clusters")
print(f"Cluster 1: {np.sum(cluster_assignments == 1)} features")
print(f"Cluster 2: {np.sum(cluster_assignments == 2)} features")

# Create DataFrame with cluster assignments
cluster_df = pd.DataFrame(clustered_features)
cluster_df['cluster'] = cluster_assignments
cluster_df['reordered_index'] = range(len(clustered_features))

# Analyze cluster composition
print("\nCluster Composition:")
for cluster_id in [1, 2]:
    cluster_data = cluster_df[cluster_df['cluster'] == cluster_id]
    n_sa = np.sum(cluster_data['feature_type'] == 'Surface_Area')
    n_thick = np.sum(cluster_data['feature_type'] == 'Thickness')
    print(f"\nCluster {cluster_id}:")
    print(f"  Surface Area (from SA file): {n_sa} features ({100*n_sa/len(cluster_data):.1f}%)")
    print(f"  Thickness (from CTH file): {n_thick} features ({100*n_thick/len(cluster_data):.1f}%)")

# Save detailed cluster information
cluster_df_sorted = cluster_df.sort_values(['cluster', 'feature_type', 'feature_name'])
cluster_df_sorted.to_csv('ppmi_freesurfer_correlation_clusters.csv', index=False)
print("\nSaved: ppmi_freesurfer_correlation_clusters.csv")

# Create separate CSVs for each cluster
for cluster_id in [1, 2]:
    cluster_data = cluster_df[cluster_df['cluster'] == cluster_id].copy()
    cluster_data = cluster_data.sort_values(['feature_type', 'feature_name'])
    cluster_data.to_csv(f'ppmi_freesurfer_cluster{cluster_id}_features.csv', index=False)
    print(f"Saved: ppmi_freesurfer_cluster{cluster_id}_features.csv")

# Create summary of cluster characteristics
cluster_summary = []
for cluster_id in [1, 2]:
    cluster_data = cluster_df[cluster_df['cluster'] == cluster_id]
    for feat_type in ['Surface_Area', 'Thickness']:
        type_data = cluster_data[cluster_data['feature_type'] == feat_type]
        if len(type_data) > 0:
            cluster_summary.append({
                'Cluster': cluster_id,
                'Feature_Type': feat_type,
                'Count': len(type_data),
                'Percentage': 100 * len(type_data) / len(cluster_data),
                'Example_Regions': ', '.join(type_data['feature_name'].head(5).tolist())
            })

pd.DataFrame(cluster_summary).to_csv('ppmi_freesurfer_cluster_summary.csv', index=False)
print("Saved: ppmi_freesurfer_cluster_summary.csv")

# Check for any "misclassified" features
cluster1_data = cluster_df[cluster_df['cluster'] == 1]
cluster2_data = cluster_df[cluster_df['cluster'] == 2]

# Determine which cluster is predominantly SA vs Thickness
cluster1_sa_pct = 100 * np.sum(cluster1_data['feature_type'] == 'Surface_Area') / len(cluster1_data)
cluster2_sa_pct = 100 * np.sum(cluster2_data['feature_type'] == 'Surface_Area') / len(cluster2_data)

if cluster1_sa_pct > 50:
    sa_cluster = 1
    thick_cluster = 2
else:
    sa_cluster = 2
    thick_cluster = 1

# Find thickness features in SA cluster (misclassified)
sa_cluster_data = cluster_df[cluster_df['cluster'] == sa_cluster]
misclassified_thick = sa_cluster_data[sa_cluster_data['feature_type'] == 'Thickness']

# Find SA features in thickness cluster (misclassified)
thick_cluster_data = cluster_df[cluster_df['cluster'] == thick_cluster]
misclassified_sa = thick_cluster_data[thick_cluster_data['feature_type'] == 'Surface_Area']

print(f"\nMisclassified Features:")
print(f"  Thickness (CTH file) features in SA-dominant cluster: {len(misclassified_thick)}")
if len(misclassified_thick) > 0:
    print(f"    Features: {misclassified_thick['feature_name'].tolist()}")
print(f"  Surface Area (SA file) features in Thickness-dominant cluster: {len(misclassified_sa)}")
if len(misclassified_sa) > 0:
    print(f"    Features: {misclassified_sa['feature_name'].tolist()}")
print()

# Model Training Functions

def run_classification(X_dict, y, test_size=0.1, task_name="Classification"):
    print(f"Running {task_name}")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), 
                                        stratify=y, random_state=SEED)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

    print("DummyClassifier")
    dummy_val_auc = []
    dummy_val_acc = []
    
    for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyClassifier(strategy='stratified', random_state=SEED)
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        pred_proba = dummy.predict_proba(np.zeros((len(val_idx), 1)))[:, 1]
        
        dummy_val_acc.append(balanced_accuracy_score(y[val_idx], pred))
        dummy_val_auc.append(roc_auc_score(y[val_idx], pred_proba))
    
    dummy = DummyClassifier(strategy='stratified', random_state=SEED)
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    dummy_test_proba = dummy.predict_proba(np.zeros((len(test_idx), 1)))[:, 1]
    
    results['DummyClassifier'] = {
        'best_alpha': 1.0,
        'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], dummy_test_pred),
        'test_auc': roc_auc_score(y[test_idx], dummy_test_proba),
        'cv_auc_mean': np.mean(dummy_val_auc),
        'cv_auc_std': np.std(dummy_val_auc),
        'cv_acc_mean': np.mean(dummy_val_acc),
        'cv_acc_std': np.std(dummy_val_acc),
        'fold_results': {'val_auc': dummy_val_auc, 'val_acc': dummy_val_acc},
        'test_idx': test_idx
    }
    print(f"DummyClassifier: AUC={results['DummyClassifier']['test_auc']:.4f}, Bal_Acc={results['DummyClassifier']['test_balanced_accuracy']:.4f}")

    for model, X in X_dict.items():
        val_auc = {a: [] for a in alphas}
        val_acc = {a: [] for a in alphas}
        train_auc = {a: [] for a in alphas}
        train_acc = {a: [] for a in alphas}

        for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]

            for alpha in alphas:
                n = int(alpha * len(train_idx))
                try:
                    tr_idx, _ = train_test_split(train_idx, train_size=n, stratify=y[train_idx], random_state=SEED)
                except:
                    tr_idx = np.random.choice(train_idx, n, replace=False)

                if len(np.unique(y[tr_idx])) < 2:
                    continue

                scaler = StandardScaler()
                rf = RandomForestClassifier(n_estimators=200, max_depth=6, min_samples_split=5,
                                           random_state=SEED, n_jobs=1, 
                                           max_features='sqrt', class_weight='balanced')
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])

                pred_proba = rf.predict_proba(scaler.transform(X[val_idx]))[:, 1]
                val_acc[alpha].append(balanced_accuracy_score(y[val_idx], rf.predict(scaler.transform(X[val_idx]))))
                val_auc[alpha].append(roc_auc_score(y[val_idx], pred_proba))
                
                train_pred_proba = rf.predict_proba(scaler.transform(X[tr_idx]))[:, 1]
                train_acc[alpha].append(balanced_accuracy_score(y[tr_idx], rf.predict(scaler.transform(X[tr_idx]))))
                train_auc[alpha].append(roc_auc_score(y[tr_idx], train_pred_proba))

        avg_auc = {a: np.mean(val_auc[a]) for a in alphas if val_auc[a]}
        best_alpha = max(avg_auc, key=avg_auc.get)

        # Calculate train-val gaps for all alphas
        train_val_gaps = {}
        for alpha in alphas:
            if train_auc[alpha] and val_auc[alpha]:
                train_val_gaps[alpha] = np.mean(train_auc[alpha]) - np.mean(val_auc[alpha])

        try:
            final_tr, _ = train_test_split(cv_idx, train_size=int(best_alpha * len(cv_idx)), 
                                          stratify=y[cv_idx], random_state=SEED)
        except:
            final_tr = cv_idx

        scaler = StandardScaler()
        rf = RandomForestClassifier(n_estimators=200, max_depth=6, min_samples_split=5,
                                    random_state=SEED, n_jobs=1, 
                                    max_features='sqrt', class_weight='balanced')
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred_proba = rf.predict_proba(scaler.transform(X[test_idx]))[:, 1]

        train_val_gap = train_val_gaps[best_alpha]
        is_overfitting = train_val_gap > 0.25

        results[model] = {
            'best_alpha': best_alpha,
            'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], rf.predict(scaler.transform(X[test_idx]))),
            'test_auc': roc_auc_score(y[test_idx], test_pred_proba),
            'cv_auc_mean': np.mean(val_auc[best_alpha]),
            'cv_auc_std': np.std(val_auc[best_alpha]),
            'cv_acc_mean': np.mean(val_acc[best_alpha]),
            'cv_acc_std': np.std(val_acc[best_alpha]),
            'train_auc_mean': np.mean(train_auc[best_alpha]),
            'train_val_gap': train_val_gap,
            'train_val_gaps_all_alphas': train_val_gaps,
            'is_overfitting': is_overfitting,
            'fold_results': {
                'val_auc': val_auc[best_alpha],
                'val_acc': val_acc[best_alpha],
                'train_auc': train_auc[best_alpha]
            },
            'cv_results': {
                'alphas': alphas, 
                'acc': {a: np.mean(val_acc[a]) for a in alphas if val_acc[a]},
                'auc': avg_auc,
                'acc_std': {a: np.std(val_acc[a]) for a in alphas if val_acc[a]},
                'auc_std': {a: np.std(val_auc[a]) for a in alphas if val_auc[a]},
                'train_auc': {a: np.mean(train_auc[a]) for a in alphas if train_auc[a]},
                'train_acc': {a: np.mean(train_acc[a]) for a in alphas if train_acc[a]}
            },
            'trained_model': rf,
            'scaler': scaler,
            'test_idx': test_idx
        }
        print(f"{model}: AUC={results[model]['test_auc']:.4f}, Bal_Acc={results[model]['test_balanced_accuracy']:.4f}" + 
              (f" [OVERFITTING: train-val gap={train_val_gap:.3f}]" if is_overfitting else ""))

    return results

def run_regression(X_dict, y, test_size=0.1, task_name="Regression"):
    print(f"Running {task_name}")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), random_state=SEED)
    kf = KFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

    print("  DummyRegressor")
    dummy_val_r2 = []
    dummy_val_mae = []
    
    for train_rel, val_rel in kf.split(cv_idx):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyRegressor(strategy='mean')
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        dummy_val_r2.append(r2_score(y[val_idx], pred))
        dummy_val_mae.append(mean_absolute_error(y[val_idx], pred))
    
    dummy = DummyRegressor(strategy='mean')
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    
    results['DummyRegressor'] = {
        'best_alpha': 1.0,
        'test_r2': r2_score(y[test_idx], dummy_test_pred),
        'test_mae': mean_absolute_error(y[test_idx], dummy_test_pred),
        'test_rmse': np.sqrt(mean_squared_error(y[test_idx], dummy_test_pred)),
        'cv_r2_mean': np.mean(dummy_val_r2),
        'cv_r2_std': np.std(dummy_val_r2),
        'cv_mae_mean': np.mean(dummy_val_mae),
        'cv_mae_std': np.std(dummy_val_mae),
        'fold_results': {'val_r2': dummy_val_r2, 'val_mae': dummy_val_mae},
        'test_idx': test_idx
    }
    print(f"  DummyRegressor: R2={results['DummyRegressor']['test_r2']:.4f}, MAE={results['DummyRegressor']['test_mae']:.2f}")

    for model, X in X_dict.items():
        val_r2 = {a: [] for a in alphas}
        val_mae = {a: [] for a in alphas}
        train_r2 = {a: [] for a in alphas}
        train_mae = {a: [] for a in alphas}

        for train_rel, val_rel in kf.split(cv_idx):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]

            for alpha in alphas:
                n = int(alpha * len(train_idx))
                tr_idx, _ = train_test_split(train_idx, train_size=n, random_state=SEED) \
                           if n < len(train_idx) else (train_idx, None)

                scaler = StandardScaler()
                rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_split=5, 
                                          random_state=SEED, n_jobs=1)
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])

                pred = rf.predict(scaler.transform(X[val_idx]))
                val_r2[alpha].append(r2_score(y[val_idx], pred))
                val_mae[alpha].append(mean_absolute_error(y[val_idx], pred))
                
                train_pred = rf.predict(scaler.transform(X[tr_idx]))
                train_r2[alpha].append(r2_score(y[tr_idx], train_pred))
                train_mae[alpha].append(mean_absolute_error(y[tr_idx], train_pred))

        avg_r2 = {a: np.mean(val_r2[a]) for a in alphas if val_r2[a]}
        best_alpha = max(avg_r2, key=avg_r2.get)

        # Calculate train-val gaps for all alphas
        train_val_gaps = {}
        for alpha in alphas:
            if train_r2[alpha] and val_r2[alpha]:
                train_val_gaps[alpha] = np.mean(train_r2[alpha]) - np.mean(val_r2[alpha])

        n = int(best_alpha * len(cv_idx))
        final_tr, _ = train_test_split(cv_idx, train_size=n, random_state=SEED) \
                     if n < len(cv_idx) else (cv_idx, None)

        scaler = StandardScaler()
        rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_split=5, random_state=SEED, n_jobs=1)
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred = rf.predict(scaler.transform(X[test_idx]))

        train_val_gap = train_val_gaps[best_alpha]
        is_overfitting = train_val_gap > 0.35

        results[model] = {
            'best_alpha': best_alpha,
            'test_r2': r2_score(y[test_idx], test_pred),
            'test_mae': mean_absolute_error(y[test_idx], test_pred),
            'test_rmse': np.sqrt(mean_squared_error(y[test_idx], test_pred)),
            'predictions': test_pred,
            'actual': y[test_idx],
            'cv_r2_mean': np.mean(val_r2[best_alpha]),
            'cv_r2_std': np.std(val_r2[best_alpha]),
            'cv_mae_mean': np.mean(val_mae[best_alpha]),
            'cv_mae_std': np.std(val_mae[best_alpha]),
            'train_r2_mean': np.mean(train_r2[best_alpha]),
            'train_val_gap': train_val_gap,
            'train_val_gaps_all_alphas': train_val_gaps,
            'is_overfitting': is_overfitting,
            'fold_results': {
                'val_r2': val_r2[best_alpha],
                'val_mae': val_mae[best_alpha],
                'train_r2': train_r2[best_alpha]
            },
            'cv_results': {
                'alphas': alphas, 
                'r2': avg_r2, 
                'mae': {a: np.mean(val_mae[a]) for a in alphas if val_mae[a]},
                'r2_std': {a: np.std(val_r2[a]) for a in alphas if val_r2[a]},
                'mae_std': {a: np.std(val_mae[a]) for a in alphas if val_mae[a]},
                'train_r2': {a: np.mean(train_r2[a]) for a in alphas if train_r2[a]},
                'train_mae': {a: np.mean(train_mae[a]) for a in alphas if train_mae[a]}
            },
            'trained_model': rf,
            'scaler': scaler,
            'test_idx': test_idx
        }
        print(f"  {model}: R2={results[model]['test_r2']:.4f}, MAE={results[model]['test_mae']:.2f}" +
              (f" [OVERFITTING: train-val gap={train_val_gap:.3f}]" if is_overfitting else ""))

    return results

# Run Sex Classification
print("Sex Classification")
sex_results = run_classification(features_dict, sex_labels, test_size=0.1, task_name="Sex classification")
print()

# Run Parkinson Classification
print("Parkinson Classification")
parkinson_results = run_classification(features_dict, parkinson_labels, test_size=0.1, task_name="Parkinson classification")
print()

# NEW: SEX-STRATIFIED AGE PREDICTION (matching HBN)
print("AGE PREDICTION - SPLIT BY SEX")

# Male subjects (sex_labels == 0)
male_indices = np.where(sex_labels == 0)[0]
male_features_dict = {k: v[male_indices] for k, v in features_dict.items()}
male_age_labels = age_labels[male_indices]

print(f"\nMale subjects: N={len(male_indices)}")
print(f"Male age: mean={male_age_labels.mean():.1f}, std={male_age_labels.std():.2f}, range={male_age_labels.min():.1f}-{male_age_labels.max():.1f}")

male_age_results = run_regression(male_features_dict, male_age_labels, test_size=0.1, task_name="Male age prediction")
print()

# Female subjects (sex_labels == 1)
female_indices = np.where(sex_labels == 1)[0]
female_features_dict = {k: v[female_indices] for k, v in features_dict.items()}
female_age_labels = age_labels[female_indices]

print(f"\nFemale subjects: N={len(female_indices)}")
print(f"Female age: mean={female_age_labels.mean():.1f}, std={female_age_labels.std():.2f}, range={female_age_labels.min():.1f}-{female_age_labels.max():.1f}")

female_age_results = run_regression(female_features_dict, female_age_labels, test_size=0.1, task_name="Female age prediction")
print()

# Overfitting Analysis
print("Overfitting Analysis")
overfitting_report = []
for model in features_dict.keys():
    sex_overfit = sex_results[model]['is_overfitting']
    parkinson_overfit = parkinson_results[model]['is_overfitting']
    male_age_overfit = male_age_results[model]['is_overfitting']
    female_age_overfit = female_age_results[model]['is_overfitting']
    
    overfitting_report.append({
        'Model': model,
        'Sex_Overfitting': sex_overfit,
        'Sex_Train_Val_Gap': sex_results[model]['train_val_gap'],
        'Parkinson_Overfitting': parkinson_overfit,
        'Parkinson_Train_Val_Gap': parkinson_results[model]['train_val_gap'],
        'Male_Age_Overfitting': male_age_overfit,
        'Male_Age_Train_Val_Gap': male_age_results[model]['train_val_gap'],
        'Female_Age_Overfitting': female_age_overfit,
        'Female_Age_Train_Val_Gap': female_age_results[model]['train_val_gap']
    })
    
    if sex_overfit or parkinson_overfit or male_age_overfit or female_age_overfit:
        print(f"{model}:")
        if sex_overfit:
            print(f"  Sex classification: Train-Val AUC gap = {sex_results[model]['train_val_gap']:.3f}")
        if parkinson_overfit:
            print(f"  Parkinson classification: Train-Val AUC gap = {parkinson_results[model]['train_val_gap']:.3f}")
        if male_age_overfit:
            print(f"  Male age prediction: Train-Val R² gap = {male_age_results[model]['train_val_gap']:.3f}")
        if female_age_overfit:
            print(f"  Female age prediction: Train-Val R² gap = {female_age_results[model]['train_val_gap']:.3f}")

if not any(r['Sex_Overfitting'] or r['Parkinson_Overfitting'] or r['Male_Age_Overfitting'] or r['Female_Age_Overfitting'] for r in overfitting_report):
    print("No significant overfitting detected")

pd.DataFrame(overfitting_report).to_csv('ppmi_overfitting_analysis.csv', index=False)
print("\nSaved: ppmi_overfitting_analysis.csv\n")

# Save train-val gaps for all alphas
print("Saving train-val gaps for all alphas...")
gap_rows = []
for model in features_dict.keys():
    model_alphas = sex_results[model]['cv_results']['alphas']
    for alpha in model_alphas:
        sex_gap = sex_results[model]['train_val_gaps_all_alphas'].get(alpha, np.nan)
        parkinson_gap = parkinson_results[model]['train_val_gaps_all_alphas'].get(alpha, np.nan)
        male_age_gap = male_age_results[model]['train_val_gaps_all_alphas'].get(alpha, np.nan)
        female_age_gap = female_age_results[model]['train_val_gaps_all_alphas'].get(alpha, np.nan)
        gap_rows.append({
            'Model': model,
            'Alpha': alpha,
            'Sex_Train_Val_Gap': sex_gap,
            'Parkinson_Train_Val_Gap': parkinson_gap,
            'Male_Age_Train_Val_Gap': male_age_gap,
            'Female_Age_Train_Val_Gap': female_age_gap
        })

pd.DataFrame(gap_rows).to_csv('ppmi_train_val_gaps_all_alphas.csv', index=False)
print("Saved: ppmi_train_val_gaps_all_alphas.csv\n")

# Saving Results
print("="*80)
print("Saving Results")
print("="*80)

all_models = ['DummyClassifier'] + list(features_dict.keys())

pd.DataFrame([{'Model': m, 'Best_Alpha': sex_results[m]['best_alpha'], 
               'Test_Balanced_Accuracy': sex_results[m]['test_balanced_accuracy'], 
               'Test_AUC': sex_results[m]['test_auc'],
               'CV_AUC_Mean': sex_results[m]['cv_auc_mean'],
               'CV_AUC_Std': sex_results[m]['cv_auc_std']} 
              for m in all_models]).to_csv('ppmi_sex_classification_summary.csv', index=False)

pd.DataFrame([{'Model': m, 'Best_Alpha': parkinson_results[m]['best_alpha'], 
               'Test_Balanced_Accuracy': parkinson_results[m]['test_balanced_accuracy'], 
               'Test_AUC': parkinson_results[m]['test_auc'],
               'CV_AUC_Mean': parkinson_results[m]['cv_auc_mean'],
               'CV_AUC_Std': parkinson_results[m]['cv_auc_std']} 
              for m in all_models]).to_csv('ppmi_parkinson_classification_summary.csv', index=False)

all_models_reg = ['DummyRegressor'] + list(features_dict.keys())

pd.DataFrame([{'Model': m, 'Best_Alpha': male_age_results[m]['best_alpha'], 
               'Test_R2': male_age_results[m]['test_r2'], 
               'Test_MAE': male_age_results[m]['test_mae'], 
               'Test_RMSE': male_age_results[m]['test_rmse'],
               'CV_R2_Mean': male_age_results[m]['cv_r2_mean'],
               'CV_R2_Std': male_age_results[m]['cv_r2_std'],
               'CV_MAE_Mean': male_age_results[m]['cv_mae_mean'],
               'CV_MAE_Std': male_age_results[m]['cv_mae_std']} 
              for m in all_models_reg]).to_csv('ppmi_male_age_prediction_summary.csv', index=False)

pd.DataFrame([{'Model': m, 'Best_Alpha': female_age_results[m]['best_alpha'], 
               'Test_R2': female_age_results[m]['test_r2'], 
               'Test_MAE': female_age_results[m]['test_mae'], 
               'Test_RMSE': female_age_results[m]['test_rmse'],
               'CV_R2_Mean': female_age_results[m]['cv_r2_mean'],
               'CV_R2_Std': female_age_results[m]['cv_r2_std'],
               'CV_MAE_Mean': female_age_results[m]['cv_mae_mean'],
               'CV_MAE_Std': female_age_results[m]['cv_mae_std']} 
              for m in all_models_reg]).to_csv('ppmi_female_age_prediction_summary.csv', index=False)

all_results = [
    {'Model': m, 
     'Sex_AUC': sex_results[m]['test_auc'], 
     'Sex_Bal_Acc': sex_results[m]['test_balanced_accuracy'],
     'Parkinson_AUC': parkinson_results[m]['test_auc'], 
     'Parkinson_Bal_Acc': parkinson_results[m]['test_balanced_accuracy'],
     'Male_Age_R2': male_age_results[m]['test_r2'], 
     'Male_Age_MAE': male_age_results[m]['test_mae'], 
     'Male_Age_RMSE': male_age_results[m]['test_rmse'],
     'Female_Age_R2': female_age_results[m]['test_r2'], 
     'Female_Age_MAE': female_age_results[m]['test_mae'], 
     'Female_Age_RMSE': female_age_results[m]['test_rmse']} 
    for m in list(features_dict.keys())
]
pd.DataFrame(all_results).to_csv('ppmi_detailed_comparison.csv', index=False)

rankings = []
for task in ['Sex_AUC', 'Parkinson_AUC', 'Male_Age_R2', 'Female_Age_R2']:
    for rank, item in enumerate(sorted(all_results, key=lambda x: x[task], reverse=True), 1):
        rankings.append({
            'Task': task.replace('_', ' '), 
            'Rank': rank, 
            'Model': item['Model'], 
            'Score': item[task]
        })
pd.DataFrame(rankings).to_csv('ppmi_model_rankings.csv', index=False)

print("Saved: ppmi_sex_classification_summary.csv")
print("Saved: ppmi_parkinson_classification_summary.csv")
print("Saved: ppmi_male_age_prediction_summary.csv")
print("Saved: ppmi_female_age_prediction_summary.csv")
print("Saved: ppmi_detailed_comparison.csv")
print("Saved: ppmi_model_rankings.csv\n")

# Generate Learning Curve Visualizations (4 PLOTS)
print("="*80)
print("Generating Learning Curves")
print("="*80)

models = list(features_dict.keys())

# SAME MODEL STYLES AS HBN
model_styles = {
    'AnatCL': {'color': '#9B59B6', 'marker': '^', 'linestyle': '-', 'label_prefix': 'FM: '},
    'BrainIAC': {'color': '#E74C3C', 'marker': '^', 'linestyle': '-', 'label_prefix': 'FM: '},
    'CNN': {'color': '#3498DB', 'marker': 'o', 'linestyle': '--', 'label_prefix': ''},
    'FreeSurfer': {'color': '#E67E22', 'marker': 's', 'linestyle': '-', 'label_prefix': ''}
}

n_cv_sex = len(sex_labels) - int(0.1 * len(sex_labels))
n_cv_parkinson = len(parkinson_labels) - int(0.1 * len(parkinson_labels))
n_cv_male = len(male_age_labels) - int(0.1 * len(male_age_labels))
n_cv_female = len(female_age_labels) - int(0.1 * len(female_age_labels))

alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]
training_sizes_sex = [int(alpha * n_cv_sex * 0.8) for alpha in alphas]
training_sizes_parkinson = [int(alpha * n_cv_parkinson * 0.8) for alpha in alphas]
training_sizes_male = [int(alpha * n_cv_male * 0.8) for alpha in alphas]
training_sizes_female = [int(alpha * n_cv_female * 0.8) for alpha in alphas]

# 4 PLOTS: Parkinson, Male Age, Sex, Female Age
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 12))

# PLOT 1: Parkinson classification
for model in models:
    style = model_styles[model]
    acc_means = []
    acc_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in parkinson_results[model]['cv_results']['acc']:
            acc_means.append(parkinson_results[model]['cv_results']['acc'][alpha])
            acc_stds.append(parkinson_results[model]['cv_results']['acc_std'][alpha])
            valid_sizes.append(training_sizes_parkinson[i])
    
    acc_means = np.array(acc_means)
    acc_stds = np.array(acc_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax1.plot(valid_sizes, acc_means, marker=style['marker'], linestyle=style['linestyle'],
             linewidth=3, markersize=10, label=style['label_prefix'] + model, 
             color=style['color'])
    ax1.fill_between(valid_sizes, acc_means - acc_stds, acc_means + acc_stds, 
                      alpha=0.2, color=style['color'])

dummy_acc_cv_pd = parkinson_results['DummyClassifier']['cv_acc_mean']
ax1.axhline(y=dummy_acc_cv_pd, color='gray', linestyle=':', linewidth=2, 
            label=f'Baseline ({dummy_acc_cv_pd:.3f})', alpha=0.7)

ax1.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax1.set_ylabel('Balanced Accuracy', fontsize=12, weight='bold')
ax1.set_title('Parkinson Classification', fontsize=14, weight='bold')
ax1.legend(fontsize=9, loc='lower right')
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 1])

# PLOT 2: Male age prediction
for model in models:
    style = model_styles[model]
    mae_means = []
    mae_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in male_age_results[model]['cv_results']['mae']:
            mae_means.append(male_age_results[model]['cv_results']['mae'][alpha])
            mae_stds.append(male_age_results[model]['cv_results']['mae_std'][alpha])
            valid_sizes.append(training_sizes_male[i])
    
    mae_means = np.array(mae_means)
    mae_stds = np.array(mae_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax2.plot(valid_sizes, mae_means, marker=style['marker'], linestyle=style['linestyle'],
             linewidth=3, markersize=10, label=style['label_prefix'] + model,
             color=style['color'])
    ax2.fill_between(valid_sizes, mae_means - mae_stds, mae_means + mae_stds,
                      alpha=0.2, color=style['color'])

dummy_mae_cv_male = male_age_results['DummyRegressor']['cv_mae_mean']
ax2.axhline(y=dummy_mae_cv_male, color='gray', linestyle=':', linewidth=2, 
            label=f'Baseline ({dummy_mae_cv_male:.2f})', alpha=0.7)

ax2.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax2.set_ylabel('MAE (years)', fontsize=12, weight='bold')
ax2.set_title('Male Age Prediction', fontsize=14, weight='bold')
ax2.legend(fontsize=9, loc='upper right')
ax2.grid(True, alpha=0.3)

# PLOT 3: Sex classification
for model in models:
    style = model_styles[model]
    acc_means = []
    acc_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in sex_results[model]['cv_results']['acc']:
            acc_means.append(sex_results[model]['cv_results']['acc'][alpha])
            acc_stds.append(sex_results[model]['cv_results']['acc_std'][alpha])
            valid_sizes.append(training_sizes_sex[i])
    
    acc_means = np.array(acc_means)
    acc_stds = np.array(acc_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax3.plot(valid_sizes, acc_means, marker=style['marker'], linestyle=style['linestyle'],
             linewidth=3, markersize=10, label=style['label_prefix'] + model, 
             color=style['color'])
    ax3.fill_between(valid_sizes, acc_means - acc_stds, acc_means + acc_stds, 
                      alpha=0.2, color=style['color'])

dummy_acc_cv_sex = sex_results['DummyClassifier']['cv_acc_mean']
ax3.axhline(y=dummy_acc_cv_sex, color='gray', linestyle=':', linewidth=2, 
            label=f'Baseline ({dummy_acc_cv_sex:.3f})', alpha=0.7)

ax3.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax3.set_ylabel('Balanced Accuracy', fontsize=12, weight='bold')
ax3.set_title('Sex Classification', fontsize=14, weight='bold')
ax3.legend(fontsize=9, loc='lower right')
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# PLOT 4: Female age prediction
for model in models:
    style = model_styles[model]
    mae_means = []
    mae_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in female_age_results[model]['cv_results']['mae']:
            mae_means.append(female_age_results[model]['cv_results']['mae'][alpha])
            mae_stds.append(female_age_results[model]['cv_results']['mae_std'][alpha])
            valid_sizes.append(training_sizes_female[i])
    
    mae_means = np.array(mae_means)
    mae_stds = np.array(mae_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax4.plot(valid_sizes, mae_means, marker=style['marker'], linestyle=style['linestyle'],
             linewidth=3, markersize=10, label=style['label_prefix'] + model,
             color=style['color'])
    ax4.fill_between(valid_sizes, mae_means - mae_stds, mae_means + mae_stds,
                      alpha=0.2, color=style['color'])

dummy_mae_cv_female = female_age_results['DummyRegressor']['cv_mae_mean']
ax4.axhline(y=dummy_mae_cv_female, color='gray', linestyle=':', linewidth=2, 
            label=f'Baseline ({dummy_mae_cv_female:.2f})', alpha=0.7)

ax4.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax4.set_ylabel('MAE (years)', fontsize=12, weight='bold')
ax4.set_title('Female Age Prediction', fontsize=14, weight='bold')
ax4.legend(fontsize=9, loc='upper right')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('ppmi_cv_learning_curves.pdf', dpi=300, bbox_inches='tight')
plt.close()
print("Saved: ppmi_cv_learning_curves.pdf\n")

print(f"\nDataset: {len(common_subjects)} subjects")
print(f"Age range: {age_labels.min():.1f} - {age_labels.max():.1f} years")
print(f"Sex: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"Males: {len(male_indices)} subjects")
print(f"Females: {len(female_indices)} subjects")
print(f"PD: {np.sum(parkinson_labels)} PD, {len(parkinson_labels) - np.sum(parkinson_labels)} HC")

Overwriting /home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs/ppmi_multimodel_comparison.py


## BrainIAC and CNN Preprocessing | HD-BET | HBN

In [None]:
%%sbatch --array=1-100
#!/bin/bash
#SBATCH --account=def-glatard
#SBATCH --time=12:00:00
#SBATCH --mem=16G
#SBATCH --cpus-per-task=4
#SBATCH --job-name=brainiac_batch
#SBATCH --output=logs/prep_batch_%a.out
#SBATCH --error=logs/prep_batch_%a.err

export BASE_DIR="/home/arelbaha/links/projects/rrg-glatard/arelbaha"
export RAW_DIR="${BASE_DIR}/brainiac_p_files"
export OUTPUT_DIR="${BASE_DIR}/brainiac_p_outputs"

module load python/3.11
module load opencv
source /home/arelbaha/.venvs/brainiac_env/bin/activate

#Batch
BATCH_NUM=$(printf "%03d" ${SLURM_ARRAY_TASK_ID})
BATCH_DIR="${RAW_DIR}/batch_${BATCH_NUM}"
BATCH_OUTPUT="${OUTPUT_DIR}/batch_${BATCH_NUM}"

mkdir -p "$BATCH_OUTPUT"
echo "Processing batch ${BATCH_NUM} from ${BATCH_DIR}"

#Preprocessing
python /home/arelbaha/.venvs/brainiac_env/lib/python3.11/site-packages/BrainIAC/src/preprocessing/mri_preprocess_3d_simple.py \
    --temp_img /home/arelbaha/.venvs/brainiac_env/lib/python3.11/site-packages/BrainIAC/src/preprocessing/atlases/temp_head.nii.gz \
    --input_dir "$BATCH_DIR" \
    --output_dir "$BATCH_OUTPUT"

echo "Completed batch ${BATCH_NUM}"

## HBN Multi-task Evaluation | AnatCl vs BrainIAC vs CNN vs FreeSurfer

In [10]:
%%writefile /home/arelbaha/links/projects/rrg-glatard/arelbaha/debug_hbn_multimodel_comparison.py

import os
import glob
import random
import numpy as np
import pandas as pd
import fnmatch
import json
import pickle
from nilearn import plotting, datasets

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import nibabel as nib
from anatcl import AnatCL
from monai.networks.nets import ViT
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.inspection import permutation_importance
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.ndimage as ndi
from matplotlib.patches import Patch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

# PATHS
HBN_BIDS = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_BIDS"
HBN_BIDS_LOWER = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/hbn_bids"
BRAINIAC_OUT = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/brainiac_p_outputs"
FREESURFER_DIR = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_FreeSurfer/freesurfer"
DEMO_FILE = os.path.join(HBN_BIDS, "final_preprocessed_subjects_with_demographics.tsv")
ANATCL_ENCODER_PATH = "/home/arelbaha/.venvs/jupyter_py3/bin"
BRAINIAC_CKPT = "/home/arelbaha/.venvs/jupyter_py3/bin/BrainIAC.ckpt"
DROPOUT_RATE = 0.3
PARCELLATION = "Schaefer2018_400Parcels_17Networks_order"

device = "cpu"

#Loading Demographics
print("Loading demographics")

demo_df = pd.read_csv(DEMO_FILE, sep='\t')
demo_df['participant_id'] = demo_df['participant_id'].astype(str)

id_sex_dict = {}
id_age_dict = {}

for _, row in demo_df.iterrows():
    subject_id = row['participant_id']
    sex = row['sex'].strip()
    if sex == 'Female':
        id_sex_dict[subject_id] = 1
    elif sex == 'Male':
        id_sex_dict[subject_id] = 0
    id_age_dict[subject_id] = row['age']

print(f"Loaded demographics for {len(id_sex_dict)} subjects")
print(f"Sex distribution: {sum(id_sex_dict.values())} Female, {len(id_sex_dict) - sum(id_sex_dict.values())} Male")
print(f"Age range: {min(id_age_dict.values()):.1f} - {max(id_age_dict.values()):.1f} years\n")

#Finding CAT12 Files
print("Finding CAT12 (s6mwp1) files")

cat12_data = {}

print(f"Searching in {HBN_BIDS}")
for subject_id in id_sex_dict.keys():
    pattern = os.path.join(HBN_BIDS, f"sub-{subject_id}", "ses-*", "anat", "mri", "s6mwp1sub*.nii")
    files = glob.glob(pattern)
    if files:
        cat12_data[subject_id] = files[0]

print(f"Searching in {HBN_BIDS_LOWER}")
for subject_id in id_sex_dict.keys():
    if subject_id not in cat12_data:
        pattern = os.path.join(HBN_BIDS_LOWER, f"sub-{subject_id}", "ses-*", "anat", "mri", "s6mwp1sub*.nii")
        files = glob.glob(pattern)
        if files:
            cat12_data[subject_id] = files[0]

print(f"Found {len(cat12_data)} CAT12 files\n")

#BrainIAC Files
print("Finding BrainIAC preprocessed files")

brainiac_data = {}
batch_dirs = glob.glob(os.path.join(BRAINIAC_OUT, "batch_*"))
print(f"Searching in {len(batch_dirs)} batch directories...")

for batch_dir in sorted(batch_dirs):
    files = glob.glob(os.path.join(batch_dir, "sub-*_0000.nii.gz"))
    for f in files:
        basename = os.path.basename(f)
        subject_id = basename.split('_')[0].replace('sub-', '')
        
        if subject_id in id_sex_dict and subject_id not in brainiac_data:
            brainiac_data[subject_id] = f

print(f"Found {len(brainiac_data)} BrainIAC files\n")

#Extracting FreeSurfer Features
print("Extracting FreeSurfer features")

fs_data = {}
fs_region_info = {}  # Store region names and hemisphere info

for subject_id in id_sex_dict.keys():
    subject_dir = os.path.join(FREESURFER_DIR, f"sub-{subject_id}")
    stats_file = os.path.join(subject_dir, f"sub-{subject_id}_regionsurfacestats.tsv")
    
    if os.path.exists(stats_file):
        try:
            df = pd.read_csv(stats_file, sep='\t')
            filtered_df = df[df["atlas"] == PARCELLATION]
            
            if not filtered_df.empty:
                filtered_df = filtered_df.sort_values("StructName")
                
                if "SurfArea" in filtered_df.columns and "ThickAvg" in filtered_df.columns:
                    surf_area = filtered_df["SurfArea"].values[:400]
                    thick_avg = filtered_df["ThickAvg"].values[:400]
                    combined = np.concatenate([surf_area, thick_avg])
                    
                    if not np.any(np.isnan(combined)):
                        fs_data[subject_id] = combined
                        
                        # Store region info from first valid subject
                        if not fs_region_info:
                            fs_region_info = {
                                'region_names': filtered_df['StructName'].values[:400].tolist(),
                                'hemisphere': filtered_df['hemisphere'].values[:400].tolist(),
                                'parcellation': PARCELLATION
                            }
        except Exception as e:
            print(f"Error reading {stats_file}: {e}")

print(f"Found {len(fs_data)} FreeSurfer subjects with 800 features (400 SurfArea + 400 ThickAvg)\n")

#Overall Common Subjects
print("Overall common subjects across all modalities")

common_subjects = sorted(list(
    set(cat12_data.keys()) & 
    set(brainiac_data.keys()) & 
    set(fs_data.keys())
))

print(f"Subjects with CAT12: {len(cat12_data)}")
print(f"Subjects with BrainIAC: {len(brainiac_data)}")
print(f"Subjects with FreeSurfer: {len(fs_data)}")
print(f"Common subjects (all modalities): {len(common_subjects)}\n")

if len(common_subjects) == 0:
    print("ERROR: No common subjects found across all modalities")
    exit(1)

cat12_paths = [cat12_data[s] for s in common_subjects]
brainiac_paths = [brainiac_data[s] for s in common_subjects]
fs_features = np.array([fs_data[s] for s in common_subjects])
sex_labels = np.array([id_sex_dict[s] for s in common_subjects])
age_labels = np.array([id_age_dict[s] for s in common_subjects])

print(f"Final matched dataset: {len(common_subjects)} subjects")
print(f"FreeSurfer features shape: {fs_features.shape}")
print(f"Sex labels: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"Age: mean={age_labels.mean():.1f}, std={age_labels.std():.2f}, range={age_labels.min():.1f}-{age_labels.max():.1f}")
print(f"Expected baseline MAE: {age_labels.std() * np.sqrt(2/np.pi):.2f} years\n")

pd.DataFrame({
    'subject_id': common_subjects,
    'cat12_path': cat12_paths,
    'brainiac_path': brainiac_paths,
    'sex': sex_labels,
    'age': age_labels
}).to_csv('hbn_matched_subjects.csv', index=False)
print("Saved: hbn_matched_subjects.csv\n")

#Dataset Classes

class CAT12VBMDataset(Dataset):
    def __init__(self, data, labels, transform):
        self.data, self.labels, self.transform = data, labels, transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img = self.transform(nib.load(self.data[idx]).get_fdata()).unsqueeze(0)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

class BrainIACDataset(Dataset):
    def __init__(self, paths, labels, transform):
        self.paths, self.labels, self.transform = paths, labels, transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (96, 96, 96):
            img = F.interpolate(torch.from_numpy(img[None, None]), size=(96, 96, 96), 
                              mode='trilinear', align_corners=False).squeeze().numpy()
        img_tensor = torch.from_numpy(img)
        img_normalized = self.transform(img_tensor)
        return img_normalized[None], torch.tensor(self.labels[idx], dtype=torch.float32)

class CNNDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (92, 110, 92):
            img = ndi.zoom(img, [92/img.shape[0], 110/img.shape[1], 92/img.shape[2]], order=1)
        img = (img - img.mean()) / (img.std() + 1e-6)
        return torch.from_numpy(img[None]).float()

class CNN3D(nn.Module):
    def __init__(self, dropout_rate=DROPOUT_RATE):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 64, 3, padding=1)
        self.bn1, self.pool1, self.drop1 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv2 = nn.Conv3d(64, 64, 3, padding=1)
        self.bn2, self.pool2, self.drop2 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv3 = nn.Conv3d(64, 128, 3, padding=1)
        self.bn3, self.pool3, self.drop3 = nn.BatchNorm3d(128), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv4 = nn.Conv3d(128, 256, 3, padding=1)
        self.bn4, self.pool4, self.drop4 = nn.BatchNorm3d(256), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1, self.drop5 = nn.Linear(256, 512), nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.drop1(self.bn1(self.pool1(F.relu(self.conv1(x)))))
        x = self.drop2(self.bn2(self.pool2(F.relu(self.conv2(x)))))
        x = self.drop3(self.bn3(self.pool3(F.relu(self.conv3(x)))))
        x = self.drop4(self.bn4(self.pool4(F.relu(self.conv4(x)))))
        return self.drop5(F.relu(self.fc1(self.global_pool(x).view(x.size(0), -1))))

def load_brainiac_vit(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    vit = ViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16),
              hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, classification=False)
    pe_key = "backbone.patch_embedding.position_embeddings"
    if pe_key in ckpt.get("state_dict", ckpt):
        vit.patch_embedding.position_embeddings = nn.Parameter(ckpt["state_dict"][pe_key].clone())
    backbone_state = {k.replace("backbone.", ""): v for k, v in ckpt.get("state_dict", ckpt).items() 
                     if k.startswith("backbone.")}
    vit.load_state_dict(backbone_state, strict=False)
    return vit.to(device).eval()

#AnatCL Features

anatcl_transform = transforms.Compose([
    transforms.Lambda(lambda x: torch.from_numpy(x.copy()).float()),
    transforms.Normalize(mean=0.0, std=1.0)
])

num_folds = 5
all_fold_features = []

for fold_idx in range(num_folds):
    path = os.path.join(ANATCL_ENCODER_PATH, f"fold{fold_idx}.pth")
    
    print(f"Loading fold {fold_idx}...")
    encoder = AnatCL(descriptor="global", fold=fold_idx, pretrained=False).to(device).eval()
    encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
    
    for p in encoder.parameters():
        p.requires_grad = False
    
    dl = DataLoader(CAT12VBMDataset(cat12_paths, age_labels, anatcl_transform), 
                    batch_size=32, num_workers=0)
    
    with torch.no_grad():
        fold_features = torch.cat([encoder(vol.to(device)).cpu() for vol, _ in dl]).numpy()
    
    all_fold_features.append(fold_features)
    print(f"Fold {fold_idx} features: {fold_features.shape}")
    
    del encoder

anatcl_features = np.mean(all_fold_features, axis=0)

print(f"\nAnatCL features (averaged across {num_folds} fold): {anatcl_features.shape}\n")

#BrainIAC Features
print("BrainIAC features")

brainiac_vit = load_brainiac_vit(BRAINIAC_CKPT, device)
brainiac_transform = lambda x: (x - x.mean()) / (x.std() + 1e-6)
dl = DataLoader(BrainIACDataset(brainiac_paths, age_labels, brainiac_transform), 
                batch_size=16, num_workers=0)

brainiac_features = []

with torch.no_grad():
    for x, _ in dl:
        out = brainiac_vit(x.to(device))
        
        #Handle tuple or tensor
        if isinstance(out, tuple):
            cls_token = out[0][:, 0]
        else:
            cls_token = out[:, 0]
        
        brainiac_features.append(cls_token.cpu().numpy())

brainiac_features = np.vstack(brainiac_features)

print(f"BrainIAC features: {brainiac_features.shape}")
del brainiac_vit

#CNN Features
print("CNN features")

cnn_model = CNN3D().to(device).eval()
dl = DataLoader(CNNDataset(brainiac_paths), batch_size=8, num_workers=0)

with torch.no_grad():
    cnn_features = np.vstack([cnn_model(x.to(device)).cpu().numpy() for x in dl])

print(f"CNN features: {cnn_features.shape}\n")
del cnn_model

#Organize Features
features_dict = {
    'AnatCL': anatcl_features,
    'BrainIAC': brainiac_features,
    'CNN': cnn_features,
    'FreeSurfer': fs_features
}
print("Feature extraction complete")
for name, feats in features_dict.items():
    print(f"{name}: {feats.shape}")
print()

#Pre-filtering variance check
print("Pre-filtering variance check:")
for model in ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer']:
    feats = StandardScaler().fit_transform(features_dict[model])
    n_dead = np.sum(np.std(feats, axis=0) < 1e-10)
    print(f"{model}: {feats.shape[1]} total, {n_dead} dead ({100*n_dead/feats.shape[1]:.1f}%)")
print()

#Correlation Analysis
print("Computing cross-model feature correlations")

def compute_correlation_within_models(features_dict):
    all_features = []
    boundaries = [0]
    reordered_indices = {}

    for model in ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer']:
        if model not in features_dict:
            continue

        feats = features_dict[model]
        feats = StandardScaler().fit_transform(feats)

        valid = ~np.isnan(feats).any(axis=0) & (np.std(feats, axis=0) > 1e-10)
        feats = feats[:, valid]

        if feats.shape[1] == 0:
            continue

        print(f"  {model}: {feats.shape[1]} features")

        corr = np.corrcoef(feats.T)
        corr = np.nan_to_num((corr + corr.T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
        np.fill_diagonal(corr, 1.0)

        dist = np.abs(1 - np.abs(corr))
        np.fill_diagonal(dist, 0)

        condensed = np.nan_to_num(
            squareform((dist + dist.T) / 2, checks=False),
            nan=0.0, posinf=1.0, neginf=0.0
        )

        reorder_idx = dendrogram(linkage(condensed, method='ward'), no_plot=True)['leaves']
        all_features.append(feats[:, reorder_idx])
        boundaries.append(boundaries[-1] + feats.shape[1])
        reordered_indices[model] = reorder_idx

    combined = np.hstack(all_features)

    corr = np.corrcoef(combined.T)
    corr = np.nan_to_num((corr + corr.T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
    np.fill_diagonal(corr, 1.0)

    print(f"\nTotal features: {combined.shape[1]}")
    
    return corr, boundaries, ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer'], reordered_indices
    
corr_matrix, boundaries, model_names, reorder_indices = compute_correlation_within_models(features_dict)
pd.DataFrame(corr_matrix).to_csv('hbn_feature_correlation_matrix.csv', index=False)
print(f"Correlation matrix: {corr_matrix.shape}")
print(f"Model boundaries: {boundaries}\n")

print("ANALYZING FREESURFER CORRELATION CLUSTERS")

# Extract FreeSurfer block from correlation matrix
fs_start_idx = boundaries[3]  # FreeSurfer is 4th model (index 3)
fs_end_idx = boundaries[4]
fs_corr_block = corr_matrix[fs_start_idx:fs_end_idx, fs_start_idx:fs_end_idx]

print(f"\nFreeSurfer correlation block: {fs_corr_block.shape}")

# The features are already reordered by hierarchical clustering
# Now identify the two main clusters
reordered_fs_indices = reorder_indices['FreeSurfer']

# Map back to original feature indices and names
region_names = fs_region_info['region_names']
hemisphere = fs_region_info['hemisphere']

# Create feature labels (first 400 = Surface Area, last 400 = Thickness)
fs_feature_labels = []
for i in range(400):
    fs_feature_labels.append({
        'original_index': i,
        'region_name': region_names[i],
        'hemisphere': hemisphere[i],
        'feature_type': 'Surface_Area'
    })
for i in range(400):
    fs_feature_labels.append({
        'original_index': i + 400,
        'region_name': region_names[i],
        'hemisphere': hemisphere[i],
        'feature_type': 'Thickness'
    })

# Apply the reordering to get clustered feature labels
clustered_features = [fs_feature_labels[i] for i in reordered_fs_indices]

# Perform hierarchical clustering on FreeSurfer block to identify 2 clusters
dist_fs = np.abs(1 - np.abs(fs_corr_block))
np.fill_diagonal(dist_fs, 0)
condensed_fs = squareform(dist_fs, checks=False)
linkage_fs = linkage(condensed_fs, method='ward')

# Cut tree to get 2 clusters
cluster_assignments = fcluster(linkage_fs, 2, criterion='maxclust')

print(f"Identified {len(np.unique(cluster_assignments))} clusters")
print(f"Cluster 1: {np.sum(cluster_assignments == 1)} features")
print(f"Cluster 2: {np.sum(cluster_assignments == 2)} features")

# Create DataFrame with cluster assignments
cluster_df = pd.DataFrame(clustered_features)
cluster_df['cluster'] = cluster_assignments
cluster_df['reordered_index'] = range(len(clustered_features))

# Analyze cluster composition
print("\nCluster Composition:")
for cluster_id in [1, 2]:
    cluster_data = cluster_df[cluster_df['cluster'] == cluster_id]
    n_sa = np.sum(cluster_data['feature_type'] == 'Surface_Area')
    n_thick = np.sum(cluster_data['feature_type'] == 'Thickness')
    print(f"\nCluster {cluster_id}:")
    print(f"Surface Area: {n_sa} features ({100*n_sa/len(cluster_data):.1f}%)")
    print(f"Thickness: {n_thick} features ({100*n_thick/len(cluster_data):.1f}%)")

# Save detailed cluster information
cluster_df_sorted = cluster_df.sort_values(['cluster', 'feature_type', 'region_name'])
cluster_df_sorted.to_csv('freesurfer_correlation_clusters.csv', index=False)
print("\nSaved: freesurfer_correlation_clusters.csv")

# Create separate CSVs for each cluster
for cluster_id in [1, 2]:
    cluster_data = cluster_df[cluster_df['cluster'] == cluster_id].copy()
    cluster_data = cluster_data.sort_values(['feature_type', 'region_name'])
    cluster_data.to_csv(f'freesurfer_cluster{cluster_id}_features.csv', index=False)
    print(f"Saved: freesurfer_cluster{cluster_id}_features.csv")

# Create summary of cluster characteristics
cluster_summary = []
for cluster_id in [1, 2]:
    cluster_data = cluster_df[cluster_df['cluster'] == cluster_id]
    for feat_type in ['Surface_Area', 'Thickness']:
        type_data = cluster_data[cluster_data['feature_type'] == feat_type]
        cluster_summary.append({
            'Cluster': cluster_id,
            'Feature_Type': feat_type,
            'Count': len(type_data),
            'Percentage': 100 * len(type_data) / len(cluster_data),
            'Brain_Regions': ', '.join(type_data['region_name'].head(10).tolist()) + '...'
        })

pd.DataFrame(cluster_summary).to_csv('freesurfer_cluster_summary.csv', index=False)
print("Saved: freesurfer_cluster_summary.csv\n")

# Plot correlation matrix WITH PDF OUTPUT
print(f"Plotting FULL cross-model correlation matrix ({corr_matrix.shape[0]}x{corr_matrix.shape[1]} features)...")

try:
    fig = plt.figure(figsize=(30, 28))
    print("Figure created, generating heatmap")
    
    ax = sns.heatmap(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, center=0, 
                     square=True, linewidths=0, cbar_kws={"shrink": 0.3},
                     xticklabels=False, yticklabels=False, rasterized=True)
    print("Heatmap generated, adding boundaries")
    
    for b in boundaries[1:-1]:
        plt.axhline(y=b, color='black', linewidth=4)
        plt.axvline(x=b, color='black', linewidth=4)
    
    for i, (pos, name) in enumerate(zip([(boundaries[i] + boundaries[i+1]) / 2 
                                          for i in range(len(boundaries)-1)], model_names)):
        plt.text(pos, -20, name, ha='center', fontsize=20, weight='bold')
        plt.text(-20, pos, name, ha='center', va='center', fontsize=20, weight='bold', rotation=90)
    
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    
    print("Saving PDF...")
    plt.savefig('hbn_cross_model_correlation.pdf', dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Saved: hbn_cross_model_correlation.pdf\n")
    
except Exception as e:
    print(f"ERROR: {e}\n")

#Model Training Functions

def run_classification(X_dict, y, test_size=0.1, task_name="Classification"):
    print(f"Running {task_name}")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), 
                                        stratify=y, random_state=SEED)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

    print("DummyClassifier")
    dummy_val_auc = []
    dummy_val_acc = []
    
    for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyClassifier(strategy='stratified', random_state=SEED)
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        pred_proba = dummy.predict_proba(np.zeros((len(val_idx), 1)))[:, 1]
        
        dummy_val_acc.append(balanced_accuracy_score(y[val_idx], pred))
        dummy_val_auc.append(roc_auc_score(y[val_idx], pred_proba))
    
    dummy = DummyClassifier(strategy='stratified', random_state=SEED)
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    dummy_test_proba = dummy.predict_proba(np.zeros((len(test_idx), 1)))[:, 1]
    
    results['DummyClassifier'] = {
        'best_alpha': 1.0,
        'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], dummy_test_pred),
        'test_auc': roc_auc_score(y[test_idx], dummy_test_proba),
        'cv_auc_mean': np.mean(dummy_val_auc),
        'cv_auc_std': np.std(dummy_val_auc),
        'cv_acc_mean': np.mean(dummy_val_acc),
        'cv_acc_std': np.std(dummy_val_acc),
        'fold_results': {'val_auc': dummy_val_auc, 'val_acc': dummy_val_acc},
        'test_idx': test_idx
    }
    print(f"DummyClassifier: AUC={results['DummyClassifier']['test_auc']:.4f}, Bal_Acc={results['DummyClassifier']['test_balanced_accuracy']:.4f}")

    for model, X in X_dict.items():
        val_auc = {a: [] for a in alphas}
        val_acc = {a: [] for a in alphas}
        train_auc = {a: [] for a in alphas}
        train_acc = {a: [] for a in alphas}

        for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]

            for alpha in alphas:
                n = int(alpha * len(train_idx))
                try:
                    tr_idx, _ = train_test_split(train_idx, train_size=n, stratify=y[train_idx], random_state=SEED)
                except:
                    tr_idx = np.random.choice(train_idx, n, replace=False)

                if len(np.unique(y[tr_idx])) < 2:
                    continue

                scaler = StandardScaler()
                rf = RandomForestClassifier(n_estimators=200, max_depth=6, min_samples_split=5,
                                           random_state=SEED, n_jobs=1, 
                                           max_features='sqrt', class_weight='balanced')
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])

                pred_proba = rf.predict_proba(scaler.transform(X[val_idx]))[:, 1]
                val_acc[alpha].append(balanced_accuracy_score(y[val_idx], rf.predict(scaler.transform(X[val_idx]))))
                val_auc[alpha].append(roc_auc_score(y[val_idx], pred_proba))
                
                train_pred_proba = rf.predict_proba(scaler.transform(X[tr_idx]))[:, 1]
                train_acc[alpha].append(balanced_accuracy_score(y[tr_idx], rf.predict(scaler.transform(X[tr_idx]))))
                train_auc[alpha].append(roc_auc_score(y[tr_idx], train_pred_proba))

        avg_auc = {a: np.mean(val_auc[a]) for a in alphas if val_auc[a]}
        best_alpha = max(avg_auc, key=avg_auc.get)

        train_val_gaps = {}
        for alpha in alphas:
            if train_auc[alpha] and val_auc[alpha]:
                train_val_gaps[alpha] = np.mean(train_auc[alpha]) - np.mean(val_auc[alpha])

        try:
            final_tr, _ = train_test_split(cv_idx, train_size=int(best_alpha * len(cv_idx)), 
                                          stratify=y[cv_idx], random_state=SEED)
        except:
            final_tr = cv_idx

        scaler = StandardScaler()
        rf = RandomForestClassifier(n_estimators=200, max_depth=6, min_samples_split=5,
                                    random_state=SEED, n_jobs=1, 
                                    max_features='sqrt', class_weight='balanced')
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred_proba = rf.predict_proba(scaler.transform(X[test_idx]))[:, 1]

        train_val_gap = train_val_gaps[best_alpha]
        is_overfitting = train_val_gap > 0.25

        results[model] = {
            'best_alpha': best_alpha,
            'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], rf.predict(scaler.transform(X[test_idx]))),
            'test_auc': roc_auc_score(y[test_idx], test_pred_proba),
            'cv_auc_mean': np.mean(val_auc[best_alpha]),
            'cv_auc_std': np.std(val_auc[best_alpha]),
            'cv_acc_mean': np.mean(val_acc[best_alpha]),
            'cv_acc_std': np.std(val_acc[best_alpha]),
            'train_auc_mean': np.mean(train_auc[best_alpha]),
            'train_val_gap': train_val_gap,
            'train_val_gaps_all_alphas': train_val_gaps,
            'is_overfitting': is_overfitting,
            'fold_results': {
                'val_auc': val_auc[best_alpha],
                'val_acc': val_acc[best_alpha],
                'train_auc': train_auc[best_alpha]
            },
            'cv_results': {
                'alphas': alphas, 
                'acc': {a: np.mean(val_acc[a]) for a in alphas if val_acc[a]},
                'auc': avg_auc,
                'acc_std': {a: np.std(val_acc[a]) for a in alphas if val_acc[a]},
                'auc_std': {a: np.std(val_auc[a]) for a in alphas if val_auc[a]},
                'train_auc': {a: np.mean(train_auc[a]) for a in alphas if train_auc[a]},
                'train_acc': {a: np.mean(train_acc[a]) for a in alphas if train_acc[a]}
            },
            'trained_model': rf,
            'scaler': scaler,
            'test_idx': test_idx
        }
        print(f"{model}: AUC={results[model]['test_auc']:.4f}, Bal_Acc={results[model]['test_balanced_accuracy']:.4f}" + 
              (f" [OVERFITTING: train-val gap={train_val_gap:.3f}]" if is_overfitting else ""))

    return results

def run_regression(X_dict, y, test_size=0.1, task_name="Regression"):
    print(f"Running {task_name}")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), random_state=SEED)
    kf = KFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

    print("DummyRegressor")
    dummy_val_r2 = []
    dummy_val_mae = []
    
    for train_rel, val_rel in kf.split(cv_idx):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyRegressor(strategy='mean')
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        dummy_val_r2.append(r2_score(y[val_idx], pred))
        dummy_val_mae.append(mean_absolute_error(y[val_idx], pred))
    
    dummy = DummyRegressor(strategy='mean')
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    
    results['DummyRegressor'] = {
        'best_alpha': 1.0,
        'test_r2': r2_score(y[test_idx], dummy_test_pred),
        'test_mae': mean_absolute_error(y[test_idx], dummy_test_pred),
        'test_rmse': np.sqrt(mean_squared_error(y[test_idx], dummy_test_pred)),
        'cv_r2_mean': np.mean(dummy_val_r2),
        'cv_r2_std': np.std(dummy_val_r2),
        'cv_mae_mean': np.mean(dummy_val_mae),
        'cv_mae_std': np.std(dummy_val_mae),
        'fold_results': {'val_r2': dummy_val_r2, 'val_mae': dummy_val_mae},
        'test_idx': test_idx
    }
    print(f"DummyRegressor: R2={results['DummyRegressor']['test_r2']:.4f}, MAE={results['DummyRegressor']['test_mae']:.2f}")

    for model, X in X_dict.items():
        val_r2 = {a: [] for a in alphas}
        val_mae = {a: [] for a in alphas}
        train_r2 = {a: [] for a in alphas}
        train_mae = {a: [] for a in alphas}

        for train_rel, val_rel in kf.split(cv_idx):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]

            for alpha in alphas:
                n = int(alpha * len(train_idx))
                tr_idx, _ = train_test_split(train_idx, train_size=n, random_state=SEED) \
                           if n < len(train_idx) else (train_idx, None)

                scaler = StandardScaler()
                rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_split=5, 
                                          random_state=SEED, n_jobs=1)
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])

                pred = rf.predict(scaler.transform(X[val_idx]))
                val_r2[alpha].append(r2_score(y[val_idx], pred))
                val_mae[alpha].append(mean_absolute_error(y[val_idx], pred))
                
                train_pred = rf.predict(scaler.transform(X[tr_idx]))
                train_r2[alpha].append(r2_score(y[tr_idx], train_pred))
                train_mae[alpha].append(mean_absolute_error(y[tr_idx], train_pred))

        avg_r2 = {a: np.mean(val_r2[a]) for a in alphas if val_r2[a]}
        best_alpha = max(avg_r2, key=avg_r2.get)

        train_val_gaps = {}
        for alpha in alphas:
            if train_r2[alpha] and val_r2[alpha]:
                train_val_gaps[alpha] = np.mean(train_r2[alpha]) - np.mean(val_r2[alpha])

        n = int(best_alpha * len(cv_idx))
        final_tr, _ = train_test_split(cv_idx, train_size=n, random_state=SEED) \
                     if n < len(cv_idx) else (cv_idx, None)

        scaler = StandardScaler()
        rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_split=5, random_state=SEED, n_jobs=1)
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred = rf.predict(scaler.transform(X[test_idx]))

        train_val_gap = train_val_gaps[best_alpha]
        is_overfitting = train_val_gap > 0.35

        results[model] = {
            'best_alpha': best_alpha,
            'test_r2': r2_score(y[test_idx], test_pred),
            'test_mae': mean_absolute_error(y[test_idx], test_pred),
            'test_rmse': np.sqrt(mean_squared_error(y[test_idx], test_pred)),
            'predictions': test_pred,
            'actual': y[test_idx],
            'cv_r2_mean': np.mean(val_r2[best_alpha]),
            'cv_r2_std': np.std(val_r2[best_alpha]),
            'cv_mae_mean': np.mean(val_mae[best_alpha]),
            'cv_mae_std': np.std(val_mae[best_alpha]),
            'train_r2_mean': np.mean(train_r2[best_alpha]),
            'train_val_gap': train_val_gap,
            'train_val_gaps_all_alphas': train_val_gaps,
            'is_overfitting': is_overfitting,
            'fold_results': {
                'val_r2': val_r2[best_alpha],
                'val_mae': val_mae[best_alpha],
                'train_r2': train_r2[best_alpha]
            },
            'cv_results': {
                'alphas': alphas, 
                'r2': avg_r2, 
                'mae': {a: np.mean(val_mae[a]) for a in alphas if val_mae[a]},
                'r2_std': {a: np.std(val_r2[a]) for a in alphas if val_r2[a]},
                'mae_std': {a: np.std(val_mae[a]) for a in alphas if val_mae[a]},
                'train_r2': {a: np.mean(train_r2[a]) for a in alphas if train_r2[a]},
                'train_mae': {a: np.mean(train_mae[a]) for a in alphas if train_mae[a]}
            },
            'trained_model': rf,
            'scaler': scaler,
            'test_idx': test_idx
        }
        print(f"{model}: R2={results[model]['test_r2']:.4f}, MAE={results[model]['test_mae']:.2f}" +
              (f" [OVERFITTING: train-val gap={train_val_gap:.3f}]" if is_overfitting else ""))

    return results


print("Sex Classification")
sex_results = run_classification(features_dict, sex_labels, test_size=0.1, task_name="Sex classification")
print()

# NEW: SPLIT AGE PREDICTION BY SEX

print("AGE PREDICTION - SPLIT BY SEX")

#Male subjects (sex_labels == 0)
male_indices = np.where(sex_labels == 0)[0]
male_features_dict = {k: v[male_indices] for k, v in features_dict.items()}
male_age_labels = age_labels[male_indices]

print(f"\nMale subjects: N={len(male_indices)}")
print(f"Male age: mean={male_age_labels.mean():.1f}, std={male_age_labels.std():.2f}, range={male_age_labels.min():.1f}-{male_age_labels.max():.1f}")

male_age_results = run_regression(male_features_dict, male_age_labels, test_size=0.1, task_name="Male age prediction")
print()

#Female subjects (sex_labels == 1)
female_indices = np.where(sex_labels == 1)[0]
female_features_dict = {k: v[female_indices] for k, v in features_dict.items()}
female_age_labels = age_labels[female_indices]

print(f"\nFemale subjects: N={len(female_indices)}")
print(f"Female age: mean={female_age_labels.mean():.1f}, std={female_age_labels.std():.2f}, range={female_age_labels.min():.1f}-{female_age_labels.max():.1f}")

female_age_results = run_regression(female_features_dict, female_age_labels, test_size=0.1, task_name="Female age prediction")
print()

print("Overfitting Analysis")
overfitting_report = []
for model in features_dict.keys():
    sex_overfit = sex_results[model]['is_overfitting']
    male_age_overfit = male_age_results[model]['is_overfitting']
    female_age_overfit = female_age_results[model]['is_overfitting']
    
    overfitting_report.append({
        'Model': model,
        'Sex_Overfitting': sex_overfit,
        'Sex_Train_Val_Gap': sex_results[model]['train_val_gap'],
        'Male_Age_Overfitting': male_age_overfit,
        'Male_Age_Train_Val_Gap': male_age_results[model]['train_val_gap'],
        'Female_Age_Overfitting': female_age_overfit,
        'Female_Age_Train_Val_Gap': female_age_results[model]['train_val_gap']
    })
    
    if sex_overfit or male_age_overfit or female_age_overfit:
        print(f"{model}:")
        if sex_overfit:
            print(f"Sex classification: Train-Val AUC gap = {sex_results[model]['train_val_gap']:.3f}")
        if male_age_overfit:
            print(f"Male age prediction: Train-Val R² gap = {male_age_results[model]['train_val_gap']:.3f}")
        if female_age_overfit:
            print(f"Female age prediction: Train-Val R² gap = {female_age_results[model]['train_val_gap']:.3f}")

if not any(r['Sex_Overfitting'] or r['Male_Age_Overfitting'] or r['Female_Age_Overfitting'] for r in overfitting_report):
    print("No significant overfitting detected")

pd.DataFrame(overfitting_report).to_csv('hbn_overfitting_analysis.csv', index=False)
print("\nSaved: hbn_overfitting_analysis.csv\n")

# Save train-val gaps for all alphas
print("Saving train-val gaps for all alphas...")
gap_rows = []
for model in features_dict.keys():
    model_alphas = sex_results[model]['cv_results']['alphas']
    for alpha in model_alphas:
        sex_gap = sex_results[model]['train_val_gaps_all_alphas'].get(alpha, np.nan)
        male_age_gap = male_age_results[model]['train_val_gaps_all_alphas'].get(alpha, np.nan)
        female_age_gap = female_age_results[model]['train_val_gaps_all_alphas'].get(alpha, np.nan)
        gap_rows.append({
            'Model': model,
            'Alpha': alpha,
            'Sex_Train_Val_Gap': sex_gap,
            'Male_Age_Train_Val_Gap': male_age_gap,
            'Female_Age_Train_Val_Gap': female_age_gap
        })

pd.DataFrame(gap_rows).to_csv('hbn_train_val_gaps_all_alphas.csv', index=False)
print("Saved: hbn_train_val_gaps_all_alphas.csv\n")

print("Results")

all_models = ['DummyClassifier'] + list(features_dict.keys())

pd.DataFrame([{'Model': m, 'Best_Alpha': sex_results[m]['best_alpha'], 
               'Test_Balanced_Accuracy': sex_results[m]['test_balanced_accuracy'], 
               'Test_AUC': sex_results[m]['test_auc'],
               'CV_AUC_Mean': sex_results[m]['cv_auc_mean'],
               'CV_AUC_Std': sex_results[m]['cv_auc_std']} 
              for m in all_models]).to_csv('hbn_sex_classification_summary.csv', index=False)

all_models_reg = ['DummyRegressor'] + list(features_dict.keys())

pd.DataFrame([{'Model': m, 'Best_Alpha': male_age_results[m]['best_alpha'], 
               'Test_R2': male_age_results[m]['test_r2'], 
               'Test_MAE': male_age_results[m]['test_mae'], 
               'Test_RMSE': male_age_results[m]['test_rmse'],
               'CV_R2_Mean': male_age_results[m]['cv_r2_mean'],
               'CV_R2_Std': male_age_results[m]['cv_r2_std'],
               'CV_MAE_Mean': male_age_results[m]['cv_mae_mean'],
               'CV_MAE_Std': male_age_results[m]['cv_mae_std']} 
              for m in all_models_reg]).to_csv('hbn_male_age_prediction_summary.csv', index=False)

pd.DataFrame([{'Model': m, 'Best_Alpha': female_age_results[m]['best_alpha'], 
               'Test_R2': female_age_results[m]['test_r2'], 
               'Test_MAE': female_age_results[m]['test_mae'], 
               'Test_RMSE': female_age_results[m]['test_rmse'],
               'CV_R2_Mean': female_age_results[m]['cv_r2_mean'],
               'CV_R2_Std': female_age_results[m]['cv_r2_std'],
               'CV_MAE_Mean': female_age_results[m]['cv_mae_mean'],
               'CV_MAE_Std': female_age_results[m]['cv_mae_std']} 
              for m in all_models_reg]).to_csv('hbn_female_age_prediction_summary.csv', index=False)

all_results = [
    {'Model': m, 
     'Sex_AUC': sex_results[m]['test_auc'], 
     'Sex_Bal_Acc': sex_results[m]['test_balanced_accuracy'],
     'Male_Age_R2': male_age_results[m]['test_r2'], 
     'Male_Age_MAE': male_age_results[m]['test_mae'], 
     'Male_Age_RMSE': male_age_results[m]['test_rmse'],
     'Female_Age_R2': female_age_results[m]['test_r2'], 
     'Female_Age_MAE': female_age_results[m]['test_mae'], 
     'Female_Age_RMSE': female_age_results[m]['test_rmse']} 
    for m in list(features_dict.keys())
]
pd.DataFrame(all_results).to_csv('hbn_detailed_comparison.csv', index=False)

rankings = []
for task in ['Sex_AUC', 'Male_Age_R2', 'Female_Age_R2']:
    for rank, item in enumerate(sorted(all_results, key=lambda x: x[task], reverse=True), 1):
        rankings.append({
            'Task': task.replace('_', ' '), 
            'Rank': rank, 
            'Model': item['Model'], 
            'Score': item[task]
        })
pd.DataFrame(rankings).to_csv('hbn_model_rankings.csv', index=False)

print("Saved: hbn_sex_classification_summary.csv")
print("Saved: hbn_male_age_prediction_summary.csv")
print("Saved: hbn_female_age_prediction_summary.csv")
print("Saved: hbn_detailed_comparison.csv")
print("Saved: hbn_model_rankings.csv\n")

#Visualization

print("Learning Curves")

models = list(features_dict.keys())

model_styles = {
    'AnatCL': {'color': '#9B59B6', 'marker': '^', 'linestyle': '-', 'label_prefix': 'FM: '},
    'BrainIAC': {'color': '#E74C3C', 'marker': '^', 'linestyle': '-', 'label_prefix': 'FM: '},
    'CNN': {'color': '#3498DB', 'marker': 'o', 'linestyle': '--', 'label_prefix': ''},
    'FreeSurfer': {'color': '#E67E22', 'marker': 's', 'linestyle': '-', 'label_prefix': ''}
}

n_cv_sex = len(sex_labels) - int(0.1 * len(sex_labels))
n_cv_male = len(male_age_labels) - int(0.1 * len(male_age_labels))
n_cv_female = len(female_age_labels) - int(0.1 * len(female_age_labels))

alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]
training_sizes_sex = [int(alpha * n_cv_sex * 0.8) for alpha in alphas]
training_sizes_male = [int(alpha * n_cv_male * 0.8) for alpha in alphas]
training_sizes_female = [int(alpha * n_cv_female * 0.8) for alpha in alphas]

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 6))

# Sex classification
for model in models:
    style = model_styles[model]
    acc_means = []
    acc_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in sex_results[model]['cv_results']['acc']:
            acc_means.append(sex_results[model]['cv_results']['acc'][alpha])
            acc_stds.append(sex_results[model]['cv_results']['acc_std'][alpha])
            valid_sizes.append(training_sizes_sex[i])
    
    acc_means = np.array(acc_means)
    acc_stds = np.array(acc_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax1.plot(valid_sizes, acc_means, marker=style['marker'], linestyle=style['linestyle'],
             linewidth=3, markersize=10, label=style['label_prefix'] + model, 
             color=style['color'])
    ax1.fill_between(valid_sizes, acc_means - acc_stds, acc_means + acc_stds, 
                      alpha=0.2, color=style['color'])

dummy_acc_cv = sex_results['DummyClassifier']['cv_acc_mean']
ax1.axhline(y=dummy_acc_cv, color='gray', linestyle=':', linewidth=2, 
            label=f'Baseline ({dummy_acc_cv:.3f})', alpha=0.7)

ax1.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax1.set_ylabel('Balanced Accuracy', fontsize=12, weight='bold')
ax1.set_title('Sex Classification', fontsize=14, weight='bold')
ax1.legend(fontsize=9, loc='lower right')
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 1])

# Male age prediction
for model in models:
    style = model_styles[model]
    mae_means = []
    mae_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in male_age_results[model]['cv_results']['mae']:
            mae_means.append(male_age_results[model]['cv_results']['mae'][alpha])
            mae_stds.append(male_age_results[model]['cv_results']['mae_std'][alpha])
            valid_sizes.append(training_sizes_male[i])
    
    mae_means = np.array(mae_means)
    mae_stds = np.array(mae_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax2.plot(valid_sizes, mae_means, marker=style['marker'], linestyle=style['linestyle'],
             linewidth=3, markersize=10, label=style['label_prefix'] + model,
             color=style['color'])
    ax2.fill_between(valid_sizes, mae_means - mae_stds, mae_means + mae_stds,
                      alpha=0.2, color=style['color'])

dummy_mae_cv_male = male_age_results['DummyRegressor']['cv_mae_mean']
ax2.axhline(y=dummy_mae_cv_male, color='gray', linestyle=':', linewidth=2, 
            label=f'Baseline ({dummy_mae_cv_male:.2f})', alpha=0.7)

ax2.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax2.set_ylabel('MAE (years)', fontsize=12, weight='bold')
ax2.set_title('Male Age Prediction', fontsize=14, weight='bold')
ax2.legend(fontsize=9, loc='upper right')
ax2.grid(True, alpha=0.3)

# Female age prediction
for model in models:
    style = model_styles[model]
    mae_means = []
    mae_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in female_age_results[model]['cv_results']['mae']:
            mae_means.append(female_age_results[model]['cv_results']['mae'][alpha])
            mae_stds.append(female_age_results[model]['cv_results']['mae_std'][alpha])
            valid_sizes.append(training_sizes_female[i])
    
    mae_means = np.array(mae_means)
    mae_stds = np.array(mae_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax3.plot(valid_sizes, mae_means, marker=style['marker'], linestyle=style['linestyle'],
             linewidth=3, markersize=10, label=style['label_prefix'] + model,
             color=style['color'])
    ax3.fill_between(valid_sizes, mae_means - mae_stds, mae_means + mae_stds,
                      alpha=0.2, color=style['color'])

dummy_mae_cv_female = female_age_results['DummyRegressor']['cv_mae_mean']
ax3.axhline(y=dummy_mae_cv_female, color='gray', linestyle=':', linewidth=2, 
            label=f'Baseline ({dummy_mae_cv_female:.2f})', alpha=0.7)

ax3.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax3.set_ylabel('MAE (years)', fontsize=12, weight='bold')
ax3.set_title('Female Age Prediction', fontsize=14, weight='bold')
ax3.legend(fontsize=9, loc='upper right')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('hbn_cv_learning_curves.pdf', dpi=300, bbox_inches='tight')
plt.close()
print("Saved: hbn_cv_learning_curves.pdf\n")

print(f"\nDataset: {len(common_subjects)} subjects")
print(f"Age range: {age_labels.min():.1f} - {age_labels.max():.1f} years")
print(f"Sex: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"Males: {len(male_indices)} subjects")
print(f"Females: {len(female_indices)} subjects")

Overwriting /home/arelbaha/links/projects/rrg-glatard/arelbaha/debug_hbn_multimodel_comparison.py


In [4]:
%%writefile /home/arelbaha/links/projects/rrg-glatard/arelbaha/combined_multimodel_comparison.py

import os
import glob
import random
import numpy as np
import pandas as pd

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import nibabel as nib
from anatcl import AnatCL
from monai.networks.nets import ViT
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.dummy import DummyClassifier, DummyRegressor
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import squareform
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.ndimage as ndi
from matplotlib.patches import Patch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

device = "cpu"
DROPOUT_RATE = 0.3
ANATCL_ENCODER_PATH = "/home/arelbaha/.venvs/jupyter_py3/bin"
BRAINIAC_CKPT = "/home/arelbaha/.venvs/jupyter_py3/bin/BrainIAC.ckpt"

class CAT12VBMDataset(Dataset):
    def __init__(self, data, labels, transform):
        self.data, self.labels, self.transform = data, labels, transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img = self.transform(nib.load(self.data[idx]).get_fdata()).unsqueeze(0)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

class BrainIACDataset(Dataset):
    def __init__(self, paths, labels, transform):
        self.paths, self.labels, self.transform = paths, labels, transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (96, 96, 96):
            img = F.interpolate(torch.from_numpy(img[None, None]), size=(96, 96, 96), 
                              mode='trilinear', align_corners=False).squeeze().numpy()
        img_tensor = torch.from_numpy(img)
        img_normalized = self.transform(img_tensor)
        return img_normalized[None], torch.tensor(self.labels[idx], dtype=torch.float32)

class CNNDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (92, 110, 92):
            img = ndi.zoom(img, [92/img.shape[0], 110/img.shape[1], 92/img.shape[2]], order=1)
        img = (img - img.mean()) / (img.std() + 1e-6)
        return torch.from_numpy(img[None]).float()

class CNN3D(nn.Module):
    def __init__(self, dropout_rate=DROPOUT_RATE):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 64, 3, padding=1)
        self.bn1, self.pool1, self.drop1 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv2 = nn.Conv3d(64, 64, 3, padding=1)
        self.bn2, self.pool2, self.drop2 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv3 = nn.Conv3d(64, 128, 3, padding=1)
        self.bn3, self.pool3, self.drop3 = nn.BatchNorm3d(128), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv4 = nn.Conv3d(128, 256, 3, padding=1)
        self.bn4, self.pool4, self.drop4 = nn.BatchNorm3d(256), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1, self.drop5 = nn.Linear(256, 512), nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.drop1(self.bn1(self.pool1(F.relu(self.conv1(x)))))
        x = self.drop2(self.bn2(self.pool2(F.relu(self.conv2(x)))))
        x = self.drop3(self.bn3(self.pool3(F.relu(self.conv3(x)))))
        x = self.drop4(self.bn4(self.pool4(F.relu(self.conv4(x)))))
        return self.drop5(F.relu(self.fc1(self.global_pool(x).view(x.size(0), -1))))

def load_brainiac_vit(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    vit = ViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16),
              hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, classification=False)
    pe_key = "backbone.patch_embedding.position_embeddings"
    if pe_key in ckpt.get("state_dict", ckpt):
        vit.patch_embedding.position_embeddings = nn.Parameter(ckpt["state_dict"][pe_key].clone())
    backbone_state = {k.replace("backbone.", ""): v for k, v in ckpt.get("state_dict", ckpt).items() 
                     if k.startswith("backbone.")}
    vit.load_state_dict(backbone_state, strict=False)
    return vit.to(device).eval()

def run_classification(X_dict, y, test_size=0.1, task_name="Classification"):
    print(f"Running {task_name}")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), 
                                        stratify=y, random_state=SEED)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

    dummy_val_auc, dummy_val_acc = [], []
    for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyClassifier(strategy='stratified', random_state=SEED)
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        pred_proba = dummy.predict_proba(np.zeros((len(val_idx), 1)))[:, 1]
        dummy_val_acc.append(balanced_accuracy_score(y[val_idx], pred))
        dummy_val_auc.append(roc_auc_score(y[val_idx], pred_proba))
    
    dummy = DummyClassifier(strategy='stratified', random_state=SEED)
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    dummy_test_proba = dummy.predict_proba(np.zeros((len(test_idx), 1)))[:, 1]
    
    results['DummyClassifier'] = {
        'best_alpha': 1.0,
        'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], dummy_test_pred),
        'test_auc': roc_auc_score(y[test_idx], dummy_test_proba),
        'cv_auc_mean': np.mean(dummy_val_auc), 'cv_auc_std': np.std(dummy_val_auc),
        'cv_acc_mean': np.mean(dummy_val_acc), 'cv_acc_std': np.std(dummy_val_acc),
        'fold_results': {'val_auc': dummy_val_auc, 'val_acc': dummy_val_acc}, 'test_idx': test_idx
    }
    print(f"  DummyClassifier: AUC={results['DummyClassifier']['test_auc']:.4f}")

    for model, X in X_dict.items():
        val_auc = {a: [] for a in alphas}
        val_acc = {a: [] for a in alphas}
        train_auc = {a: [] for a in alphas}
        train_acc = {a: [] for a in alphas}

        for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
            for alpha in alphas:
                n = int(alpha * len(train_idx))
                try:
                    tr_idx, _ = train_test_split(train_idx, train_size=n, stratify=y[train_idx], random_state=SEED)
                except:
                    tr_idx = np.random.choice(train_idx, n, replace=False)
                if len(np.unique(y[tr_idx])) < 2:
                    continue
                scaler = StandardScaler()
                rf = RandomForestClassifier(n_estimators=200, max_depth=6, min_samples_split=5,
                                           random_state=SEED, n_jobs=1, max_features='sqrt', class_weight='balanced')
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])
                pred_proba = rf.predict_proba(scaler.transform(X[val_idx]))[:, 1]
                val_acc[alpha].append(balanced_accuracy_score(y[val_idx], rf.predict(scaler.transform(X[val_idx]))))
                val_auc[alpha].append(roc_auc_score(y[val_idx], pred_proba))
                train_pred_proba = rf.predict_proba(scaler.transform(X[tr_idx]))[:, 1]
                train_acc[alpha].append(balanced_accuracy_score(y[tr_idx], rf.predict(scaler.transform(X[tr_idx]))))
                train_auc[alpha].append(roc_auc_score(y[tr_idx], train_pred_proba))

        avg_auc = {a: np.mean(val_auc[a]) for a in alphas if val_auc[a]}
        best_alpha = max(avg_auc, key=avg_auc.get)
        train_val_gaps = {a: np.mean(train_auc[a]) - np.mean(val_auc[a]) for a in alphas if train_auc[a] and val_auc[a]}

        try:
            final_tr, _ = train_test_split(cv_idx, train_size=int(best_alpha * len(cv_idx)), stratify=y[cv_idx], random_state=SEED)
        except:
            final_tr = cv_idx

        scaler = StandardScaler()
        rf = RandomForestClassifier(n_estimators=200, max_depth=6, min_samples_split=5,
                                    random_state=SEED, n_jobs=1, max_features='sqrt', class_weight='balanced')
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred_proba = rf.predict_proba(scaler.transform(X[test_idx]))[:, 1]
        train_val_gap = train_val_gaps[best_alpha]

        results[model] = {
            'best_alpha': best_alpha,
            'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], rf.predict(scaler.transform(X[test_idx]))),
            'test_auc': roc_auc_score(y[test_idx], test_pred_proba),
            'cv_auc_mean': np.mean(val_auc[best_alpha]), 'cv_auc_std': np.std(val_auc[best_alpha]),
            'cv_acc_mean': np.mean(val_acc[best_alpha]), 'cv_acc_std': np.std(val_acc[best_alpha]),
            'train_auc_mean': np.mean(train_auc[best_alpha]), 'train_val_gap': train_val_gap,
            'train_val_gaps_all_alphas': train_val_gaps, 'is_overfitting': train_val_gap > 0.25,
            'fold_results': {'val_auc': val_auc[best_alpha], 'val_acc': val_acc[best_alpha], 'train_auc': train_auc[best_alpha]},
            'cv_results': {
                'alphas': alphas, 
                'acc': {a: np.mean(val_acc[a]) for a in alphas if val_acc[a]},
                'auc': avg_auc,
                'acc_std': {a: np.std(val_acc[a]) for a in alphas if val_acc[a]},
                'auc_std': {a: np.std(val_auc[a]) for a in alphas if val_auc[a]},
            },
            'trained_model': rf, 'scaler': scaler, 'test_idx': test_idx
        }
        print(f"  {model}: AUC={results[model]['test_auc']:.4f}, Bal_Acc={results[model]['test_balanced_accuracy']:.4f}")
    return results

def run_regression(X_dict, y, test_size=0.1, task_name="Regression"):
    print(f"Running {task_name}")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), random_state=SEED)
    kf = KFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

    dummy_val_r2, dummy_val_mae = [], []
    for train_rel, val_rel in kf.split(cv_idx):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyRegressor(strategy='mean')
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        dummy_val_r2.append(r2_score(y[val_idx], pred))
        dummy_val_mae.append(mean_absolute_error(y[val_idx], pred))
    
    dummy = DummyRegressor(strategy='mean')
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    
    results['DummyRegressor'] = {
        'best_alpha': 1.0,
        'test_r2': r2_score(y[test_idx], dummy_test_pred),
        'test_mae': mean_absolute_error(y[test_idx], dummy_test_pred),
        'test_rmse': np.sqrt(mean_squared_error(y[test_idx], dummy_test_pred)),
        'cv_r2_mean': np.mean(dummy_val_r2), 'cv_r2_std': np.std(dummy_val_r2),
        'cv_mae_mean': np.mean(dummy_val_mae), 'cv_mae_std': np.std(dummy_val_mae),
        'fold_results': {'val_r2': dummy_val_r2, 'val_mae': dummy_val_mae}, 'test_idx': test_idx
    }
    print(f"  DummyRegressor: R2={results['DummyRegressor']['test_r2']:.4f}, MAE={results['DummyRegressor']['test_mae']:.2f}")

    for model, X in X_dict.items():
        val_r2 = {a: [] for a in alphas}
        val_mae = {a: [] for a in alphas}
        train_r2 = {a: [] for a in alphas}
        train_mae = {a: [] for a in alphas}

        for train_rel, val_rel in kf.split(cv_idx):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
            for alpha in alphas:
                n = int(alpha * len(train_idx))
                tr_idx, _ = train_test_split(train_idx, train_size=n, random_state=SEED) if n < len(train_idx) else (train_idx, None)
                scaler = StandardScaler()
                rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_split=5, random_state=SEED, n_jobs=1)
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])
                pred = rf.predict(scaler.transform(X[val_idx]))
                val_r2[alpha].append(r2_score(y[val_idx], pred))
                val_mae[alpha].append(mean_absolute_error(y[val_idx], pred))
                train_pred = rf.predict(scaler.transform(X[tr_idx]))
                train_r2[alpha].append(r2_score(y[tr_idx], train_pred))
                train_mae[alpha].append(mean_absolute_error(y[tr_idx], train_pred))

        avg_r2 = {a: np.mean(val_r2[a]) for a in alphas if val_r2[a]}
        best_alpha = max(avg_r2, key=avg_r2.get)
        train_val_gaps = {a: np.mean(train_r2[a]) - np.mean(val_r2[a]) for a in alphas if train_r2[a] and val_r2[a]}

        n = int(best_alpha * len(cv_idx))
        final_tr, _ = train_test_split(cv_idx, train_size=n, random_state=SEED) if n < len(cv_idx) else (cv_idx, None)

        scaler = StandardScaler()
        rf = RandomForestRegressor(n_estimators=200, max_depth=6, min_samples_split=5, random_state=SEED, n_jobs=1)
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred = rf.predict(scaler.transform(X[test_idx]))
        train_val_gap = train_val_gaps[best_alpha]

        results[model] = {
            'best_alpha': best_alpha,
            'test_r2': r2_score(y[test_idx], test_pred),
            'test_mae': mean_absolute_error(y[test_idx], test_pred),
            'test_rmse': np.sqrt(mean_squared_error(y[test_idx], test_pred)),
            'predictions': test_pred, 'actual': y[test_idx],
            'cv_r2_mean': np.mean(val_r2[best_alpha]), 'cv_r2_std': np.std(val_r2[best_alpha]),
            'cv_mae_mean': np.mean(val_mae[best_alpha]), 'cv_mae_std': np.std(val_mae[best_alpha]),
            'train_r2_mean': np.mean(train_r2[best_alpha]), 'train_val_gap': train_val_gap,
            'train_val_gaps_all_alphas': train_val_gaps, 'is_overfitting': train_val_gap > 0.35,
            'fold_results': {'val_r2': val_r2[best_alpha], 'val_mae': val_mae[best_alpha], 'train_r2': train_r2[best_alpha]},
            'cv_results': {
                'alphas': alphas,
                'r2': avg_r2,
                'mae': {a: np.mean(val_mae[a]) for a in alphas if val_mae[a]},
                'r2_std': {a: np.std(val_r2[a]) for a in alphas if val_r2[a]},
                'mae_std': {a: np.std(val_mae[a]) for a in alphas if val_mae[a]},
            },
            'trained_model': rf, 'scaler': scaler, 'test_idx': test_idx
        }
        print(f"  {model}: R2={results[model]['test_r2']:.4f}, MAE={results[model]['test_mae']:.2f}")
    return results

def compute_correlation_with_cluster_labels(features_dict, fs_feature_info, dataset_name):
    all_features = []
    boundaries = [0]
    reordered_indices = {}

    for model in ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer']:
        if model not in features_dict:
            continue
        feats = StandardScaler().fit_transform(features_dict[model])
        valid = ~np.isnan(feats).any(axis=0) & (np.std(feats, axis=0) > 1e-10)
        feats = feats[:, valid]
        if feats.shape[1] == 0:
            continue
        print(f"  {model}: {feats.shape[1]} features")
        
        corr = np.corrcoef(feats.T)
        corr = np.nan_to_num((corr + corr.T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
        np.fill_diagonal(corr, 1.0)
        dist = np.abs(1 - np.abs(corr))
        np.fill_diagonal(dist, 0)
        condensed = np.nan_to_num(squareform((dist + dist.T) / 2, checks=False), nan=0.0, posinf=1.0, neginf=0.0)
        reorder_idx = dendrogram(linkage(condensed, method='ward'), no_plot=True)['leaves']
        all_features.append(feats[:, reorder_idx])
        boundaries.append(boundaries[-1] + feats.shape[1])
        reordered_indices[model] = reorder_idx

    combined = np.hstack(all_features)
    corr = np.nan_to_num((np.corrcoef(combined.T) + np.corrcoef(combined.T).T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
    np.fill_diagonal(corr, 1.0)

    reordered_fs_indices = reordered_indices['FreeSurfer']
    
    cluster_df = pd.DataFrame([
        {**fs_feature_info[idx], 'original_index': idx, 'reordered_position': pos}
        for pos, idx in enumerate(reordered_fs_indices)
    ])
    
    n_sa = np.sum(cluster_df['feature_type'] == 'Surface_Area')
    n_thick = np.sum(cluster_df['feature_type'] == 'Thickness')
    total = len(cluster_df)
    
    print(f"\n{dataset_name} FreeSurfer ({total} features reordered by correlation):")
    print(f"  Surface Area: {n_sa} ({100*n_sa/total:.1f}%)")
    print(f"  Thickness: {n_thick} ({100*n_thick/total:.1f}%)")
    
    return corr, boundaries, ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer'], reordered_indices, cluster_df

def extract_features(cat12_paths, brainiac_paths, age_labels, device):
    anatcl_transform = transforms.Compose([
        transforms.Lambda(lambda x: torch.from_numpy(x.copy()).float()),
        transforms.Normalize(mean=0.0, std=1.0)
    ])
    
    print("Extracting AnatCL features...")
    all_fold_features = []
    for fold_idx in range(5):
        path = os.path.join(ANATCL_ENCODER_PATH, f"fold{fold_idx}.pth")
        encoder = AnatCL(descriptor="global", fold=fold_idx, pretrained=False).to(device).eval()
        encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
        for p in encoder.parameters():
            p.requires_grad = False
        dl = DataLoader(CAT12VBMDataset(cat12_paths, age_labels, anatcl_transform), batch_size=32, num_workers=0)
        with torch.no_grad():
            fold_features = torch.cat([encoder(vol.to(device)).cpu() for vol, _ in dl]).numpy()
        all_fold_features.append(fold_features)
        del encoder
    anatcl_features = np.mean(all_fold_features, axis=0)
    print(f"  AnatCL: {anatcl_features.shape}")
    
    print("Extracting BrainIAC features...")
    brainiac_vit = load_brainiac_vit(BRAINIAC_CKPT, device)
    brainiac_transform = lambda x: (x - x.mean()) / (x.std() + 1e-6)
    dl = DataLoader(BrainIACDataset(brainiac_paths, age_labels, brainiac_transform), batch_size=16, num_workers=0)
    brainiac_features = []
    with torch.no_grad():
        for x, _ in dl:
            out = brainiac_vit(x.to(device))
            cls_token = out[0][:, 0] if isinstance(out, tuple) else out[:, 0]
            brainiac_features.append(cls_token.cpu().numpy())
    brainiac_features = np.vstack(brainiac_features)
    print(f"  BrainIAC: {brainiac_features.shape}")
    del brainiac_vit
    
    print("Extracting CNN features...")
    cnn_model = CNN3D().to(device).eval()
    dl = DataLoader(CNNDataset(brainiac_paths), batch_size=8, num_workers=0)
    with torch.no_grad():
        cnn_features = np.vstack([cnn_model(x.to(device)).cpu().numpy() for x in dl])
    print(f"  CNN: {cnn_features.shape}")
    del cnn_model
    
    return anatcl_features, brainiac_features, cnn_features

def plot_correlation_with_labels(corr_matrix, boundaries, model_names, cluster_df, dataset_name):
    fig, ax = plt.subplots(figsize=(32, 30))
    
    im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='equal', interpolation='nearest')
    plt.colorbar(im, ax=ax, shrink=0.3)
    
    for b in boundaries[1:-1]:
        ax.axhline(y=b-0.5, color='black', linewidth=4)
        ax.axvline(x=b-0.5, color='black', linewidth=4)
    
    fs_start = boundaries[3]
    fs_end = boundaries[4]
    bar_width = 8
    
    feature_types = cluster_df['feature_type'].values
    colors = {'Surface_Area': '#2ECC71', 'Thickness': '#3498DB'}
    
    for i, ft in enumerate(feature_types):
        color = colors.get(ft, '#95A5A6')
        rect_right = plt.Rectangle((fs_end + 3, fs_start + i - 0.5), bar_width, 1,
                                    facecolor=color, edgecolor='none')
        ax.add_patch(rect_right)
        rect_left = plt.Rectangle((-bar_width - 5, fs_start + i - 0.5), bar_width, 1,
                                   facecolor=color, edgecolor='none')
        ax.add_patch(rect_left)
    
    for i in range(len(boundaries) - 1):
        pos = (boundaries[i] + boundaries[i+1]) / 2
        label = model_names[i]
        ax.text(pos, -40, label, ha='center', fontsize=20, weight='bold')
        ax.text(-50, pos, label, ha='center', va='center', fontsize=20, weight='bold', rotation=90)
    
    ax.set_xlim(-bar_width - 60, fs_end + bar_width + 50)
    ax.set_ylim(fs_end + 20, -60)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')
    
    legend_elements = [
        Patch(facecolor='#2ECC71', edgecolor='black', label='Surface Area'),
        Patch(facecolor='#3498DB', edgecolor='black', label='Thickness')
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=14)
    
    plt.tight_layout()
    plt.savefig(f'{dataset_name.lower()}_cross_model_correlation.pdf', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved: {dataset_name.lower()}_cross_model_correlation.pdf")


print("PPMI DATASET")
PPMI_CAT12_BASE_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs"
PPMI_DATA_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data"
PPMI_LABELS_PATH = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/processed_cohort_with_mri.csv"
PPMI_BRAINIAC_MAPPING_CSV = os.path.join(PPMI_DATA_DIR, "processed_files_mapping.csv")

print("Loading PPMI demographics...")
ppmi_labels_df = pd.read_csv(PPMI_LABELS_PATH)
ppmi_id_sex_dict, ppmi_id_parkinson_dict, ppmi_id_age_dict = {}, {}, {}

for _, row in ppmi_labels_df.iterrows():
    patno = str(int(row['PATNO']))
    sex = row['Sex'].strip().upper()
    ppmi_id_sex_dict[patno] = 1 if sex == 'F' else 0
    group = row['Group'].strip()
    ppmi_id_parkinson_dict[patno] = 1 if group == 'PD' else 0
    ppmi_id_age_dict[patno] = row['Age']

print(f"  {len(ppmi_id_sex_dict)} subjects")

def extract_patno_from_path(filepath):
    for part in filepath.split(os.sep):
        if part.startswith('sub-'):
            return part[4:]
    return None

print("Finding PPMI CAT12 files...")
ppmi_cat12_files = glob.glob(os.path.join(PPMI_CAT12_BASE_DIR, "**", "*s6mwp1*.nii*"), recursive=True)
ppmi_cat12_data = {}
for f in ppmi_cat12_files:
    if not os.path.isfile(f):
        continue
    patno = extract_patno_from_path(f)
    if patno and patno in ppmi_id_sex_dict:
        ppmi_cat12_data[patno] = f
print(f"  Found {len(ppmi_cat12_data)} CAT12 files")

print("Finding PPMI BrainIAC files...")
ppmi_brainiac_df = pd.read_csv(PPMI_BRAINIAC_MAPPING_CSV).dropna(subset=["processed_file", "Age", "subject_id"])
ppmi_brainiac_data = {}
for _, row in ppmi_brainiac_df.iterrows():
    patno = str(row['subject_id'])
    if patno in ppmi_id_sex_dict and os.path.exists(row['processed_file']):
        ppmi_brainiac_data[patno] = row['processed_file']
print(f"  Found {len(ppmi_brainiac_data)} BrainIAC files")

print("Loading PPMI FreeSurfer features...")
ppmi_fs_cth_df = pd.read_csv(os.path.join(PPMI_CAT12_BASE_DIR, "FS7_APARC_CTH_23Oct2025.csv"))
ppmi_fs_sa_df = pd.read_csv(os.path.join(PPMI_CAT12_BASE_DIR, "FS7_APARC_SA_23Oct2025.csv"))
ppmi_fs_cth_df = ppmi_fs_cth_df[ppmi_fs_cth_df['EVENT_ID'] == 'BL'].copy()
ppmi_fs_sa_df = ppmi_fs_sa_df[ppmi_fs_sa_df['EVENT_ID'] == 'BL'].copy()
ppmi_fs_cth_df['PATNO'] = ppmi_fs_cth_df['PATNO'].astype(str)
ppmi_fs_sa_df['PATNO'] = ppmi_fs_sa_df['PATNO'].astype(str)
ppmi_cth_features = [c for c in ppmi_fs_cth_df.columns if c not in ['PATNO', 'EVENT_ID']]
ppmi_sa_features = [c for c in ppmi_fs_sa_df.columns if c not in ['PATNO', 'EVENT_ID']]

ppmi_fs_feature_info = []
for feat in ppmi_cth_features:
    ppmi_fs_feature_info.append({'feature_name': feat, 'feature_type': 'Thickness'})
for feat in ppmi_sa_features:
    ppmi_fs_feature_info.append({'feature_name': feat, 'feature_type': 'Surface_Area'})

ppmi_common_subjects = sorted(list(set(ppmi_cat12_data.keys()) & set(ppmi_brainiac_data.keys())))
ppmi_fs_data = {}
for patno in ppmi_common_subjects:
    cth_row = ppmi_fs_cth_df[ppmi_fs_cth_df['PATNO'] == patno]
    sa_row = ppmi_fs_sa_df[ppmi_fs_sa_df['PATNO'] == patno]
    if len(cth_row) > 0 and len(sa_row) > 0:
        combined = np.concatenate([cth_row[ppmi_cth_features].values.flatten(), sa_row[ppmi_sa_features].values.flatten()])
        if not np.any(np.isnan(combined)):
            ppmi_fs_data[patno] = combined

ppmi_common_subjects = sorted(list(set(ppmi_cat12_data.keys()) & set(ppmi_brainiac_data.keys()) & set(ppmi_fs_data.keys())))
print(f"  Common subjects: {len(ppmi_common_subjects)}")

ppmi_cat12_paths = [ppmi_cat12_data[p] for p in ppmi_common_subjects]
ppmi_brainiac_paths = [ppmi_brainiac_data[p] for p in ppmi_common_subjects]
ppmi_fs_features = np.array([ppmi_fs_data[p] for p in ppmi_common_subjects])
ppmi_sex_labels = np.array([ppmi_id_sex_dict[p] for p in ppmi_common_subjects])
ppmi_parkinson_labels = np.array([ppmi_id_parkinson_dict[p] for p in ppmi_common_subjects])
ppmi_age_labels = np.array([ppmi_id_age_dict[p] for p in ppmi_common_subjects])

ppmi_anatcl_features, ppmi_brainiac_features, ppmi_cnn_features = extract_features(
    ppmi_cat12_paths, ppmi_brainiac_paths, ppmi_age_labels, device)

ppmi_features_dict = {
    'AnatCL': ppmi_anatcl_features,
    'BrainIAC': ppmi_brainiac_features,
    'CNN': ppmi_cnn_features,
    'FreeSurfer': ppmi_fs_features
}

print("\nComputing PPMI correlation matrix...")
ppmi_corr_matrix, ppmi_boundaries, ppmi_model_names, ppmi_reorder_indices, ppmi_cluster_df = \
    compute_correlation_with_cluster_labels(ppmi_features_dict, ppmi_fs_feature_info, "PPMI")

print("\nPPMI Sex Classification")
ppmi_sex_results = run_classification(ppmi_features_dict, ppmi_sex_labels, task_name="PPMI Sex")

print("\nPPMI Parkinson Classification")
ppmi_parkinson_results = run_classification(ppmi_features_dict, ppmi_parkinson_labels, task_name="PPMI Parkinson")

ppmi_male_indices = np.where(ppmi_sex_labels == 0)[0]
ppmi_male_features_dict = {k: v[ppmi_male_indices] for k, v in ppmi_features_dict.items()}
ppmi_male_age_labels = ppmi_age_labels[ppmi_male_indices]
print(f"\nPPMI Male subjects: {len(ppmi_male_indices)}")
ppmi_male_age_results = run_regression(ppmi_male_features_dict, ppmi_male_age_labels, task_name="PPMI Male Age")

ppmi_female_indices = np.where(ppmi_sex_labels == 1)[0]
ppmi_female_features_dict = {k: v[ppmi_female_indices] for k, v in ppmi_features_dict.items()}
ppmi_female_age_labels = ppmi_age_labels[ppmi_female_indices]
print(f"\nPPMI Female subjects: {len(ppmi_female_indices)}")
ppmi_female_age_results = run_regression(ppmi_female_features_dict, ppmi_female_age_labels, task_name="PPMI Female Age")


print("\n\nHBN DATASET")
HBN_BIDS = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_BIDS"
HBN_BIDS_LOWER = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/hbn_bids"
HBN_BRAINIAC_OUT = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/brainiac_p_outputs"
HBN_FREESURFER_DIR = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_FreeSurfer/freesurfer"
HBN_DEMO_FILE = os.path.join(HBN_BIDS, "final_preprocessed_subjects_with_demographics.tsv")
HBN_PARCELLATION = "Schaefer2018_400Parcels_17Networks_order"

print("Loading HBN demographics...")
hbn_demo_df = pd.read_csv(HBN_DEMO_FILE, sep='\t')
hbn_demo_df['participant_id'] = hbn_demo_df['participant_id'].astype(str)
hbn_id_sex_dict, hbn_id_age_dict = {}, {}

for _, row in hbn_demo_df.iterrows():
    subject_id = row['participant_id']
    sex = row['sex'].strip()
    hbn_id_sex_dict[subject_id] = 1 if sex == 'Female' else 0
    hbn_id_age_dict[subject_id] = row['age']

print(f"  {len(hbn_id_sex_dict)} subjects")

print("Finding HBN CAT12 files...")
hbn_cat12_data = {}
for subject_id in hbn_id_sex_dict.keys():
    pattern = os.path.join(HBN_BIDS, f"sub-{subject_id}", "ses-*", "anat", "mri", "s6mwp1sub*.nii")
    files = glob.glob(pattern)
    if files:
        hbn_cat12_data[subject_id] = files[0]
for subject_id in hbn_id_sex_dict.keys():
    if subject_id not in hbn_cat12_data:
        pattern = os.path.join(HBN_BIDS_LOWER, f"sub-{subject_id}", "ses-*", "anat", "mri", "s6mwp1sub*.nii")
        files = glob.glob(pattern)
        if files:
            hbn_cat12_data[subject_id] = files[0]
print(f"  Found {len(hbn_cat12_data)} CAT12 files")

print("Finding HBN BrainIAC files...")
hbn_brainiac_data = {}
batch_dirs = glob.glob(os.path.join(HBN_BRAINIAC_OUT, "batch_*"))
for batch_dir in sorted(batch_dirs):
    files = glob.glob(os.path.join(batch_dir, "sub-*_0000.nii.gz"))
    for f in files:
        basename = os.path.basename(f)
        subject_id = basename.split('_')[0].replace('sub-', '')
        if subject_id in hbn_id_sex_dict and subject_id not in hbn_brainiac_data:
            hbn_brainiac_data[subject_id] = f
print(f"  Found {len(hbn_brainiac_data)} BrainIAC files")

print("Loading HBN FreeSurfer features...")
hbn_fs_data = {}
hbn_fs_region_info = {}
hbn_fs_row_counts = []
hbn_fs_region_sets = {}
hbn_fs_extra_regions = {}

for subject_id in hbn_id_sex_dict.keys():
    subject_dir = os.path.join(HBN_FREESURFER_DIR, f"sub-{subject_id}")
    stats_file = os.path.join(subject_dir, f"sub-{subject_id}_regionsurfacestats.tsv")
    if os.path.exists(stats_file):
        try:
            df = pd.read_csv(stats_file, sep='\t')
            filtered_df = df[df["atlas"] == HBN_PARCELLATION]
            if not filtered_df.empty:
                filtered_df = filtered_df.sort_values("StructName")
                hbn_fs_row_counts.append(len(filtered_df))
                if len(filtered_df) > 400:
                    extra = filtered_df['StructName'].values[400:].tolist()
                    hbn_fs_extra_regions[subject_id] = extra
                if "SurfArea" in filtered_df.columns and "ThickAvg" in filtered_df.columns:
                    surf_area = filtered_df["SurfArea"].values[:400]
                    thick_avg = filtered_df["ThickAvg"].values[:400]
                    combined = np.concatenate([surf_area, thick_avg])
                    if not np.any(np.isnan(combined)) and len(surf_area) == 400 and len(thick_avg) == 400:
                        hbn_fs_data[subject_id] = combined
                        region_names_400 = tuple(filtered_df['StructName'].values[:400].tolist())
                        hbn_fs_region_sets[subject_id] = region_names_400
                        if not hbn_fs_region_info:
                            hbn_fs_region_info = {
                                'region_names': list(region_names_400),
                                'hemisphere': filtered_df['hemisphere'].values[:400].tolist(),
                            }
        except:
            pass

if hbn_fs_row_counts:
    unique_counts = np.unique(hbn_fs_row_counts)
    if len(unique_counts) > 1:
        print(f"  WARNING: Inconsistent row counts found: {unique_counts}")
        for cnt in unique_counts:
            print(f"    {cnt} rows: {np.sum(np.array(hbn_fs_row_counts) == cnt)} subjects")

if hbn_fs_extra_regions:
    print(f"  Subjects with >400 rows: {len(hbn_fs_extra_regions)}")
    all_extras = [r for regions in hbn_fs_extra_regions.values() for r in regions]
    unique_extras = set(all_extras)
    print(f"  Truncated regions (appear after row 400): {unique_extras}")

if hbn_fs_region_sets:
    unique_region_sets = set(hbn_fs_region_sets.values())
    if len(unique_region_sets) == 1:
        print(f"  VERIFIED: All {len(hbn_fs_region_sets)} subjects have identical 400 regions")
    else:
        print(f"  WARNING: Found {len(unique_region_sets)} different region sets!")
        reference_set = list(unique_region_sets)[0]
        for subject_id, regions in hbn_fs_region_sets.items():
            if regions != reference_set:
                diff = set(reference_set) ^ set(regions)
                print(f"    Subject {subject_id} differs by: {diff}")
                break

print(f"  Found {len(hbn_fs_data)} FreeSurfer subjects with exactly 800 features")

hbn_fs_feature_info = []
if hbn_fs_region_info:
    for i in range(400):
        hbn_fs_feature_info.append({
            'feature_name': hbn_fs_region_info['region_names'][i],
            'feature_type': 'Surface_Area',
            'hemisphere': hbn_fs_region_info['hemisphere'][i]
        })
    for i in range(400):
        hbn_fs_feature_info.append({
            'feature_name': hbn_fs_region_info['region_names'][i],
            'feature_type': 'Thickness',
            'hemisphere': hbn_fs_region_info['hemisphere'][i]
        })

# Load HBN aparc parcellation (68 regions)
print("\nLoading HBN FreeSurfer aparc features...")
HBN_APARC_PARCELLATION = "aparc"
hbn_aparc_fs_data = {}
hbn_aparc_region_info = {}

for subject_id in hbn_id_sex_dict.keys():
    subject_dir = os.path.join(HBN_FREESURFER_DIR, f"sub-{subject_id}")
    stats_file = os.path.join(subject_dir, f"sub-{subject_id}_regionsurfacestats.tsv")
    if os.path.exists(stats_file):
        try:
            df = pd.read_csv(stats_file, sep='\t')
            filtered_df = df[df["atlas"] == HBN_APARC_PARCELLATION]
            if not filtered_df.empty:
                filtered_df = filtered_df.sort_values("StructName")
                if "SurfArea" in filtered_df.columns and "ThickAvg" in filtered_df.columns:
                    surf_area = filtered_df["SurfArea"].values
                    thick_avg = filtered_df["ThickAvg"].values
                    combined = np.concatenate([surf_area, thick_avg])
                    if not np.any(np.isnan(combined)) and len(surf_area) == 68 and len(thick_avg) == 68:
                        hbn_aparc_fs_data[subject_id] = combined
                        if not hbn_aparc_region_info:
                            hbn_aparc_region_info = {
                                'region_names': filtered_df['StructName'].values.tolist(),
                                'hemisphere': filtered_df['hemisphere'].values.tolist(),
                            }
        except:
            pass

print(f"  Found {len(hbn_aparc_fs_data)} FreeSurfer aparc subjects with 136 features (68 SA + 68 Th)")

hbn_aparc_fs_feature_info = []
if hbn_aparc_region_info:
    for i in range(68):
        hbn_aparc_fs_feature_info.append({
            'feature_name': hbn_aparc_region_info['region_names'][i],
            'feature_type': 'Surface_Area',
            'hemisphere': hbn_aparc_region_info['hemisphere'][i]
        })
    for i in range(68):
        hbn_aparc_fs_feature_info.append({
            'feature_name': hbn_aparc_region_info['region_names'][i],
            'feature_type': 'Thickness',
            'hemisphere': hbn_aparc_region_info['hemisphere'][i]
        })

hbn_common_subjects = sorted(list(set(hbn_cat12_data.keys()) & set(hbn_brainiac_data.keys()) & set(hbn_fs_data.keys()) & set(hbn_aparc_fs_data.keys())))
print(f"  Common subjects: {len(hbn_common_subjects)}")

hbn_cat12_paths = [hbn_cat12_data[s] for s in hbn_common_subjects]
hbn_brainiac_paths = [hbn_brainiac_data[s] for s in hbn_common_subjects]
hbn_fs_features = np.array([hbn_fs_data[s] for s in hbn_common_subjects])
hbn_sex_labels = np.array([hbn_id_sex_dict[s] for s in hbn_common_subjects])
hbn_age_labels = np.array([hbn_id_age_dict[s] for s in hbn_common_subjects])

hbn_anatcl_features, hbn_brainiac_features, hbn_cnn_features = extract_features(
    hbn_cat12_paths, hbn_brainiac_paths, hbn_age_labels, device)

hbn_features_dict = {
    'AnatCL': hbn_anatcl_features,
    'BrainIAC': hbn_brainiac_features,
    'CNN': hbn_cnn_features,
    'FreeSurfer': hbn_fs_features
}

# Create aparc features dict
hbn_aparc_fs_features = np.array([hbn_aparc_fs_data[s] for s in hbn_common_subjects])
hbn_aparc_features_dict = {
    'AnatCL': hbn_anatcl_features,
    'BrainIAC': hbn_brainiac_features,
    'CNN': hbn_cnn_features,
    'FreeSurfer': hbn_aparc_fs_features
}

print("\nComputing HBN Schaefer correlation matrix...")
hbn_corr_matrix, hbn_boundaries, hbn_model_names, hbn_reorder_indices, hbn_cluster_df = \
    compute_correlation_with_cluster_labels(hbn_features_dict, hbn_fs_feature_info, "HBN_Schaefer")

print("\nComputing HBN aparc correlation matrix...")
hbn_aparc_corr_matrix, hbn_aparc_boundaries, hbn_aparc_model_names, hbn_aparc_reorder_indices, hbn_aparc_cluster_df = \
    compute_correlation_with_cluster_labels(hbn_aparc_features_dict, hbn_aparc_fs_feature_info, "HBN_aparc")

print("\nHBN Sex Classification")
hbn_sex_results = run_classification(hbn_features_dict, hbn_sex_labels, task_name="HBN Sex")

hbn_male_indices = np.where(hbn_sex_labels == 0)[0]
hbn_male_features_dict = {k: v[hbn_male_indices] for k, v in hbn_features_dict.items()}
hbn_male_age_labels = hbn_age_labels[hbn_male_indices]
print(f"\nHBN Male subjects: {len(hbn_male_indices)}")
hbn_male_age_results = run_regression(hbn_male_features_dict, hbn_male_age_labels, task_name="HBN Male Age")

hbn_female_indices = np.where(hbn_sex_labels == 1)[0]
hbn_female_features_dict = {k: v[hbn_female_indices] for k, v in hbn_features_dict.items()}
hbn_female_age_labels = hbn_age_labels[hbn_female_indices]
print(f"\nHBN Female subjects: {len(hbn_female_indices)}")
hbn_female_age_results = run_regression(hbn_female_features_dict, hbn_female_age_labels, task_name="HBN Female Age")

# Run HBN aparc experiments
print("\n\nHBN APARC EXPERIMENTS")
print("\nHBN aparc Sex Classification")
hbn_aparc_sex_results = run_classification(hbn_aparc_features_dict, hbn_sex_labels, task_name="HBN aparc Sex")

hbn_aparc_male_features_dict = {k: v[hbn_male_indices] for k, v in hbn_aparc_features_dict.items()}
print(f"\nHBN aparc Male subjects: {len(hbn_male_indices)}")
hbn_aparc_male_age_results = run_regression(hbn_aparc_male_features_dict, hbn_male_age_labels, task_name="HBN aparc Male Age")

hbn_aparc_female_features_dict = {k: v[hbn_female_indices] for k, v in hbn_aparc_features_dict.items()}
print(f"\nHBN aparc Female subjects: {len(hbn_female_indices)}")
hbn_aparc_female_age_results = run_regression(hbn_aparc_female_features_dict, hbn_female_age_labels, task_name="HBN aparc Female Age")


print("\n\nGENERATING COMBINED FIGURE")
models = ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer']
model_styles = {
    'AnatCL': {'color': '#9B59B6', 'marker': '^', 'linestyle': '-', 'label_prefix': 'FM: '},
    'BrainIAC': {'color': '#E74C3C', 'marker': '^', 'linestyle': '-', 'label_prefix': 'FM: '},
    'CNN': {'color': '#3498DB', 'marker': 'o', 'linestyle': '--', 'label_prefix': ''},
    'FreeSurfer': {'color': '#E67E22', 'marker': 's', 'linestyle': '-', 'label_prefix': ''}
}
alphas = [0.01, 0.05, 0.1, 0.15, 0.25, 0.4, 0.6, 0.8, 1.0]

ppmi_n_cv_sex = len(ppmi_sex_labels) - int(0.1 * len(ppmi_sex_labels))
ppmi_n_cv_parkinson = len(ppmi_parkinson_labels) - int(0.1 * len(ppmi_parkinson_labels))
ppmi_n_cv_male = len(ppmi_male_age_labels) - int(0.1 * len(ppmi_male_age_labels))
ppmi_n_cv_female = len(ppmi_female_age_labels) - int(0.1 * len(ppmi_female_age_labels))

hbn_n_cv_sex = len(hbn_sex_labels) - int(0.1 * len(hbn_sex_labels))
hbn_n_cv_male = len(hbn_male_age_labels) - int(0.1 * len(hbn_male_age_labels))
hbn_n_cv_female = len(hbn_female_age_labels) - int(0.1 * len(hbn_female_age_labels))

ppmi_sizes_sex = [int(a * ppmi_n_cv_sex * 0.8) for a in alphas]
ppmi_sizes_parkinson = [int(a * ppmi_n_cv_parkinson * 0.8) for a in alphas]
ppmi_sizes_male = [int(a * ppmi_n_cv_male * 0.8) for a in alphas]
ppmi_sizes_female = [int(a * ppmi_n_cv_female * 0.8) for a in alphas]

hbn_sizes_sex = [int(a * hbn_n_cv_sex * 0.8) for a in alphas]
hbn_sizes_male = [int(a * hbn_n_cv_male * 0.8) for a in alphas]
hbn_sizes_female = [int(a * hbn_n_cv_female * 0.8) for a in alphas]

fig, axes = plt.subplots(4, 2, figsize=(20, 24))

fig.text(0.28, 0.95, 'PPMI', ha='center', fontsize=18, weight='bold')
fig.text(0.73, 0.95, 'HBN', ha='center', fontsize=18, weight='bold')

def plot_classification(ax, results, training_sizes, ylabel=None, show_legend=False):
    for model in models:
        style = model_styles[model]
        acc_means, acc_stds, valid_sizes = [], [], []
        for i, alpha in enumerate(alphas):
            if alpha in results[model]['cv_results']['acc']:
                acc_means.append(results[model]['cv_results']['acc'][alpha])
                acc_stds.append(results[model]['cv_results']['acc_std'][alpha])
                valid_sizes.append(training_sizes[i])
        acc_means, acc_stds, valid_sizes = np.array(acc_means), np.array(acc_stds), np.array(valid_sizes)
        ax.plot(valid_sizes, acc_means, marker=style['marker'], linestyle=style['linestyle'],
                linewidth=3, markersize=10, label=style['label_prefix'] + model, color=style['color'])
        ax.fill_between(valid_sizes, acc_means - acc_stds, acc_means + acc_stds, alpha=0.2, color=style['color'])
    
    dummy_acc = results['DummyClassifier']['cv_acc_mean']
    ax.axhline(y=dummy_acc, color='gray', linestyle=':', linewidth=2, label=f'Baseline ({dummy_acc:.3f})', alpha=0.7)
    ax.set_ylim([0, 1])
    ax.grid(True, alpha=0.3)
    if ylabel:
        ax.set_ylabel(ylabel, fontsize=12, weight='bold')
    ax.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
    if show_legend:
        ax.legend(fontsize=9, loc='lower right')

def plot_regression(ax, results, training_sizes, ylabel=None, show_legend=False, ylim=None):
    for model in models:
        style = model_styles[model]
        mae_means, mae_stds, valid_sizes = [], [], []
        for i, alpha in enumerate(alphas):
            if alpha in results[model]['cv_results']['mae']:
                mae_means.append(results[model]['cv_results']['mae'][alpha])
                mae_stds.append(results[model]['cv_results']['mae_std'][alpha])
                valid_sizes.append(training_sizes[i])
        mae_means, mae_stds, valid_sizes = np.array(mae_means), np.array(mae_stds), np.array(valid_sizes)
        ax.plot(valid_sizes, mae_means, marker=style['marker'], linestyle=style['linestyle'],
                linewidth=3, markersize=10, label=style['label_prefix'] + model, color=style['color'])
        ax.fill_between(valid_sizes, mae_means - mae_stds, mae_means + mae_stds, alpha=0.2, color=style['color'])
    
    dummy_mae = results['DummyRegressor']['cv_mae_mean']
    ax.axhline(y=dummy_mae, color='gray', linestyle=':', linewidth=2, label=f'Baseline ({dummy_mae:.2f})', alpha=0.7)
    ax.grid(True, alpha=0.3)
    if ylim:
        ax.set_ylim(ylim)
    if ylabel:
        ax.set_ylabel(ylabel, fontsize=12, weight='bold')
    ax.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
    if show_legend:
        ax.legend(fontsize=9, loc='upper right')

plot_classification(axes[0, 0], ppmi_sex_results, ppmi_sizes_sex, ylabel='Balanced Accuracy', show_legend=True)
plot_classification(axes[0, 1], hbn_sex_results, hbn_sizes_sex, show_legend=True)
axes[0, 0].text(-0.15, 0.5, 'Sex Classification', transform=axes[0, 0].transAxes, fontsize=14, weight='bold', 
                va='center', ha='center', rotation=90)

# Compute shared y-axis limits for all MAE plots
all_mae_values = []
for results in [ppmi_male_age_results, ppmi_female_age_results, hbn_male_age_results, hbn_female_age_results]:
    all_mae_values.append(results['DummyRegressor']['cv_mae_mean'])
    for model in models:
        for alpha in alphas:
            if alpha in results[model]['cv_results']['mae']:
                mae = results[model]['cv_results']['mae'][alpha]
                mae_std = results[model]['cv_results']['mae_std'][alpha]
                all_mae_values.extend([mae - mae_std, mae + mae_std])
mae_ylim = (0, max(all_mae_values) * 1.05)

plot_regression(axes[1, 0], ppmi_male_age_results, ppmi_sizes_male, ylabel='MAE (years)', show_legend=True, ylim=mae_ylim)
plot_regression(axes[1, 1], hbn_male_age_results, hbn_sizes_male, ylim=mae_ylim, show_legend=True)
axes[1, 0].text(-0.15, 0.5, 'Male Age Prediction', transform=axes[1, 0].transAxes, fontsize=14, weight='bold',
                va='center', ha='center', rotation=90)

plot_regression(axes[2, 0], ppmi_female_age_results, ppmi_sizes_female, ylabel='MAE (years)', ylim=mae_ylim, show_legend=True)
plot_regression(axes[2, 1], hbn_female_age_results, hbn_sizes_female, ylim=mae_ylim, show_legend=True)
axes[2, 0].text(-0.15, 0.5, 'Female Age Prediction', transform=axes[2, 0].transAxes, fontsize=14, weight='bold',
                va='center', ha='center', rotation=90)

plot_classification(axes[3, 0], ppmi_parkinson_results, ppmi_sizes_parkinson, ylabel='Balanced Accuracy', show_legend=True)
axes[3, 1].set_xticks([])
axes[3, 1].set_yticks([])
for spine in axes[3, 1].spines.values():
    spine.set_visible(False)
axes[3, 0].text(-0.15, 0.5, 'Parkinson Classification', transform=axes[3, 0].transAxes, fontsize=14, weight='bold',
                va='center', ha='center', rotation=90)

plt.tight_layout(rect=[0.05, 0, 1, 0.94])
plt.savefig('combined_learning_curves.pdf', dpi=300, bbox_inches='tight')
plt.close()
print("Saved: combined_learning_curves.pdf")

# Generate HBN aparc learning curves (separate PDF)
print("\nGenerating HBN aparc learning curves...")
fig_aparc, axes_aparc = plt.subplots(3, 1, figsize=(10, 18))

fig_aparc.suptitle('HBN aparc Parcellation (68 regions)', fontsize=16, weight='bold')

plot_classification(axes_aparc[0], hbn_aparc_sex_results, hbn_sizes_sex, ylabel='Balanced Accuracy', show_legend=True)
axes_aparc[0].set_title('Sex Classification', fontsize=14, weight='bold')

# Compute shared y-axis limits for aparc MAE plots
aparc_mae_values = []
for results in [hbn_aparc_male_age_results, hbn_aparc_female_age_results]:
    aparc_mae_values.append(results['DummyRegressor']['cv_mae_mean'])
    for model in models:
        for alpha in alphas:
            if alpha in results[model]['cv_results']['mae']:
                mae = results[model]['cv_results']['mae'][alpha]
                mae_std = results[model]['cv_results']['mae_std'][alpha]
                aparc_mae_values.extend([mae - mae_std, mae + mae_std])
aparc_mae_ylim = (0, max(aparc_mae_values) * 1.05)

plot_regression(axes_aparc[1], hbn_aparc_male_age_results, hbn_sizes_male, ylabel='MAE (years)', show_legend=True, ylim=aparc_mae_ylim)
axes_aparc[1].set_title('Male Age Prediction', fontsize=14, weight='bold')

plot_regression(axes_aparc[2], hbn_aparc_female_age_results, hbn_sizes_female, ylabel='MAE (years)', show_legend=True, ylim=aparc_mae_ylim)
axes_aparc[2].set_title('Female Age Prediction', fontsize=14, weight='bold')

plt.tight_layout()
plt.savefig('hbn_aparc_learning_curves.pdf', dpi=300, bbox_inches='tight')
plt.close()
print("Saved: hbn_aparc_learning_curves.pdf")

print("\nGenerating correlation matrix plots...")
plot_correlation_with_labels(ppmi_corr_matrix, ppmi_boundaries, ppmi_model_names, ppmi_cluster_df, "PPMI")
plot_correlation_with_labels(hbn_corr_matrix, hbn_boundaries, hbn_model_names, hbn_cluster_df, "HBN_Schaefer")
plot_correlation_with_labels(hbn_aparc_corr_matrix, hbn_aparc_boundaries, hbn_aparc_model_names, hbn_aparc_cluster_df, "HBN_aparc")

ppmi_cluster_df.to_csv('ppmi_freesurfer_features_reordered.csv', index=False)
hbn_cluster_df.to_csv('hbn_schaefer_freesurfer_features_reordered.csv', index=False)
hbn_aparc_cluster_df.to_csv('hbn_aparc_freesurfer_features_reordered.csv', index=False)
print("Saved: ppmi_freesurfer_features_reordered.csv")
print("Saved: hbn_schaefer_freesurfer_features_reordered.csv")
print("Saved: hbn_aparc_freesurfer_features_reordered.csv")

print(f"\nDONE - PPMI: {len(ppmi_common_subjects)} subjects, HBN: {len(hbn_common_subjects)} subjects")

Overwriting /home/arelbaha/links/projects/rrg-glatard/arelbaha/combined_multimodel_comparison.py
