# MRI Foundation Models: Pre-Trained Models' Evaluation

## SPM/CAT12 Preprocessing | AnatCL | HBN 

In [None]:
%%writefile ~/links/projects/rrg-glatard/arelbaha/HBN_BIDS/new_cat12_preprocessing.py

import os
import glob
import boutiques
from boutiques import bosh
from boutiques.descriptor2func import function

boutiques_descriptor = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/descriptors/cat12_prepro.json"
base_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_BIDS"
output_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/preprocessed_output"

bosh(["exec", 
      "prepare", 
      boutiques_descriptor, 
      "--imagepath", 
      "/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif"
])

cat12 = boutiques.descriptor2func.function(boutiques_descriptor)

t1_nii_files = glob.glob(os.path.join(base_dir, "sub-*", "ses-*", "anat", "sub-*_T1w.nii"))
print(f"Found {len(t1_nii_files)} T1w files.")

task_id = int(os.environ["SLURM_ARRAY_TASK_ID"])

if task_id >= len(t1_nii_files) or task_id < 0:
    print(f"SLURM_ARRAY_TASK_ID={task_id} out of range")
    exit(1)

input_file = t1_nii_files[task_id]

path_parts = input_file.split(os.sep)
subject_id = next((part for part in path_parts if part.startswith("sub-")), None)

if subject_id is None:
        print(f"No subject ID in path: {input_file}")
        exit(1)

filename = os.path.basename(input_file)

subject_output_dir = os.path.join(output_dir, subject_id)
os.makedirs(subject_output_dir, exist_ok=True)

print(f"Processing={task_id} -> file: {filename} (subject: {subject_id})")

result = cat12('--imagepath=/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif', 
                   input_file=input_file, 
                   output_dir=subject_output_dir)

result_dict = vars(result)
print("Available result keys:", result_dict.keys())
print("\nExit code:", result_dict.get("exit_code"))
print("\nStdout:\n", result_dict.get ("stdout" ))
print("\nStderr:\n", result_dict.get ("stderr"))
print("\nOutput files:\n", result_dict.get("output_files"))

In [None]:
%%sbatch --array=500-520
#!/bin/bash
#SBATCH --job-name=CAT12_preproc
#SBATCH --account=rrg-glatard
#SBATCH --mem=12G
#SBATCH --cpus-per-task=2
#SBATCH --nodes=1
#SBATCH --output=CAT12_preproc_%A_%a.out
#SBATCH --error=CAT12_preproc_%A_%a.err
#SBATCH --time=1:0:0

source ~/.venvs/jupyter_py3/bin/activate

module load apptainer

cd ~/links/projects/rrg-glatard/arelbaha/HBN_BIDS #Location of preprocessing data and script

echo "Running task ID: $SLURM_ARRAY_TASK_ID"

python new_cat12_preprocessing.py

## SPM/CAT12 Preprocessing | AnatCL | PPMI

In [None]:
%%writefile ~/links/projects/def-glatard/arelbaha/data/inputs/ppmicat12_preprocessing.py

import os
import glob
import boutiques
from boutiques import bosh
from boutiques.descriptor2func import function

#Downloading Container Part + Paths
boutiques_descriptor = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/descriptors/cat12_prepro.json"
base_dir = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs"
output_dir = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/preprocessed_output"

bosh(["exec", "prepare", boutiques_descriptor, "--imagepath", "/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif"])
cat12 = boutiques.descriptor2func.function(boutiques_descriptor)

#Task ID Extraction
data = glob.glob(os.path.join(base_dir, "sub-*", "ses-*", "anat", "*.nii"))
print(f"Found {len(data)}")

task_id = int(os.environ["SLURM_ARRAY_TASK_ID"])

if task_id >= len(data) or task_id < 0:
    print(f"SLURM_ARRAY_TASK_ID={task_id} out of range")
    exit(1)

input_file = data[task_id]

path_parts = input_file.split(os.sep)
subject_id = next((part for part in path_parts if part.startswith("sub-")), None)

if subject_id is None:
        print(f"Could not find subject ID in path: {input_file}")
        exit(1)

filename = os.path.basename(input_file)

subject_output_dir = os.path.join(output_dir, subject_id)
os.makedirs(subject_output_dir, exist_ok=True)

print(f"Processing SLURM_ARRAY_TASK_ID ={task_id} -> file: {filename} (subject: {subject_id})")

result = cat12('--imagepath=/home/arelbaha/links/projects/rrg-glatard/arelbaha/containers/cat12_prepro.sif', 
                   input_file=input_file, 
                   output_dir=subject_output_dir)

result_dict = vars(result)
print("Available result keys:", result_dict.keys())
print("\nExit code:", result_dict.get("exit_code"))
print("\nStdout:\n", result_dict.get ("stdout" ))
print("\nStderr:\n", result_dict.get ("stderr"))
print("\nOutput files:\n", result_dict.get("output_files"))

In [None]:
%%sbatch --array=1040-1040
#!/bin/bash
#SBATCH --job-name=CAT12_parkinson_preproc
#SBATCH --account=rrg-glatard
#SBATCH --mem=12G
#SBATCH --cpus-per-task=2
#SBATCH --nodes=1
#SBATCH --output=parkinson_preproc_%A_%a.out
#SBATCH --error=parkinson_preproc_%A_%a.err
#SBATCH --time=1:0:0

source ~/.venvs/jupyter_py3/bin/activate

module load apptainer

cd ~/links/projects/def-glatard/arelbaha/data/inputs #Location of preprocessing data and script

echo "Running task ID: $SLURM_ARRAY_TASK_ID"

python classification_parkinson.py

### BrainIAC and CNN Preprocessing | HD-BET | PPMI

In [None]:
%%sbatch --array=23,24
#!/bin/bash
#SBATCH --account=def-glatard
#SBATCH --time=20:00:00
#SBATCH --mem=16G
#SBATCH --cpus-per-task=4
#SBATCH --job-name=brainiac_batch
#SBATCH --output=logs/prep_batch_%a.out
#SBATCH --error=logs/prep_batch_%a.err

export BASE_DIR="/home/arelbaha/links/projects/def-glatard/arelbaha/data"
export RAW_DIR="${BASE_DIR}/raw_files_brainiac"
export OUTPUT_DIR="${BASE_DIR}/processed_outputs"

# Load environment
module load python/3.11
module load opencv
source /home/arelbaha/.venvs/brainiac_env/bin/activate

BATCH_DIR="${RAW_DIR}/batch_${SLURM_ARRAY_TASK_ID}"
BATCH_OUTPUT="${OUTPUT_DIR}/batch_${SLURM_ARRAY_TASK_ID}"

mkdir -p "$BATCH_OUTPUT"

echo "Processing batch ${SLURM_ARRAY_TASK_ID} from ${BATCH_DIR}"

python /home/arelbaha/.venvs/brainiac_env/lib/python3.11/site-packages/BrainIAC/src/preprocessing/mri_preprocess_3d_simple.py \
    --temp_img /home/arelbaha/.venvs/brainiac_env/lib/python3.11/site-packages/BrainIAC/src/preprocessing/atlases/temp_head.nii.gz \
    --input_dir "$BATCH_DIR" \
    --output_dir "$BATCH_OUTPUT"

echo "Completed batch ${SLURM_ARRAY_TASK_ID}"

## PPMI Multi-task Evaluation | AnatCL vs BrainIAC vs CNN vs FreeSurfer

In [2]:
%%writefile ~/links/projects/def-glatard/arelbaha/data/inputs/ppmi_multimodel_comparison.py

import os
import glob
import random
import numpy as np
import pandas as pd

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import nibabel as nib
from anatcl import AnatCL
from monai.networks.nets import ViT
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.dummy import DummyClassifier, DummyRegressor
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns
import scipy.ndimage as ndi

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

CAT12_BASE_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs"
DATA_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data"
LABELS_PATH = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/processed_cohort_with_mri.csv"
BRAINIAC_MAPPING_CSV = os.path.join(DATA_DIR, "processed_files_mapping.csv")
ANATCL_ENCODER_PATH = "/home/arelbaha/.venvs/jupyter_py3/bin"
BRAINIAC_CKPT = "/home/arelbaha/.venvs/jupyter_py3/bin/BrainIAC.ckpt"
DROPOUT_RATE = 0.3

device = "cpu"

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"Random seed: {SEED}\n")

# Load demographic labels
print("Loading demographics")

labels_df = pd.read_csv(LABELS_PATH)
id_sex_dict = {}
id_parkinson_dict = {}
id_age_dict = {}

for _, row in labels_df.iterrows():
    patno = str(int(row['PATNO']))
    sex = row['Sex'].strip().upper()
    if sex == 'F':
        id_sex_dict[patno] = 1
    elif sex == 'M':
        id_sex_dict[patno] = 0
    group = row['Group'].strip()
    if group == 'PD':
        id_parkinson_dict[patno] = 1
    elif group == 'HC':
        id_parkinson_dict[patno] = 0
    id_age_dict[patno] = row['Age']

print(f"Loaded demographics for {len(id_sex_dict)} subjects")
print(f"Sex distribution: {sum(id_sex_dict.values())} Female, {len(id_sex_dict) - sum(id_sex_dict.values())} Male")
print(f"PD distribution: {sum(id_parkinson_dict.values())} PD, {len(id_parkinson_dict) - sum(id_parkinson_dict.values())} HC")
print(f"Age range: {min(id_age_dict.values()):.1f} - {max(id_age_dict.values()):.1f} years\n")

def extract_patno_from_path(filepath):
    for part in filepath.split(os.sep):
        if part.startswith('sub-'):
            return part[4:]
    return None

# Load CAT12 VBM data
print("Finding CAT12 (s6mwp1) files")

cat12_files = glob.glob(os.path.join(CAT12_BASE_DIR, "**", "*s6mwp1*.nii*"), recursive=True)
cat12_data = {}
for f in cat12_files:
    if not os.path.isfile(f):
        continue
    patno = extract_patno_from_path(f)
    if patno and patno in id_sex_dict and patno in id_parkinson_dict and patno in id_age_dict:
        cat12_data[patno] = f

print(f"Found {len(cat12_data)} CAT12 files\n")

# Load BrainIAC data
print("Finding BrainIAC preprocessed files")

brainiac_df = pd.read_csv(BRAINIAC_MAPPING_CSV).dropna(subset=["processed_file", "Age", "subject_id"])
brainiac_data = {}
for _, row in brainiac_df.iterrows():
    patno = str(row['subject_id'])
    if patno in id_sex_dict and patno in id_parkinson_dict and patno in id_age_dict:
        if os.path.exists(row['processed_file']):
            brainiac_data[patno] = row['processed_file']

print(f"Found {len(brainiac_data)} BrainIAC files\n")

# Load FreeSurfer data
print("Extracting FreeSurfer features")

fs_cth_df = pd.read_csv(os.path.join(CAT12_BASE_DIR, "FS7_APARC_CTH_23Oct2025.csv"))
fs_sa_df = pd.read_csv(os.path.join(CAT12_BASE_DIR, "FS7_APARC_SA_23Oct2025.csv"))
fs_cth_df = fs_cth_df[fs_cth_df['EVENT_ID'] == 'BL'].copy()
fs_sa_df = fs_sa_df[fs_sa_df['EVENT_ID'] == 'BL'].copy()
fs_cth_df['PATNO'] = fs_cth_df['PATNO'].astype(str)
fs_sa_df['PATNO'] = fs_sa_df['PATNO'].astype(str)
cth_features = [c for c in fs_cth_df.columns if c not in ['PATNO', 'EVENT_ID']]
sa_features = [c for c in fs_sa_df.columns if c not in ['PATNO', 'EVENT_ID']]

# Overall Common Subjects
print("Overall common subjects across all modalities")

common_subjects = sorted(list(set(cat12_data.keys()) & set(brainiac_data.keys())))
fs_data = {}
for patno in common_subjects:
    cth_row = fs_cth_df[fs_cth_df['PATNO'] == patno]
    sa_row = fs_sa_df[fs_sa_df['PATNO'] == patno]
    if len(cth_row) > 0 and len(sa_row) > 0:
        combined = np.concatenate([cth_row[cth_features].values.flatten(), 
                                  sa_row[sa_features].values.flatten()])
        if not np.any(np.isnan(combined)):
            fs_data[patno] = combined

common_subjects = sorted(list(set(cat12_data.keys()) & set(brainiac_data.keys()) & set(fs_data.keys())))

print(f"Subjects with CAT12: {len(cat12_data)}")
print(f"Subjects with BrainIAC: {len(brainiac_data)}")
print(f"Subjects with FreeSurfer: {len(fs_data)}")
print(f"Common subjects (all modalities): {len(common_subjects)}\n")

if len(common_subjects) == 0:
    print("ERROR: No common subjects found across all modalities")
    exit(1)

cat12_paths = [cat12_data[p] for p in common_subjects]
brainiac_paths = [brainiac_data[p] for p in common_subjects]
fs_features = np.array([fs_data[p] for p in common_subjects])
sex_labels = np.array([id_sex_dict[p] for p in common_subjects])
parkinson_labels = np.array([id_parkinson_dict[p] for p in common_subjects])
age_labels = np.array([id_age_dict[p] for p in common_subjects])

print(f"Final matched dataset: {len(common_subjects)} subjects")
print(f"FreeSurfer features shape: {fs_features.shape}")
print(f"Sex distribution: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"PD distribution: {np.sum(parkinson_labels)} PD, {len(parkinson_labels) - np.sum(parkinson_labels)} HC")
print(f"Age: mean={age_labels.mean():.1f}, range={age_labels.min():.1f}-{age_labels.max():.1f}\n")

pd.DataFrame({
    'subject_id': common_subjects,
    'cat12_path': cat12_paths,
    'brainiac_path': brainiac_paths,
    'sex': sex_labels,
    'parkinson': parkinson_labels,
    'age': age_labels
}).to_csv('ppmi_matched_subjects.csv', index=False)
print("Saved: ppmi_matched_subjects.csv\n")

#Dataset Classes

class CAT12VBMDataset(Dataset):
    def __init__(self, data, labels, transform):
        self.data, self.labels, self.transform = data, labels, transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img = self.transform(nib.load(self.data[idx]).get_fdata()).unsqueeze(0)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

class BrainIACDataset(Dataset):
    def __init__(self, paths, labels, transform):
        self.paths, self.labels, self.transform = paths, labels, transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (96, 96, 96):
            img = F.interpolate(torch.from_numpy(img[None, None]), size=(96, 96, 96), 
                              mode='trilinear', align_corners=False).squeeze().numpy()
        img_tensor = torch.from_numpy(img)
        img_normalized = self.transform(img_tensor)
        return img_normalized[None], torch.tensor(self.labels[idx], dtype=torch.float32)

class CNNDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (92, 110, 92):
            img = ndi.zoom(img, [92/img.shape[0], 110/img.shape[1], 92/img.shape[2]], order=1)
        img = (img - img.mean()) / (img.std() + 1e-6)
        return torch.from_numpy(img[None]).float()

class CNN3D(nn.Module):
    def __init__(self, dropout_rate=DROPOUT_RATE):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 64, 3, padding=1)
        self.bn1, self.pool1, self.drop1 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv2 = nn.Conv3d(64, 64, 3, padding=1)
        self.bn2, self.pool2, self.drop2 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv3 = nn.Conv3d(64, 128, 3, padding=1)
        self.bn3, self.pool3, self.drop3 = nn.BatchNorm3d(128), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv4 = nn.Conv3d(128, 256, 3, padding=1)
        self.bn4, self.pool4, self.drop4 = nn.BatchNorm3d(256), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1, self.drop5 = nn.Linear(256, 512), nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.drop1(self.bn1(self.pool1(F.relu(self.conv1(x)))))
        x = self.drop2(self.bn2(self.pool2(F.relu(self.conv2(x)))))
        x = self.drop3(self.bn3(self.pool3(F.relu(self.conv3(x)))))
        x = self.drop4(self.bn4(self.pool4(F.relu(self.conv4(x)))))
        return self.drop5(F.relu(self.fc1(self.global_pool(x).view(x.size(0), -1))))

def load_brainiac_vit(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    vit = ViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16),
              hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, classification=False)
    pe_key = "backbone.patch_embedding.position_embeddings"
    if pe_key in ckpt.get("state_dict", ckpt):
        vit.patch_embedding.position_embeddings = nn.Parameter(ckpt["state_dict"][pe_key].clone())
    backbone_state = {k.replace("backbone.", ""): v for k, v in ckpt.get("state_dict", ckpt).items() 
                     if k.startswith("backbone.")}
    vit.load_state_dict(backbone_state, strict=False)
    return vit.to(device).eval()

# Extracting AnatCL Features
print("Extracting AnatCL features (averaging across 5 cross-validation folds)")

anatcl_transform = transforms.Compose([
    transforms.Lambda(lambda x: torch.from_numpy(x.copy()).float()),
    transforms.Normalize(mean=0.0, std=1.0)
])

num_folds = 5
all_fold_features = []

for fold_idx in range(num_folds):
    path = os.path.join(ANATCL_ENCODER_PATH, f"fold{fold_idx}.pth")
    
    if not os.path.exists(path):
        print(f"ERROR: fold{fold_idx}.pth not found at {path}")
        exit(1)
    
    print(f"Loading fold {fold_idx}...")
    encoder = AnatCL(descriptor="global", fold=fold_idx, pretrained=False).to(device).eval()
    encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
    
    for p in encoder.parameters():
        p.requires_grad = False
    
    dl = DataLoader(CAT12VBMDataset(cat12_paths, age_labels, anatcl_transform), 
                    batch_size=32, num_workers=0)
    
    with torch.no_grad():
        fold_features = torch.cat([encoder(vol.to(device)).cpu() for vol, _ in dl]).numpy()
    
    all_fold_features.append(fold_features)
    print(f"  Fold {fold_idx} features: {fold_features.shape}")
    
    del encoder
    torch.cuda.empty_cache()

anatcl_features = np.mean(all_fold_features, axis=0)

print(f"\nAnatCL features (averaged across {num_folds} folds): {anatcl_features.shape}\n")

#Extracting BrainIAC Features
print("Extracting BrainIAC features")

brainiac_vit = load_brainiac_vit(BRAINIAC_CKPT, device)
brainiac_transform = lambda x: (x - x.mean()) / (x.std() + 1e-6)
dl = DataLoader(BrainIACDataset(brainiac_paths, age_labels, brainiac_transform), 
                batch_size=16, num_workers=0)

brainiac_features = []

with torch.no_grad():
    for x, _ in dl:
        out = brainiac_vit(x.to(device))
        
        # Handle tuple or tensor
        if isinstance(out, tuple):
            cls_token = out[0][:, 0]
        else:
            cls_token = out[:, 0]
        
        brainiac_features.append(cls_token.cpu().numpy())

brainiac_features = np.vstack(brainiac_features)

print(f"BrainIAC features: {brainiac_features.shape}")
del brainiac_vit
torch.cuda.empty_cache()

# Extracting CNN Features
print("Extracting CNN features")

cnn_model = CNN3D().to(device).eval()
dl = DataLoader(CNNDataset(brainiac_paths), batch_size=8, num_workers=0)

with torch.no_grad():
    cnn_features = np.vstack([cnn_model(x.to(device)).cpu().numpy() for x in dl])

print(f"CNN features: {cnn_features.shape}\n")
del cnn_model

# Organize Features
features_dict = {
    'AnatCL': anatcl_features,
    'BrainIAC': brainiac_features,
    'CNN': cnn_features,
    'FreeSurfer': fs_features
}
print("Feature extraction complete")
for name, feats in features_dict.items():
    print(f"{name}: {feats.shape}")
print()

# Correlation Analysis
print("Computing cross-model feature correlations")

def compute_correlation_within_models(features_dict):
    all_features = []
    boundaries = [0]
    reordered_indices = {}

    for model in ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer']:
        if model not in features_dict:
            continue

        feats = features_dict[model]
        feats = StandardScaler().fit_transform(feats)

        valid = ~np.isnan(feats).any(axis=0) & (np.std(feats, axis=0) > 1e-10)
        feats = feats[:, valid]

        if feats.shape[1] == 0:
            continue

        print(f"  {model}: {feats.shape[1]} features")

        corr = np.corrcoef(feats.T)
        corr = np.nan_to_num((corr + corr.T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
        np.fill_diagonal(corr, 1.0)

        dist = np.abs(1 - np.abs(corr))
        np.fill_diagonal(dist, 0)
        condensed = np.nan_to_num(squareform((dist + dist.T) / 2, checks=False), 
                                  nan=0.0, posinf=1.0, neginf=0.0)

        reorder_idx = dendrogram(linkage(condensed, method='ward'), no_plot=True)['leaves']
        all_features.append(feats[:, reorder_idx])
        boundaries.append(boundaries[-1] + feats.shape[1])
        reordered_indices[model] = reorder_idx

    combined = np.hstack(all_features)
    corr = np.nan_to_num((np.corrcoef(combined.T) + np.corrcoef(combined.T).T) / 2, 
                         nan=0.0, posinf=1.0, neginf=-1.0)
    np.fill_diagonal(corr, 1.0)

    print(f"\nTotal features: {combined.shape[1]}")
    
    return corr, boundaries, ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer'], reordered_indices

corr_matrix, boundaries, model_names, reorder_indices = compute_correlation_within_models(features_dict)
pd.DataFrame(corr_matrix).to_csv('ppmi_feature_correlation_matrix.csv', index=False)
print(f"Correlation matrix: {corr_matrix.shape}")
print(f"Model boundaries: {boundaries}\n")

# Plot correlation matrix (FULL SIZE)
print("Plotting cross-model correlation matrix...")

plt.figure(figsize=(40, 38))
sns.heatmap(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, center=0, square=True, 
            linewidths=0, cbar_kws={"shrink": 0.3},
            xticklabels=False, yticklabels=False)

for b in boundaries[1:-1]:
    plt.axhline(y=b, color='black', linewidth=4)
    plt.axvline(x=b, color='black', linewidth=4)

for i, (pos, name) in enumerate(zip([(boundaries[i] + boundaries[i+1]) / 2 
                                      for i in range(len(boundaries)-1)], model_names)):
    plt.text(pos, -30, name, ha='center', fontsize=24, weight='bold')
    plt.text(-30, pos, name, ha='center', va='center', fontsize=24, weight='bold', rotation=90)

plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig('ppmi_cross_model_correlation.png', dpi=300, bbox_inches='tight')
plt.close()
print("Saved: ppmi_cross_model_correlation.png\n")

# Model Training Functions

def run_classification(X_dict, y, test_size=0.1, task_name="Classification"):
    print(f"Running {task_name}...")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), 
                                        stratify=y, random_state=SEED)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.1, 0.25, 0.5, 0.75, 1.0]

    print("Training DummyClassifier baseline...")
    dummy_val_auc = []
    dummy_val_acc = []
    
    for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyClassifier(strategy='stratified', random_state=SEED)
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        pred_proba = dummy.predict_proba(np.zeros((len(val_idx), 1)))[:, 1]
        
        dummy_val_acc.append(balanced_accuracy_score(y[val_idx], pred))
        dummy_val_auc.append(roc_auc_score(y[val_idx], pred_proba))
    
    dummy = DummyClassifier(strategy='stratified', random_state=SEED)
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    dummy_test_proba = dummy.predict_proba(np.zeros((len(test_idx), 1)))[:, 1]
    
    results['DummyClassifier'] = {
        'best_alpha': 1.0,
        'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], dummy_test_pred),
        'test_auc': roc_auc_score(y[test_idx], dummy_test_proba),
        'cv_auc_mean': np.mean(dummy_val_auc),
        'cv_auc_std': np.std(dummy_val_auc),
        'cv_acc_mean': np.mean(dummy_val_acc),
        'cv_acc_std': np.std(dummy_val_acc),
        'fold_results': {'val_auc': dummy_val_auc, 'val_acc': dummy_val_acc}
    }
    print(f"  DummyClassifier: AUC={results['DummyClassifier']['test_auc']:.4f}, Bal_Acc={results['DummyClassifier']['test_balanced_accuracy']:.4f}")

    for model, X in X_dict.items():
        val_auc = {a: [] for a in alphas}
        val_acc = {a: [] for a in alphas}
        train_auc = {a: [] for a in alphas}
        train_acc = {a: [] for a in alphas}

        for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]

            for alpha in alphas:
                n = max(4, int(alpha * len(train_idx)))
                try:
                    tr_idx, _ = train_test_split(train_idx, train_size=n, stratify=y[train_idx], random_state=SEED)
                except:
                    tr_idx = np.random.choice(train_idx, n, replace=False)

                if len(np.unique(y[tr_idx])) < 2:
                    continue

                scaler = StandardScaler()
                rf = RandomForestClassifier(n_estimators=500, random_state=SEED, n_jobs=-1, 
                                           max_features='sqrt', class_weight='balanced')
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])

                pred_proba = rf.predict_proba(scaler.transform(X[val_idx]))[:, 1]
                val_acc[alpha].append(balanced_accuracy_score(y[val_idx], rf.predict(scaler.transform(X[val_idx]))))
                val_auc[alpha].append(roc_auc_score(y[val_idx], pred_proba))
                
                train_pred_proba = rf.predict_proba(scaler.transform(X[tr_idx]))[:, 1]
                train_acc[alpha].append(balanced_accuracy_score(y[tr_idx], rf.predict(scaler.transform(X[tr_idx]))))
                train_auc[alpha].append(roc_auc_score(y[tr_idx], train_pred_proba))

        avg_auc = {a: np.mean(val_auc[a]) for a in alphas if val_auc[a]}
        best_alpha = max(avg_auc, key=avg_auc.get)

        try:
            final_tr, _ = train_test_split(cv_idx, train_size=max(4, int(best_alpha * len(cv_idx))), 
                                          stratify=y[cv_idx], random_state=SEED)
        except:
            final_tr = cv_idx

        scaler = StandardScaler()
        rf = RandomForestClassifier(n_estimators=500, random_state=SEED, n_jobs=-1, 
                                    max_features='sqrt', class_weight='balanced')
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred_proba = rf.predict_proba(scaler.transform(X[test_idx]))[:, 1]

        train_val_gap = np.mean(train_auc[best_alpha]) - np.mean(val_auc[best_alpha])
        is_overfitting = train_val_gap > 0.25

        results[model] = {
            'best_alpha': best_alpha,
            'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], rf.predict(scaler.transform(X[test_idx]))),
            'test_auc': roc_auc_score(y[test_idx], test_pred_proba),
            'cv_auc_mean': np.mean(val_auc[best_alpha]),
            'cv_auc_std': np.std(val_auc[best_alpha]),
            'cv_acc_mean': np.mean(val_acc[best_alpha]),
            'cv_acc_std': np.std(val_acc[best_alpha]),
            'train_auc_mean': np.mean(train_auc[best_alpha]),
            'train_val_gap': train_val_gap,
            'is_overfitting': is_overfitting,
            'fold_results': {
                'val_auc': val_auc[best_alpha],
                'val_acc': val_acc[best_alpha],
                'train_auc': train_auc[best_alpha]
            },
            'cv_results': {
                'alphas': alphas, 
                'acc': {a: np.mean(val_acc[a]) for a in alphas if val_acc[a]},
                'auc': avg_auc,
                'acc_std': {a: np.std(val_acc[a]) for a in alphas if val_acc[a]},
                'auc_std': {a: np.std(val_auc[a]) for a in alphas if val_auc[a]}
            }
        }
        print(f"  {model}: AUC={results[model]['test_auc']:.4f}, Bal_Acc={results[model]['test_balanced_accuracy']:.4f}" + 
              (f" [OVERFITTING: train-val gap={train_val_gap:.3f}]" if is_overfitting else ""))

    return results

def run_regression(X_dict, y, test_size=0.1):
    print("Running age prediction...")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), random_state=SEED)
    kf = KFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.1, 0.25, 0.5, 0.75, 1.0]

    print("  Training DummyRegressor baseline...")
    dummy_val_r2 = []
    dummy_val_mae = []
    
    for train_rel, val_rel in kf.split(cv_idx):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyRegressor(strategy='mean')
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        dummy_val_r2.append(r2_score(y[val_idx], pred))
        dummy_val_mae.append(mean_absolute_error(y[val_idx], pred))
    
    dummy = DummyRegressor(strategy='mean')
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    
    results['DummyRegressor'] = {
        'best_alpha': 1.0,
        'test_r2': r2_score(y[test_idx], dummy_test_pred),
        'test_mae': mean_absolute_error(y[test_idx], dummy_test_pred),
        'test_rmse': np.sqrt(mean_squared_error(y[test_idx], dummy_test_pred)),
        'cv_r2_mean': np.mean(dummy_val_r2),
        'cv_r2_std': np.std(dummy_val_r2),
        'cv_mae_mean': np.mean(dummy_val_mae),
        'cv_mae_std': np.std(dummy_val_mae),
        'fold_results': {'val_r2': dummy_val_r2, 'val_mae': dummy_val_mae}
    }
    print(f"  DummyRegressor: R2={results['DummyRegressor']['test_r2']:.4f}, MAE={results['DummyRegressor']['test_mae']:.2f}")

    for model, X in X_dict.items():
        val_r2 = {a: [] for a in alphas}
        val_mae = {a: [] for a in alphas}
        train_r2 = {a: [] for a in alphas}
        train_mae = {a: [] for a in alphas}

        for train_rel, val_rel in kf.split(cv_idx):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]

            for alpha in alphas:
                n = max(40, int(alpha * len(train_idx)))
                tr_idx, _ = train_test_split(train_idx, train_size=n, random_state=SEED) \
                           if n < len(train_idx) else (train_idx, None)

                scaler = StandardScaler()
                rf = RandomForestRegressor(n_estimators=200, max_depth=10, min_samples_split=5, 
                                          random_state=SEED, n_jobs=-1)
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])

                pred = rf.predict(scaler.transform(X[val_idx]))
                val_r2[alpha].append(r2_score(y[val_idx], pred))
                val_mae[alpha].append(mean_absolute_error(y[val_idx], pred))
                
                train_pred = rf.predict(scaler.transform(X[tr_idx]))
                train_r2[alpha].append(r2_score(y[tr_idx], train_pred))
                train_mae[alpha].append(mean_absolute_error(y[tr_idx], train_pred))

        avg_r2 = {a: np.mean(val_r2[a]) for a in alphas if val_r2[a]}
        best_alpha = max(avg_r2, key=avg_r2.get)

        n = max(40, int(best_alpha * len(cv_idx)))
        final_tr, _ = train_test_split(cv_idx, train_size=n, random_state=SEED) \
                     if n < len(cv_idx) else (cv_idx, None)

        scaler = StandardScaler()
        rf = RandomForestRegressor(n_estimators=200, max_depth=10, min_samples_split=5, random_state=SEED, n_jobs=-1)
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred = rf.predict(scaler.transform(X[test_idx]))

        train_val_gap = np.mean(train_r2[best_alpha]) - np.mean(val_r2[best_alpha])
        is_overfitting = train_val_gap > 0.35

        results[model] = {
            'best_alpha': best_alpha,
            'test_r2': r2_score(y[test_idx], test_pred),
            'test_mae': mean_absolute_error(y[test_idx], test_pred),
            'test_rmse': np.sqrt(mean_squared_error(y[test_idx], test_pred)),
            'predictions': test_pred,
            'actual': y[test_idx],
            'cv_r2_mean': np.mean(val_r2[best_alpha]),
            'cv_r2_std': np.std(val_r2[best_alpha]),
            'cv_mae_mean': np.mean(val_mae[best_alpha]),
            'cv_mae_std': np.std(val_mae[best_alpha]),
            'train_r2_mean': np.mean(train_r2[best_alpha]),
            'train_val_gap': train_val_gap,
            'is_overfitting': is_overfitting,
            'fold_results': {
                'val_r2': val_r2[best_alpha],
                'val_mae': val_mae[best_alpha],
                'train_r2': train_r2[best_alpha]
            },
            'cv_results': {
                'alphas': alphas, 
                'r2': avg_r2, 
                'mae': {a: np.mean(val_mae[a]) for a in alphas if val_mae[a]},
                'r2_std': {a: np.std(val_r2[a]) for a in alphas if val_r2[a]},
                'mae_std': {a: np.std(val_mae[a]) for a in alphas if val_mae[a]}
            }
        }
        print(f"  {model}: R2={results[model]['test_r2']:.4f}, MAE={results[model]['test_mae']:.2f}" +
              (f" [OVERFITTING: train-val gap={train_val_gap:.3f}]" if is_overfitting else ""))

    return results

# Run Sex Classification
print("Sex Classification")

sex_results = run_classification(features_dict, sex_labels, test_size=0.1, task_name="Sex classification")
print()

# Run Parkinson Classification
print("Parkinson Classification")

parkinson_results = run_classification(features_dict, parkinson_labels, test_size=0.1, task_name="Parkinson classification")
print()

# Run Age Prediction
print("Age Prediction")

age_results = run_regression(features_dict, age_labels, test_size=0.1)
print()

# Overfitting Analysis
print("Overfitting Analysis")

overfitting_report = []
for model in features_dict.keys():
    sex_overfit = sex_results[model]['is_overfitting']
    parkinson_overfit = parkinson_results[model]['is_overfitting']
    age_overfit = age_results[model]['is_overfitting']
    
    overfitting_report.append({
        'Model': model,
        'Sex_Overfitting': sex_overfit,
        'Sex_Train_Val_Gap': sex_results[model]['train_val_gap'],
        'Parkinson_Overfitting': parkinson_overfit,
        'Parkinson_Train_Val_Gap': parkinson_results[model]['train_val_gap'],
        'Age_Overfitting': age_overfit,
        'Age_Train_Val_Gap': age_results[model]['train_val_gap']
    })
    
    if sex_overfit or parkinson_overfit or age_overfit:
        print(f"{model}:")
        if sex_overfit:
            print(f"  Sex classification: Train-Val AUC gap = {sex_results[model]['train_val_gap']:.3f}")
        if parkinson_overfit:
            print(f"  Parkinson classification: Train-Val AUC gap = {parkinson_results[model]['train_val_gap']:.3f}")
        if age_overfit:
            print(f"  Age prediction: Train-Val R² gap = {age_results[model]['train_val_gap']:.3f}")

if not any(r['Sex_Overfitting'] or r['Parkinson_Overfitting'] or r['Age_Overfitting'] for r in overfitting_report):
    print("No significant overfitting detected")

pd.DataFrame(overfitting_report).to_csv('ppmi_overfitting_analysis.csv', index=False)
print("\nSaved: ppmi_overfitting_analysis.csv\n")

# Saving Results
print("="*80)
print("Saving Results")
print("="*80)

all_models = ['DummyClassifier'] + list(features_dict.keys())

pd.DataFrame([{'Model': m, 'Best_Alpha': sex_results[m]['best_alpha'], 
               'Test_Balanced_Accuracy': sex_results[m]['test_balanced_accuracy'], 
               'Test_AUC': sex_results[m]['test_auc'],
               'CV_AUC_Mean': sex_results[m]['cv_auc_mean'],
               'CV_AUC_Std': sex_results[m]['cv_auc_std']} 
              for m in all_models]).to_csv('ppmi_sex_classification_summary.csv', index=False)

pd.DataFrame([{'Model': m, 'Best_Alpha': parkinson_results[m]['best_alpha'], 
               'Test_Balanced_Accuracy': parkinson_results[m]['test_balanced_accuracy'], 
               'Test_AUC': parkinson_results[m]['test_auc'],
               'CV_AUC_Mean': parkinson_results[m]['cv_auc_mean'],
               'CV_AUC_Std': parkinson_results[m]['cv_auc_std']} 
              for m in all_models]).to_csv('ppmi_parkinson_classification_summary.csv', index=False)

all_models_reg = ['DummyRegressor'] + list(features_dict.keys())

pd.DataFrame([{'Model': m, 'Best_Alpha': age_results[m]['best_alpha'], 
               'Test_R2': age_results[m]['test_r2'], 
               'Test_MAE': age_results[m]['test_mae'], 
               'Test_RMSE': age_results[m]['test_rmse'],
               'CV_R2_Mean': age_results[m]['cv_r2_mean'],
               'CV_R2_Std': age_results[m]['cv_r2_std'],
               'CV_MAE_Mean': age_results[m]['cv_mae_mean'],
               'CV_MAE_Std': age_results[m]['cv_mae_std']} 
              for m in all_models_reg]).to_csv('ppmi_age_prediction_summary.csv', index=False)

all_results = [
    {'Model': m, 
     'Sex_AUC': sex_results[m]['test_auc'], 
     'Sex_Bal_Acc': sex_results[m]['test_balanced_accuracy'],
     'Parkinson_AUC': parkinson_results[m]['test_auc'], 
     'Parkinson_Bal_Acc': parkinson_results[m]['test_balanced_accuracy'],
     'Age_R2': age_results[m]['test_r2'], 
     'Age_MAE': age_results[m]['test_mae'], 
     'Age_RMSE': age_results[m]['test_rmse']} 
    for m in list(features_dict.keys())
]
pd.DataFrame(all_results).to_csv('ppmi_detailed_comparison.csv', index=False)

rankings = []
for task in ['Sex_AUC', 'Parkinson_AUC', 'Age_R2']:
    for rank, item in enumerate(sorted(all_results, key=lambda x: x[task], reverse=True), 1):
        rankings.append({
            'Task': task.replace('_', ' '), 
            'Rank': rank, 
            'Model': item['Model'], 
            'Score': item[task]
        })
pd.DataFrame(rankings).to_csv('ppmi_model_rankings.csv', index=False)

print("Saved: ppmi_sex_classification_summary.csv")
print("Saved: ppmi_parkinson_classification_summary.csv")
print("Saved: ppmi_age_prediction_summary.csv")
print("Saved: ppmi_detailed_comparison.csv")
print("Saved: ppmi_model_rankings.csv\n")

# Generate Learning Curve Visualizations
print("="*80)
print("Generating Learning Curves")
print("="*80)

models = list(features_dict.keys())
x_pos = np.arange(len(models))
colors = ['#3498db', '#2ecc71', '#e74c3c', '#9b59b6']

n_cv = len(age_labels) - int(0.1 * len(age_labels))
alphas = [0.1, 0.25, 0.5, 0.75, 1.0]
training_sizes = [int(alpha * n_cv * 0.8) for alpha in alphas]

# Learning curves WITH SHADED ERROR REGIONS
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 6))

# Sex classification learning curve
for idx, model in enumerate(models):
    acc_means = []
    acc_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in sex_results[model]['cv_results']['acc']:
            acc_means.append(sex_results[model]['cv_results']['acc'][alpha])
            acc_stds.append(sex_results[model]['cv_results']['acc_std'][alpha])
            valid_sizes.append(training_sizes[i])
    
    acc_means = np.array(acc_means)
    acc_stds = np.array(acc_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax1.plot(valid_sizes, acc_means, 'o-', linewidth=2.5, markersize=8, 
             label=model, color=colors[idx])
    ax1.fill_between(valid_sizes, acc_means - acc_stds, acc_means + acc_stds, 
                     alpha=0.25, color=colors[idx])

dummy_acc_cv = sex_results['DummyClassifier']['cv_acc_mean']
ax1.axhline(y=dummy_acc_cv, color='red', linestyle='--', linewidth=2, 
            label=f'Dummy Classifier ({dummy_acc_cv:.3f})', alpha=0.7)

ax1.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax1.set_ylabel('Balanced Accuracy', fontsize=12, weight='bold')
ax1.set_title('Sex Classification: CV Learning Curve', fontsize=14, weight='bold')
ax1.legend(fontsize=10, loc='lower right')
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 1])

# Parkinson classification learning curve
for idx, model in enumerate(models):
    acc_means = []
    acc_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in parkinson_results[model]['cv_results']['acc']:
            acc_means.append(parkinson_results[model]['cv_results']['acc'][alpha])
            acc_stds.append(parkinson_results[model]['cv_results']['acc_std'][alpha])
            valid_sizes.append(training_sizes[i])
    
    acc_means = np.array(acc_means)
    acc_stds = np.array(acc_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax2.plot(valid_sizes, acc_means, 'o-', linewidth=2.5, markersize=8, 
             label=model, color=colors[idx])
    ax2.fill_between(valid_sizes, acc_means - acc_stds, acc_means + acc_stds, 
                     alpha=0.25, color=colors[idx])

dummy_acc_cv_pd = parkinson_results['DummyClassifier']['cv_acc_mean']
ax2.axhline(y=dummy_acc_cv_pd, color='red', linestyle='--', linewidth=2, 
            label=f'Dummy Classifier ({dummy_acc_cv_pd:.3f})', alpha=0.7)

ax2.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax2.set_ylabel('Balanced Accuracy', fontsize=12, weight='bold')
ax2.set_title('Parkinson Classification: CV Learning Curve', fontsize=14, weight='bold')
ax2.legend(fontsize=10, loc='lower right')
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

# Age prediction learning curve
for idx, model in enumerate(models):
    mae_means = []
    mae_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in age_results[model]['cv_results']['mae']:
            mae_means.append(age_results[model]['cv_results']['mae'][alpha])
            mae_stds.append(age_results[model]['cv_results']['mae_std'][alpha])
            valid_sizes.append(training_sizes[i])
    
    mae_means = np.array(mae_means)
    mae_stds = np.array(mae_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax3.plot(valid_sizes, mae_means, 'o-', linewidth=2.5, markersize=8, 
             label=model, color=colors[idx])
    ax3.fill_between(valid_sizes, mae_means - mae_stds, mae_means + mae_stds, 
                     alpha=0.25, color=colors[idx])

dummy_mae_cv = age_results['DummyRegressor']['cv_mae_mean']
ax3.axhline(y=dummy_mae_cv, color='red', linestyle='--', linewidth=2, 
            label=f'Dummy Regressor ({dummy_mae_cv:.2f})', alpha=0.7)

ax3.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax3.set_ylabel('MAE (years)', fontsize=12, weight='bold')
ax3.set_title('Age Prediction: CV MAE Learning Curve', fontsize=14, weight='bold')
ax3.legend(fontsize=10, loc='upper right')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('ppmi_cv_learning_curves.png', dpi=300, bbox_inches='tight')
plt.close()
print("Saved: ppmi_cv_learning_curves.png\n")

print(f"Dataset: {len(common_subjects)} subjects")
print(f"Age range: {age_labels.min():.1f} - {age_labels.max():.1f} years")
print(f"Sex: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"PD: {np.sum(parkinson_labels)} PD, {len(parkinson_labels) - np.sum(parkinson_labels)} HC")

Overwriting /home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs/ppmi_multimodel_comparison.py


## BrainIAC and CNN Preprocessing | HD-BET | HBN

In [None]:
%%sbatch --array=1-100
#!/bin/bash
#SBATCH --account=def-glatard
#SBATCH --time=12:00:00
#SBATCH --mem=16G
#SBATCH --cpus-per-task=4
#SBATCH --job-name=brainiac_batch
#SBATCH --output=logs/prep_batch_%a.out
#SBATCH --error=logs/prep_batch_%a.err

export BASE_DIR="/home/arelbaha/links/projects/rrg-glatard/arelbaha"
export RAW_DIR="${BASE_DIR}/brainiac_p_files"
export OUTPUT_DIR="${BASE_DIR}/brainiac_p_outputs"

module load python/3.11
module load opencv
source /home/arelbaha/.venvs/brainiac_env/bin/activate

#Batch
BATCH_NUM=$(printf "%03d" ${SLURM_ARRAY_TASK_ID})
BATCH_DIR="${RAW_DIR}/batch_${BATCH_NUM}"
BATCH_OUTPUT="${OUTPUT_DIR}/batch_${BATCH_NUM}"

mkdir -p "$BATCH_OUTPUT"
echo "Processing batch ${BATCH_NUM} from ${BATCH_DIR}"

#Preprocessing
python /home/arelbaha/.venvs/brainiac_env/lib/python3.11/site-packages/BrainIAC/src/preprocessing/mri_preprocess_3d_simple.py \
    --temp_img /home/arelbaha/.venvs/brainiac_env/lib/python3.11/site-packages/BrainIAC/src/preprocessing/atlases/temp_head.nii.gz \
    --input_dir "$BATCH_DIR" \
    --output_dir "$BATCH_OUTPUT"

echo "Completed batch ${BATCH_NUM}"

## HBN Multi-task Evaluation | AnatCl vs BrainIAC vs CNN vs FreeSurfer

In [3]:
%%writefile /home/arelbaha/links/projects/rrg-glatard/arelbaha/debug_hbn_multimodel_comparison.py

import os
import glob
import random
import numpy as np
import pandas as pd
import fnmatch
from nilearn import plotting, datasets

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import nibabel as nib
from anatcl import AnatCL
from monai.networks.nets import ViT
import torch.nn.functional as F
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.dummy import DummyClassifier, DummyRegressor
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.ndimage as ndi
from matplotlib.patches import Patch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

# PATHS
HBN_BIDS = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_BIDS"
HBN_BIDS_LOWER = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/hbn_bids"
BRAINIAC_OUT = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/brainiac_p_outputs"
FREESURFER_DIR = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_FreeSurfer/freesurfer"
DEMO_FILE = os.path.join(HBN_BIDS, "final_preprocessed_subjects_with_demographics.tsv")
ANATCL_ENCODER_PATH = "/home/arelbaha/.venvs/jupyter_py3/bin"
BRAINIAC_CKPT = "/home/arelbaha/.venvs/jupyter_py3/bin/BrainIAC.ckpt"
DROPOUT_RATE = 0.3
PARCELLATION = "Schaefer2018_400Parcels_17Networks_order"

device = "cpu"

#Loading Demographics
print("Loading demographics")

demo_df = pd.read_csv(DEMO_FILE, sep='\t')
demo_df['participant_id'] = demo_df['participant_id'].astype(str)

id_sex_dict = {}
id_age_dict = {}

for _, row in demo_df.iterrows():
    subject_id = row['participant_id']
    sex = row['sex'].strip()
    if sex == 'Female':
        id_sex_dict[subject_id] = 1
    elif sex == 'Male':
        id_sex_dict[subject_id] = 0
    id_age_dict[subject_id] = row['age']

print(f"Loaded demographics for {len(id_sex_dict)} subjects")
print(f"Sex distribution: {sum(id_sex_dict.values())} Female, {len(id_sex_dict) - sum(id_sex_dict.values())} Male")
print(f"Age range: {min(id_age_dict.values()):.1f} - {max(id_age_dict.values()):.1f} years\n")

#Finding CAT12 Files
print("Finding CAT12 (s6mwp1) files")

cat12_data = {}

print(f"Searching in {HBN_BIDS}...")
for subject_id in id_sex_dict.keys():
    pattern = os.path.join(HBN_BIDS, f"sub-{subject_id}", "ses-*", "anat", "mri", "s6mwp1sub*.nii")
    files = glob.glob(pattern)
    if files:
        cat12_data[subject_id] = files[0]

print(f"Searching in {HBN_BIDS_LOWER}...")
for subject_id in id_sex_dict.keys():
    if subject_id not in cat12_data:
        pattern = os.path.join(HBN_BIDS_LOWER, f"sub-{subject_id}", "ses-*", "anat", "mri", "s6mwp1sub*.nii")
        files = glob.glob(pattern)
        if files:
            cat12_data[subject_id] = files[0]

print(f"Found {len(cat12_data)} CAT12 files\n")

#Finding BrainIAC Files
print("Finding BrainIAC preprocessed files")

brainiac_data = {}
batch_dirs = glob.glob(os.path.join(BRAINIAC_OUT, "batch_*"))
print(f"Searching in {len(batch_dirs)} batch directories...")

for batch_dir in sorted(batch_dirs):
    files = glob.glob(os.path.join(batch_dir, "sub-*_0000.nii.gz"))
    for f in files:
        basename = os.path.basename(f)
        subject_id = basename.split('_')[0].replace('sub-', '')
        
        if subject_id in id_sex_dict and subject_id not in brainiac_data:
            brainiac_data[subject_id] = f

print(f"Found {len(brainiac_data)} BrainIAC files\n")

#Extracting FreeSurfer Features
print("Extracting FreeSurfer features")

fs_data = {}

for subject_id in id_sex_dict.keys():
    subject_dir = os.path.join(FREESURFER_DIR, f"sub-{subject_id}")
    stats_file = os.path.join(subject_dir, f"sub-{subject_id}_regionsurfacestats.tsv")
    
    if os.path.exists(stats_file):
        try:
            df = pd.read_csv(stats_file, sep='\t')
            filtered_df = df[df["atlas"] == PARCELLATION]
            
            if not filtered_df.empty:
                filtered_df = filtered_df.sort_values("StructName")
                
                if "SurfArea" in filtered_df.columns and "ThickAvg" in filtered_df.columns:
                    surf_area = filtered_df["SurfArea"].values[:400]
                    thick_avg = filtered_df["ThickAvg"].values[:400]
                    combined = np.concatenate([surf_area, thick_avg])
                    
                    if not np.any(np.isnan(combined)):
                        fs_data[subject_id] = combined
        except Exception as e:
            print(f"Error reading {stats_file}: {e}")

print(f"Found {len(fs_data)} FreeSurfer subjects with 800 features (400 SurfArea + 400 ThickAvg)\n")

#Overall Common Subjects
print("Overall common subjects across all modalities")

common_subjects = sorted(list(
    set(cat12_data.keys()) & 
    set(brainiac_data.keys()) & 
    set(fs_data.keys())
))

print(f"Subjects with CAT12: {len(cat12_data)}")
print(f"Subjects with BrainIAC: {len(brainiac_data)}")
print(f"Subjects with FreeSurfer: {len(fs_data)}")
print(f"Common subjects (all modalities): {len(common_subjects)}\n")

if len(common_subjects) == 0:
    print("ERROR: No common subjects found across all modalities")
    exit(1)

cat12_paths = [cat12_data[s] for s in common_subjects]
brainiac_paths = [brainiac_data[s] for s in common_subjects]
fs_features = np.array([fs_data[s] for s in common_subjects])
sex_labels = np.array([id_sex_dict[s] for s in common_subjects])
age_labels = np.array([id_age_dict[s] for s in common_subjects])

print(f"Final matched dataset: {len(common_subjects)} subjects")
print(f"FreeSurfer features shape: {fs_features.shape}")
print(f"Sex labels: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"Age: mean={age_labels.mean():.1f}, range={age_labels.min():.1f}-{age_labels.max():.1f}\n")

pd.DataFrame({
    'subject_id': common_subjects,
    'cat12_path': cat12_paths,
    'brainiac_path': brainiac_paths,
    'sex': sex_labels,
    'age': age_labels
}).to_csv('hbn_matched_subjects.csv', index=False)
print("Saved: hbn_matched_subjects.csv\n")

#Dataset Classes

class CAT12VBMDataset(Dataset):
    def __init__(self, data, labels, transform):
        self.data, self.labels, self.transform = data, labels, transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img = self.transform(nib.load(self.data[idx]).get_fdata()).unsqueeze(0)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

class BrainIACDataset(Dataset):
    def __init__(self, paths, labels, transform):
        self.paths, self.labels, self.transform = paths, labels, transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (96, 96, 96):
            img = F.interpolate(torch.from_numpy(img[None, None]), size=(96, 96, 96), 
                              mode='trilinear', align_corners=False).squeeze().numpy()
        img_tensor = torch.from_numpy(img)
        img_normalized = self.transform(img_tensor)
        return img_normalized[None], torch.tensor(self.labels[idx], dtype=torch.float32)

class CNNDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = nib.load(self.paths[idx]).get_fdata().astype(np.float32)
        if img.shape != (92, 110, 92):
            img = ndi.zoom(img, [92/img.shape[0], 110/img.shape[1], 92/img.shape[2]], order=1)
        img = (img - img.mean()) / (img.std() + 1e-6)
        return torch.from_numpy(img[None]).float()

class CNN3D(nn.Module):
    def __init__(self, dropout_rate=DROPOUT_RATE):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 64, 3, padding=1)
        self.bn1, self.pool1, self.drop1 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv2 = nn.Conv3d(64, 64, 3, padding=1)
        self.bn2, self.pool2, self.drop2 = nn.BatchNorm3d(64), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv3 = nn.Conv3d(64, 128, 3, padding=1)
        self.bn3, self.pool3, self.drop3 = nn.BatchNorm3d(128), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.conv4 = nn.Conv3d(128, 256, 3, padding=1)
        self.bn4, self.pool4, self.drop4 = nn.BatchNorm3d(256), nn.MaxPool3d(2), nn.Dropout3d(dropout_rate)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1, self.drop5 = nn.Linear(256, 512), nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.drop1(self.bn1(self.pool1(F.relu(self.conv1(x)))))
        x = self.drop2(self.bn2(self.pool2(F.relu(self.conv2(x)))))
        x = self.drop3(self.bn3(self.pool3(F.relu(self.conv3(x)))))
        x = self.drop4(self.bn4(self.pool4(F.relu(self.conv4(x)))))
        return self.drop5(F.relu(self.fc1(self.global_pool(x).view(x.size(0), -1))))

def load_brainiac_vit(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    vit = ViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16),
              hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, classification=False)
    pe_key = "backbone.patch_embedding.position_embeddings"
    if pe_key in ckpt.get("state_dict", ckpt):
        vit.patch_embedding.position_embeddings = nn.Parameter(ckpt["state_dict"][pe_key].clone())
    backbone_state = {k.replace("backbone.", ""): v for k, v in ckpt.get("state_dict", ckpt).items() 
                     if k.startswith("backbone.")}
    vit.load_state_dict(backbone_state, strict=False)
    return vit.to(device).eval()

#Extracting AnatCL Features (All 5 folds, averaged)
print("Extracting AnatCL features (averaging across 5 cross-validation folds)")

anatcl_transform = transforms.Compose([
    transforms.Lambda(lambda x: torch.from_numpy(x.copy()).float()),
    transforms.Normalize(mean=0.0, std=1.0)
])

num_folds = 5
all_fold_features = []

for fold_idx in range(num_folds):
    path = os.path.join(ANATCL_ENCODER_PATH, f"fold{fold_idx}.pth")
    
    if not os.path.exists(path):
        print(f"ERROR: fold{fold_idx}.pth not found at {path}")
        exit(1)
    
    print(f"Loading fold {fold_idx}...")
    encoder = AnatCL(descriptor="global", fold=fold_idx, pretrained=False).to(device).eval()
    encoder.backbone.load_state_dict(torch.load(path, map_location=device, weights_only=False)['model'])
    
    for p in encoder.parameters():
        p.requires_grad = False
    
    dl = DataLoader(CAT12VBMDataset(cat12_paths, age_labels, anatcl_transform), 
                    batch_size=32, num_workers=0)
    
    with torch.no_grad():
        fold_features = torch.cat([encoder(vol.to(device)).cpu() for vol, _ in dl]).numpy()
    
    all_fold_features.append(fold_features)
    print(f"  Fold {fold_idx} features: {fold_features.shape}")
    
    del encoder

anatcl_features = np.mean(all_fold_features, axis=0)

print(f"\nAnatCL features (averaged across {num_folds} folds): {anatcl_features.shape}\n")

#Extracting BrainIAC Features
print("Extracting BrainIAC features")

brainiac_vit = load_brainiac_vit(BRAINIAC_CKPT, device)
brainiac_transform = lambda x: (x - x.mean()) / (x.std() + 1e-6)
dl = DataLoader(BrainIACDataset(brainiac_paths, age_labels, brainiac_transform), 
                batch_size=16, num_workers=0)

brainiac_features = []

with torch.no_grad():
    for x, _ in dl:
        out = brainiac_vit(x.to(device))
        
        #Handle tuple or tensor
        if isinstance(out, tuple):
            cls_token = out[0][:, 0]  # First element, CLS token
        else:
            cls_token = out[:, 0]  # CLS token directly
        
        brainiac_features.append(cls_token.cpu().numpy())

brainiac_features = np.vstack(brainiac_features)

print(f"BrainIAC features: {brainiac_features.shape}")
del brainiac_vit

#Extracting CNN Features
print("Extracting CNN features")

cnn_model = CNN3D().to(device).eval()
dl = DataLoader(CNNDataset(brainiac_paths), batch_size=8, num_workers=0)

with torch.no_grad():
    cnn_features = np.vstack([cnn_model(x.to(device)).cpu().numpy() for x in dl])

print(f"CNN features: {cnn_features.shape}\n")
del cnn_model

#Organize Features
features_dict = {
    'AnatCL': anatcl_features,
    'BrainIAC': brainiac_features,
    'CNN': cnn_features,
    'FreeSurfer': fs_features
}
print("Feature extraction complete")
for name, feats in features_dict.items():
    print(f"{name}: {feats.shape}")
print()

#Correlation Analysis
print("Computing cross-model feature correlations")

def compute_correlation_within_models(features_dict):
    all_features = []
    boundaries = [0]
    reordered_indices = {}

    for model in ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer']:
        if model not in features_dict:
            continue

        feats = features_dict[model]
        feats = StandardScaler().fit_transform(feats)

        valid = ~np.isnan(feats).any(axis=0) & (np.std(feats, axis=0) > 1e-10)
        feats = feats[:, valid]

        if feats.shape[1] == 0:
            continue

        print(f"  {model}: {feats.shape[1]} features")

        corr = np.corrcoef(feats.T)
        corr = np.nan_to_num((corr + corr.T) / 2, nan=0.0, posinf=1.0, neginf=-1.0)
        np.fill_diagonal(corr, 1.0)

        dist = np.abs(1 - np.abs(corr))
        np.fill_diagonal(dist, 0)
        condensed = np.nan_to_num(squareform((dist + dist.T) / 2, checks=False), 
                                  nan=0.0, posinf=1.0, neginf=0.0)

        reorder_idx = dendrogram(linkage(condensed, method='ward'), no_plot=True)['leaves']
        all_features.append(feats[:, reorder_idx])
        boundaries.append(boundaries[-1] + feats.shape[1])
        reordered_indices[model] = reorder_idx

    combined = np.hstack(all_features)
    corr = np.nan_to_num((np.corrcoef(combined.T) + np.corrcoef(combined.T).T) / 2, 
                         nan=0.0, posinf=1.0, neginf=-1.0)
    np.fill_diagonal(corr, 1.0)

    print(f"\nTotal features: {combined.shape[1]}")
    
    return corr, boundaries, ['AnatCL', 'BrainIAC', 'CNN', 'FreeSurfer'], reordered_indices

corr_matrix, boundaries, model_names, reorder_indices = compute_correlation_within_models(features_dict)
pd.DataFrame(corr_matrix).to_csv('hbn_feature_correlation_matrix.csv', index=False)
print(f"Correlation matrix: {corr_matrix.shape}")
print(f"Model boundaries: {boundaries}\n")

# Plot correlation matrix (FULL SIZE)
print("Plotting cross-model correlation matrix...")

plt.figure(figsize=(40, 38))
sns.heatmap(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, center=0, square=True, 
            linewidths=0, cbar_kws={"shrink": 0.3},
            xticklabels=False, yticklabels=False)

for b in boundaries[1:-1]:
    plt.axhline(y=b, color='black', linewidth=4)
    plt.axvline(x=b, color='black', linewidth=4)

for i, (pos, name) in enumerate(zip([(boundaries[i] + boundaries[i+1]) / 2 
                                      for i in range(len(boundaries)-1)], model_names)):
    plt.text(pos, -30, name, ha='center', fontsize=24, weight='bold')
    plt.text(-30, pos, name, ha='center', va='center', fontsize=24, weight='bold', rotation=90)

plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig('hbn_cross_model_correlation.png', dpi=300, bbox_inches='tight')
plt.close()
print("Saved: hbn_cross_model_correlation.png\n")

#Model Training Functions

def run_classification(X_dict, y, test_size=0.1, task_name="Classification"):
    print(f"Running {task_name}...")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), 
                                        stratify=y, random_state=SEED)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.1, 0.25, 0.5, 0.75, 1.0]

    print("  Training DummyClassifier baseline...")
    dummy_val_auc = []
    dummy_val_acc = []
    
    for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyClassifier(strategy='stratified', random_state=SEED)
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        pred_proba = dummy.predict_proba(np.zeros((len(val_idx), 1)))[:, 1]
        
        dummy_val_acc.append(balanced_accuracy_score(y[val_idx], pred))
        dummy_val_auc.append(roc_auc_score(y[val_idx], pred_proba))
    
    dummy = DummyClassifier(strategy='stratified', random_state=SEED)
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    dummy_test_proba = dummy.predict_proba(np.zeros((len(test_idx), 1)))[:, 1]
    
    results['DummyClassifier'] = {
        'best_alpha': 1.0,
        'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], dummy_test_pred),
        'test_auc': roc_auc_score(y[test_idx], dummy_test_proba),
        'cv_auc_mean': np.mean(dummy_val_auc),
        'cv_auc_std': np.std(dummy_val_auc),
        'cv_acc_mean': np.mean(dummy_val_acc),
        'cv_acc_std': np.std(dummy_val_acc),
        'fold_results': {'val_auc': dummy_val_auc, 'val_acc': dummy_val_acc}
    }
    print(f"  DummyClassifier: AUC={results['DummyClassifier']['test_auc']:.4f}, Bal_Acc={results['DummyClassifier']['test_balanced_accuracy']:.4f}")

    for model, X in X_dict.items():
        val_auc = {a: [] for a in alphas}
        val_acc = {a: [] for a in alphas}
        train_auc = {a: [] for a in alphas}
        train_acc = {a: [] for a in alphas}

        for train_rel, val_rel in skf.split(cv_idx, y[cv_idx]):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]

            for alpha in alphas:
                n = max(4, int(alpha * len(train_idx)))
                try:
                    tr_idx, _ = train_test_split(train_idx, train_size=n, stratify=y[train_idx], random_state=SEED)
                except:
                    tr_idx = np.random.choice(train_idx, n, replace=False)

                if len(np.unique(y[tr_idx])) < 2:
                    continue

                scaler = StandardScaler()
                rf = RandomForestClassifier(n_estimators=500, random_state=SEED, n_jobs=-1, 
                                           max_features='sqrt', class_weight='balanced')
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])

                pred_proba = rf.predict_proba(scaler.transform(X[val_idx]))[:, 1]
                val_acc[alpha].append(balanced_accuracy_score(y[val_idx], rf.predict(scaler.transform(X[val_idx]))))
                val_auc[alpha].append(roc_auc_score(y[val_idx], pred_proba))
                
                train_pred_proba = rf.predict_proba(scaler.transform(X[tr_idx]))[:, 1]
                train_acc[alpha].append(balanced_accuracy_score(y[tr_idx], rf.predict(scaler.transform(X[tr_idx]))))
                train_auc[alpha].append(roc_auc_score(y[tr_idx], train_pred_proba))

        avg_auc = {a: np.mean(val_auc[a]) for a in alphas if val_auc[a]}
        best_alpha = max(avg_auc, key=avg_auc.get)

        try:
            final_tr, _ = train_test_split(cv_idx, train_size=max(4, int(best_alpha * len(cv_idx))), 
                                          stratify=y[cv_idx], random_state=SEED)
        except:
            final_tr = cv_idx

        scaler = StandardScaler()
        rf = RandomForestClassifier(n_estimators=500, random_state=SEED, n_jobs=-1, 
                                    max_features='sqrt', class_weight='balanced')
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred_proba = rf.predict_proba(scaler.transform(X[test_idx]))[:, 1]

        train_val_gap = np.mean(train_auc[best_alpha]) - np.mean(val_auc[best_alpha])
        is_overfitting = train_val_gap > 0.25

        results[model] = {
            'best_alpha': best_alpha,
            'test_balanced_accuracy': balanced_accuracy_score(y[test_idx], rf.predict(scaler.transform(X[test_idx]))),
            'test_auc': roc_auc_score(y[test_idx], test_pred_proba),
            'cv_auc_mean': np.mean(val_auc[best_alpha]),
            'cv_auc_std': np.std(val_auc[best_alpha]),
            'cv_acc_mean': np.mean(val_acc[best_alpha]),
            'cv_acc_std': np.std(val_acc[best_alpha]),
            'train_auc_mean': np.mean(train_auc[best_alpha]),
            'train_val_gap': train_val_gap,
            'is_overfitting': is_overfitting,
            'fold_results': {
                'val_auc': val_auc[best_alpha],
                'val_acc': val_acc[best_alpha],
                'train_auc': train_auc[best_alpha]
            },
            'cv_results': {
                'alphas': alphas, 
                'acc': {a: np.mean(val_acc[a]) for a in alphas if val_acc[a]},
                'auc': avg_auc,
                'acc_std': {a: np.std(val_acc[a]) for a in alphas if val_acc[a]},
                'auc_std': {a: np.std(val_auc[a]) for a in alphas if val_auc[a]}
            }
        }
        print(f"  {model}: AUC={results[model]['test_auc']:.4f}, Bal_Acc={results[model]['test_balanced_accuracy']:.4f}" + 
              (f" [OVERFITTING: train-val gap={train_val_gap:.3f}]" if is_overfitting else ""))

    return results

def run_regression(X_dict, y, test_size=0.1):
    print("Running age prediction...")
    results = {}
    cv_idx, test_idx = train_test_split(np.arange(len(y)), test_size=int(test_size * len(y)), random_state=SEED)
    kf = KFold(n_splits=5, shuffle=True, random_state=SEED)
    alphas = [0.1, 0.25, 0.5, 0.75, 1.0]

    print("  Training DummyRegressor baseline...")
    dummy_val_r2 = []
    dummy_val_mae = []
    
    for train_rel, val_rel in kf.split(cv_idx):
        train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]
        dummy = DummyRegressor(strategy='mean')
        dummy.fit(np.zeros((len(train_idx), 1)), y[train_idx])
        
        pred = dummy.predict(np.zeros((len(val_idx), 1)))
        dummy_val_r2.append(r2_score(y[val_idx], pred))
        dummy_val_mae.append(mean_absolute_error(y[val_idx], pred))
    
    dummy = DummyRegressor(strategy='mean')
    dummy.fit(np.zeros((len(cv_idx), 1)), y[cv_idx])
    dummy_test_pred = dummy.predict(np.zeros((len(test_idx), 1)))
    
    results['DummyRegressor'] = {
        'best_alpha': 1.0,
        'test_r2': r2_score(y[test_idx], dummy_test_pred),
        'test_mae': mean_absolute_error(y[test_idx], dummy_test_pred),
        'test_rmse': np.sqrt(mean_squared_error(y[test_idx], dummy_test_pred)),
        'cv_r2_mean': np.mean(dummy_val_r2),
        'cv_r2_std': np.std(dummy_val_r2),
        'cv_mae_mean': np.mean(dummy_val_mae),
        'cv_mae_std': np.std(dummy_val_mae),
        'fold_results': {'val_r2': dummy_val_r2, 'val_mae': dummy_val_mae}
    }
    print(f"  DummyRegressor: R2={results['DummyRegressor']['test_r2']:.4f}, MAE={results['DummyRegressor']['test_mae']:.2f}")

    for model, X in X_dict.items():
        val_r2 = {a: [] for a in alphas}
        val_mae = {a: [] for a in alphas}
        train_r2 = {a: [] for a in alphas}
        train_mae = {a: [] for a in alphas}

        for train_rel, val_rel in kf.split(cv_idx):
            train_idx, val_idx = cv_idx[train_rel], cv_idx[val_rel]

            for alpha in alphas:
                n = max(40, int(alpha * len(train_idx)))
                tr_idx, _ = train_test_split(train_idx, train_size=n, random_state=SEED) \
                           if n < len(train_idx) else (train_idx, None)

                scaler = StandardScaler()
                rf = RandomForestRegressor(n_estimators=200, max_depth=10, min_samples_split=5, 
                                          random_state=SEED, n_jobs=-1)
                rf.fit(scaler.fit_transform(X[tr_idx]), y[tr_idx])

                pred = rf.predict(scaler.transform(X[val_idx]))
                val_r2[alpha].append(r2_score(y[val_idx], pred))
                val_mae[alpha].append(mean_absolute_error(y[val_idx], pred))
                
                train_pred = rf.predict(scaler.transform(X[tr_idx]))
                train_r2[alpha].append(r2_score(y[tr_idx], train_pred))
                train_mae[alpha].append(mean_absolute_error(y[tr_idx], train_pred))

        avg_r2 = {a: np.mean(val_r2[a]) for a in alphas if val_r2[a]}
        best_alpha = max(avg_r2, key=avg_r2.get)

        n = max(40, int(best_alpha * len(cv_idx)))
        final_tr, _ = train_test_split(cv_idx, train_size=n, random_state=SEED) \
                     if n < len(cv_idx) else (cv_idx, None)

        scaler = StandardScaler()
        rf = RandomForestRegressor(n_estimators=200, max_depth=10, min_samples_split=5, random_state=SEED, n_jobs=-1)
        rf.fit(scaler.fit_transform(X[final_tr]), y[final_tr])
        test_pred = rf.predict(scaler.transform(X[test_idx]))

        train_val_gap = np.mean(train_r2[best_alpha]) - np.mean(val_r2[best_alpha])
        is_overfitting = train_val_gap > 0.35

        results[model] = {
            'best_alpha': best_alpha,
            'test_r2': r2_score(y[test_idx], test_pred),
            'test_mae': mean_absolute_error(y[test_idx], test_pred),
            'test_rmse': np.sqrt(mean_squared_error(y[test_idx], test_pred)),
            'predictions': test_pred,
            'actual': y[test_idx],
            'cv_r2_mean': np.mean(val_r2[best_alpha]),
            'cv_r2_std': np.std(val_r2[best_alpha]),
            'cv_mae_mean': np.mean(val_mae[best_alpha]),
            'cv_mae_std': np.std(val_mae[best_alpha]),
            'train_r2_mean': np.mean(train_r2[best_alpha]),
            'train_val_gap': train_val_gap,
            'is_overfitting': is_overfitting,
            'fold_results': {
                'val_r2': val_r2[best_alpha],
                'val_mae': val_mae[best_alpha],
                'train_r2': train_r2[best_alpha]
            },
            'cv_results': {
                'alphas': alphas, 
                'r2': avg_r2, 
                'mae': {a: np.mean(val_mae[a]) for a in alphas if val_mae[a]},
                'r2_std': {a: np.std(val_r2[a]) for a in alphas if val_r2[a]},
                'mae_std': {a: np.std(val_mae[a]) for a in alphas if val_mae[a]}
            }
        }
        print(f"  {model}: R2={results[model]['test_r2']:.4f}, MAE={results[model]['test_mae']:.2f}" +
              (f" [OVERFITTING: train-val gap={train_val_gap:.3f}]" if is_overfitting else ""))

    return results

#Run Sex Classification
print("Sex Classification")

sex_results = run_classification(features_dict, sex_labels, test_size=0.1, task_name="Sex classification")
print()

#Run Age Prediction
print("Age Prediction")

age_results = run_regression(features_dict, age_labels, test_size=0.1)
print()

#Overfitting Analysis
print("Overfitting Analysis")

overfitting_report = []
for model in features_dict.keys():
    sex_overfit = sex_results[model]['is_overfitting']
    age_overfit = age_results[model]['is_overfitting']
    
    overfitting_report.append({
        'Model': model,
        'Sex_Overfitting': sex_overfit,
        'Sex_Train_Val_Gap': sex_results[model]['train_val_gap'],
        'Age_Overfitting': age_overfit,
        'Age_Train_Val_Gap': age_results[model]['train_val_gap']
    })
    
    if sex_overfit or age_overfit:
        print(f"{model}:")
        if sex_overfit:
            print(f"  Sex classification: Train-Val AUC gap = {sex_results[model]['train_val_gap']:.3f}")
        if age_overfit:
            print(f"  Age prediction: Train-Val R² gap = {age_results[model]['train_val_gap']:.3f}")

if not any(r['Sex_Overfitting'] or r['Age_Overfitting'] for r in overfitting_report):
    print("No significant overfitting detected")

pd.DataFrame(overfitting_report).to_csv('hbn_overfitting_analysis.csv', index=False)
print("\nSaved: hbn_overfitting_analysis.csv\n")

#Saving results
print("="*80)
print("Saving Results")
print("="*80)

all_models = ['DummyClassifier'] + list(features_dict.keys())

pd.DataFrame([{'Model': m, 'Best_Alpha': sex_results[m]['best_alpha'], 
               'Test_Balanced_Accuracy': sex_results[m]['test_balanced_accuracy'], 
               'Test_AUC': sex_results[m]['test_auc'],
               'CV_AUC_Mean': sex_results[m]['cv_auc_mean'],
               'CV_AUC_Std': sex_results[m]['cv_auc_std']} 
              for m in all_models]).to_csv('hbn_sex_classification_summary.csv', index=False)

all_models_reg = ['DummyRegressor'] + list(features_dict.keys())

pd.DataFrame([{'Model': m, 'Best_Alpha': age_results[m]['best_alpha'], 
               'Test_R2': age_results[m]['test_r2'], 
               'Test_MAE': age_results[m]['test_mae'], 
               'Test_RMSE': age_results[m]['test_rmse'],
               'CV_R2_Mean': age_results[m]['cv_r2_mean'],
               'CV_R2_Std': age_results[m]['cv_r2_std'],
               'CV_MAE_Mean': age_results[m]['cv_mae_mean'],
               'CV_MAE_Std': age_results[m]['cv_mae_std']} 
              for m in all_models_reg]).to_csv('hbn_age_prediction_summary.csv', index=False)

all_results = [
    {'Model': m, 
     'Sex_AUC': sex_results[m]['test_auc'], 
     'Sex_Bal_Acc': sex_results[m]['test_balanced_accuracy'],
     'Age_R2': age_results[m]['test_r2'], 
     'Age_MAE': age_results[m]['test_mae'], 
     'Age_RMSE': age_results[m]['test_rmse']} 
    for m in list(features_dict.keys())
]
pd.DataFrame(all_results).to_csv('hbn_detailed_comparison.csv', index=False)

rankings = []
for task in ['Sex_AUC', 'Age_R2']:
    for rank, item in enumerate(sorted(all_results, key=lambda x: x[task], reverse=True), 1):
        rankings.append({
            'Task': task.replace('_', ' '), 
            'Rank': rank, 
            'Model': item['Model'], 
            'Score': item[task]
        })
pd.DataFrame(rankings).to_csv('hbn_model_rankings.csv', index=False)

print("Saved: hbn_sex_classification_summary.csv")
print("Saved: hbn_age_prediction_summary.csv")
print("Saved: hbn_detailed_comparison.csv")
print("Saved: hbn_model_rankings.csv\n")

#Generate Learning Curve Visualizations
print("="*80)
print("Generating Learning Curves")
print("="*80)

models = list(features_dict.keys())
x_pos = np.arange(len(models))
colors = ['#3498db', '#2ecc71', '#e74c3c', '#9b59b6']

n_cv = len(age_labels) - int(0.1 * len(age_labels))
alphas = [0.1, 0.25, 0.5, 0.75, 1.0]
training_sizes = [int(alpha * n_cv * 0.8) for alpha in alphas]

# Learning curves WITH SHADED ERROR REGIONS
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

# Sex classification learning curve
for idx, model in enumerate(models):
    acc_means = []
    acc_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in sex_results[model]['cv_results']['acc']:
            acc_means.append(sex_results[model]['cv_results']['acc'][alpha])
            acc_stds.append(sex_results[model]['cv_results']['acc_std'][alpha])
            valid_sizes.append(training_sizes[i])
    
    acc_means = np.array(acc_means)
    acc_stds = np.array(acc_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax1.plot(valid_sizes, acc_means, 'o-', linewidth=2.5, markersize=8, 
             label=model, color=colors[idx])
    ax1.fill_between(valid_sizes, acc_means - acc_stds, acc_means + acc_stds, 
                     alpha=0.25, color=colors[idx])

dummy_acc_cv = sex_results['DummyClassifier']['cv_acc_mean']
ax1.axhline(y=dummy_acc_cv, color='red', linestyle='--', linewidth=2, 
            label=f'Dummy Classifier ({dummy_acc_cv:.3f})', alpha=0.7)

ax1.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax1.set_ylabel('Balanced Accuracy', fontsize=12, weight='bold')
ax1.set_title('Sex Classification: CV Learning Curve', fontsize=14, weight='bold')
ax1.legend(fontsize=10, loc='lower right')
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 1])

# Age prediction learning curve
for idx, model in enumerate(models):
    mae_means = []
    mae_stds = []
    valid_sizes = []
    
    for i, alpha in enumerate(alphas):
        if alpha in age_results[model]['cv_results']['mae']:
            mae_means.append(age_results[model]['cv_results']['mae'][alpha])
            mae_stds.append(age_results[model]['cv_results']['mae_std'][alpha])
            valid_sizes.append(training_sizes[i])
    
    mae_means = np.array(mae_means)
    mae_stds = np.array(mae_stds)
    valid_sizes = np.array(valid_sizes)
    
    ax2.plot(valid_sizes, mae_means, 'o-', linewidth=2.5, markersize=8, 
             label=model, color=colors[idx])
    ax2.fill_between(valid_sizes, mae_means - mae_stds, mae_means + mae_stds, 
                     alpha=0.25, color=colors[idx])

dummy_mae_cv = age_results['DummyRegressor']['cv_mae_mean']
ax2.axhline(y=dummy_mae_cv, color='red', linestyle='--', linewidth=2, 
            label=f'Dummy Regressor ({dummy_mae_cv:.2f})', alpha=0.7)

ax2.set_xlabel('Number of Training Subjects', fontsize=12, weight='bold')
ax2.set_ylabel('MAE (years)', fontsize=12, weight='bold')
ax2.set_title('Age Prediction: CV MAE Learning Curve', fontsize=14, weight='bold')
ax2.legend(fontsize=10, loc='upper right')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('hbn_cv_learning_curves.png', dpi=300, bbox_inches='tight')
plt.close()
print("Saved: hbn_cv_learning_curves.png\n")


print(f"Dataset: {len(common_subjects)} subjects")
print(f"Age range: {age_labels.min():.1f} - {age_labels.max():.1f} years")
print(f"Sex: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")

Overwriting /home/arelbaha/links/projects/rrg-glatard/arelbaha/debug_hbn_multimodel_comparison.py


In [None]:
%%writefile /home/arelbaha/links/projects/rrg-glatard/arelbaha/hbn_fs_save_freesurfer_data.py

import os
import glob
import random
import numpy as np
import pandas as pd
import pickle
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_absolute_error, balanced_accuracy_score, roc_auc_score
from sklearn.inspection import permutation_importance

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

FREESURFER_DIR = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_FreeSurfer/freesurfer"
DEMO_FILE = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/HBN_BIDS/final_preprocessed_subjects_with_demographics.tsv"
OUTPUT_DIR = "/home/arelbaha/links/projects/rrg-glatard/arelbaha/freesurfer_importance_outputs"
PARCELLATION = "Schaefer2018_400Parcels_17Networks_order"

os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Loading demographics...")
demo_df = pd.read_csv(DEMO_FILE, sep='\t')
demo_df['participant_id'] = demo_df['participant_id'].astype(str)

id_sex_dict = {}
id_age_dict = {}

for _, row in demo_df.iterrows():
    subject_id = row['participant_id']
    sex = row['sex'].strip()
    if sex == 'Female':
        id_sex_dict[subject_id] = 1
    elif sex == 'Male':
        id_sex_dict[subject_id] = 0
    id_age_dict[subject_id] = row['age']

print(f"Loaded demographics for {len(id_sex_dict)} subjects")
print(f"Sex distribution: {sum(id_sex_dict.values())} Female, {len(id_sex_dict) - sum(id_sex_dict.values())} Male")
print(f"Age range: {min(id_age_dict.values()):.1f} - {max(id_age_dict.values()):.1f} years\n")

print("Extracting FreeSurfer features...")
fs_data = {}

for subject_id in id_sex_dict.keys():
    stats_file = os.path.join(FREESURFER_DIR, f"sub-{subject_id}", 
                              f"sub-{subject_id}_regionsurfacestats.tsv")
    
    if os.path.exists(stats_file):
        try:
            df = pd.read_csv(stats_file, sep='\t')
            filtered_df = df[df["atlas"] == PARCELLATION].sort_values("StructName")
            
            if not filtered_df.empty and "SurfArea" in filtered_df.columns and "ThickAvg" in filtered_df.columns:
                surf_area = filtered_df["SurfArea"].values[:400]
                thick_avg = filtered_df["ThickAvg"].values[:400]
                
                if len(surf_area) == 400 and len(thick_avg) == 400:
                    combined = np.concatenate([surf_area, thick_avg])
                    
                    if not np.any(np.isnan(combined)):
                        fs_data[subject_id] = combined
        except Exception as e:
            print(f"Error reading {subject_id}: {e}")

print(f"Found {len(fs_data)} subjects with FreeSurfer data\n")

if len(fs_data) == 0:
    print("ERROR: No FreeSurfer data found!")
    exit(1)

common_subjects = sorted(list(fs_data.keys()))
fs_features = np.array([fs_data[s] for s in common_subjects])
sex_labels = np.array([id_sex_dict[s] for s in common_subjects])
age_labels = np.array([id_age_dict[s] for s in common_subjects])

print(f"Final dataset: {len(common_subjects)} subjects")
print(f"FreeSurfer features shape: {fs_features.shape}")
print(f"Sex: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"Age: mean={age_labels.mean():.1f}, range={age_labels.min():.1f}-{age_labels.max():.1f}\n")

print("Loading region names...")
sample_subject = common_subjects[0]
sample_file = os.path.join(FREESURFER_DIR, f'sub-{sample_subject}', 
                           f'sub-{sample_subject}_regionsurfacestats.tsv')
region_df = pd.read_csv(sample_file, sep='\t')
region_df = region_df[region_df['atlas'] == PARCELLATION].sort_values('StructName')
region_names = region_df['StructName'].values[:400]
hemisphere = region_df['hemisphere'].values[:400]

print("\n" + "="*80)
print("SPLITTING DATA")
print("="*80)

all_indices = np.arange(len(common_subjects))

# 90/10 split for sex classification and age prediction
cv_indices, test_indices = train_test_split(
    all_indices, 
    test_size=0.1,
    stratify=sex_labels,
    random_state=SEED
)

print(f"CV set: {len(cv_indices)} subjects")
print(f"Test set: {len(test_indices)} subjects")

print("\n" + "="*80)
print("CROSS-VALIDATION WITH ALPHA SUBSAMPLING")
print("="*80)

alphas = [0.1, 0.25, 0.5, 0.75, 1.0]
n_folds = 5

cv_results = {
    'sex': {'acc': {a: [] for a in alphas}, 'auc': {a: [] for a in alphas}},
    'age': {'mae': {a: [] for a in alphas}, 'r2': {a: [] for a in alphas}}
}

# SEX CLASSIFICATION CV
print("\nSex classification cross-validation...")
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=SEED)

for fold_idx, (train_rel, val_rel) in enumerate(skf.split(cv_indices, sex_labels[cv_indices])):
    train_idx = cv_indices[train_rel]
    val_idx = cv_indices[val_rel]
    
    for alpha in alphas:
        n_train = max(4, int(alpha * len(train_idx)))
        try:
            train_sub, _ = train_test_split(
                train_idx,
                train_size=n_train,
                stratify=sex_labels[train_idx],
                random_state=SEED
            )
        except:
            train_sub = np.random.choice(train_idx, n_train, replace=False)
        
        if len(np.unique(sex_labels[train_sub])) < 2:
            continue
        
        scaler = StandardScaler()
        X_train = scaler.fit_transform(fs_features[train_sub])
        X_val = scaler.transform(fs_features[val_idx])
        
        rf = RandomForestClassifier(
            n_estimators=500,
            max_depth=8,
            min_samples_split=20,
            min_samples_leaf=10,
            random_state=SEED,
            max_features='sqrt',
            class_weight='balanced',
            n_jobs=-1
        )
        rf.fit(X_train, sex_labels[train_sub])
        
        val_pred = rf.predict(X_val)
        val_proba = rf.predict_proba(X_val)[:, 1]
        
        cv_results['sex']['acc'][alpha].append(
            balanced_accuracy_score(sex_labels[val_idx], val_pred)
        )
        cv_results['sex']['auc'][alpha].append(
            roc_auc_score(sex_labels[val_idx], val_proba)
        )
    
    print(f"  Fold {fold_idx+1}/{n_folds} complete")

# AGE PREDICTION CV
print("\nAge prediction cross-validation...")
kf = KFold(n_splits=n_folds, shuffle=True, random_state=SEED)

for fold_idx, (train_rel, val_rel) in enumerate(kf.split(cv_indices)):
    train_idx = cv_indices[train_rel]
    val_idx = cv_indices[val_rel]
    
    for alpha in alphas:
        n_train = max(40, int(alpha * len(train_idx)))
        if n_train < len(train_idx):
            train_sub, _ = train_test_split(train_idx, train_size=n_train, random_state=SEED)
        else:
            train_sub = train_idx
        
        scaler = StandardScaler()
        X_train = scaler.fit_transform(fs_features[train_sub])
        X_val = scaler.transform(fs_features[val_idx])
        
        rf = RandomForestRegressor(
            n_estimators=200,
            max_depth=8,
            min_samples_split=20,
            min_samples_leaf=10,
            random_state=SEED,
            n_jobs=-1
        )
        rf.fit(X_train, age_labels[train_sub])
        
        val_pred = rf.predict(X_val)
        
        cv_results['age']['mae'][alpha].append(
            mean_absolute_error(age_labels[val_idx], val_pred)
        )
        cv_results['age']['r2'][alpha].append(
            r2_score(age_labels[val_idx], val_pred)
        )
    
    print(f"  Fold {fold_idx+1}/{n_folds} complete")

print("\n" + "="*80)
print("SELECTING BEST ALPHA FROM CV")
print("="*80)

avg_auc_sex = {a: np.mean(cv_results['sex']['auc'][a]) if cv_results['sex']['auc'][a] else 0 for a in alphas}
avg_r2_age = {a: np.mean(cv_results['age']['r2'][a]) if cv_results['age']['r2'][a] else 0 for a in alphas}

best_alpha_sex = max(avg_auc_sex, key=avg_auc_sex.get)
best_alpha_age = max(avg_r2_age, key=avg_r2_age.get)

print(f"Best alpha - Sex: {best_alpha_sex} (CV AUC: {avg_auc_sex[best_alpha_sex]:.4f})")
print(f"Best alpha - Age: {best_alpha_age} (CV R²: {avg_r2_age[best_alpha_age]:.4f})")

print("FINAL TEST SET EVALUATION & PERMUTATION IMPORTANCE EXTRACTION")

# SEX CLASSIFICATION
print("\nSex classification final model...")
n_train_sex = int(best_alpha_sex * len(cv_indices))
try:
    final_train_sex, _ = train_test_split(
        cv_indices,
        train_size=n_train_sex,
        stratify=sex_labels[cv_indices],
        random_state=SEED
    )
except:
    final_train_sex = cv_indices

scaler_sex = StandardScaler()
X_train_sex = scaler_sex.fit_transform(fs_features[final_train_sex])
X_test_sex = scaler_sex.transform(fs_features[test_indices])

rf_sex = RandomForestClassifier(
    n_estimators=500,
    max_depth=8,
    min_samples_split=20,
    min_samples_leaf=10,
    random_state=SEED,
    max_features='sqrt',
    class_weight='balanced',
    n_jobs=-1
)
rf_sex.fit(X_train_sex, sex_labels[final_train_sex])

test_pred_sex = rf_sex.predict(X_test_sex)
test_proba_sex = rf_sex.predict_proba(X_test_sex)[:, 1]
test_acc_sex = balanced_accuracy_score(sex_labels[test_indices], test_pred_sex)
test_auc_sex = roc_auc_score(sex_labels[test_indices], test_proba_sex)

# Permutation importance on TEST SET
print("  Computing permutation importance on test set...")
perm_sex = permutation_importance(
    rf_sex, X_test_sex, sex_labels[test_indices],
    n_repeats=10,
    random_state=SEED,
    scoring='balanced_accuracy',
    n_jobs=-1
)
sex_importance = perm_sex.importances_mean

print(f"  Test Balanced Accuracy: {test_acc_sex:.4f}")
print(f"  Test AUC: {test_auc_sex:.4f}")

# AGE PREDICTION
print("\nAge prediction final model...")
n_train_age = int(best_alpha_age * len(cv_indices))
if n_train_age < len(cv_indices):
    final_train_age, _ = train_test_split(cv_indices, train_size=n_train_age, random_state=SEED)
else:
    final_train_age = cv_indices

scaler_age = StandardScaler()
X_train_age = scaler_age.fit_transform(fs_features[final_train_age])
X_test_age = scaler_age.transform(fs_features[test_indices])

rf_age = RandomForestRegressor(
    n_estimators=200,
    max_depth=8,
    min_samples_split=20,
    min_samples_leaf=10,
    random_state=SEED,
    n_jobs=-1
)
rf_age.fit(X_train_age, age_labels[final_train_age])

test_pred_age = rf_age.predict(X_test_age)
test_r2_age = r2_score(age_labels[test_indices], test_pred_age)
test_mae_age = mean_absolute_error(age_labels[test_indices], test_pred_age)

# Permutation importance on TEST SET
print("  Computing permutation importance on test set...")
perm_age = permutation_importance(
    rf_age, X_test_age, age_labels[test_indices],
    n_repeats=10,
    random_state=SEED,
    n_jobs=-1
)
age_importance = perm_age.importances_mean

print(f"  Test R²: {test_r2_age:.4f}")
print(f"  Test MAE: {test_mae_age:.2f} years")

age_importance_sa = age_importance[:400]
age_importance_thick = age_importance[400:]
sex_importance_sa = sex_importance[:400]
sex_importance_thick = sex_importance[400:]

print("\n" + "="*80)
print("PERMUTATION IMPORTANCE STATISTICS")
print("="*80)

print(f"\nAge prediction:")
print(f"  Surface area: {age_importance_sa.sum():.4f} ({age_importance_sa.sum()*100:.1f}%)")
print(f"  Thickness: {age_importance_thick.sum():.4f} ({age_importance_thick.sum()*100:.1f}%)")

print(f"\nSex classification:")
print(f"  Surface area: {sex_importance_sa.sum():.4f} ({sex_importance_sa.sum()*100:.1f}%)")
print(f"  Thickness: {sex_importance_thick.sum():.4f} ({sex_importance_thick.sum()*100:.1f}%)")

print("\n" + "="*80)
print("TOP 10 MOST IMPORTANT REGIONS (PERMUTATION IMPORTANCE)")
print("="*80)

print("\nAge prediction - Surface area:")
top_age_sa = np.argsort(age_importance_sa)[-10:][::-1]
for rank, idx in enumerate(top_age_sa, 1):
    print(f"  {rank:2d}. {region_names[idx]:50s} = {age_importance_sa[idx]:.6f}")

print("\nAge prediction - Thickness:")
top_age_thick = np.argsort(age_importance_thick)[-10:][::-1]
for rank, idx in enumerate(top_age_thick, 1):
    print(f"  {rank:2d}. {region_names[idx]:50s} = {age_importance_thick[idx]:.6f}")

print("\nSex classification - Surface area:")
top_sex_sa = np.argsort(sex_importance_sa)[-10:][::-1]
for rank, idx in enumerate(top_sex_sa, 1):
    print(f"  {rank:2d}. {region_names[idx]:50s} = {sex_importance_sa[idx]:.6f}")

print("\nSex classification - Thickness:")
top_sex_thick = np.argsort(sex_importance_thick)[-10:][::-1]
for rank, idx in enumerate(top_sex_thick, 1):
    print(f"  {rank:2d}. {region_names[idx]:50s} = {sex_importance_thick[idx]:.6f}")

print("\n" + "="*80)
print("Saving results")
print("="*80)

fs_importance_data = {
    'age_importance_surface_area': age_importance_sa,
    'age_importance_thickness': age_importance_thick,
    'sex_importance_surface_area': sex_importance_sa,
    'sex_importance_thickness': sex_importance_thick,
    'region_names': region_names,
    'hemisphere': hemisphere,
    'n_subjects': len(common_subjects),
    'test_results': {
        'sex': {'accuracy': float(test_acc_sex), 'auc': float(test_auc_sex), 'alpha': best_alpha_sex},
        'age': {'r2': float(test_r2_age), 'mae': float(test_mae_age), 'alpha': best_alpha_age}
    },
    'cv_results': cv_results
}

output_file = os.path.join(OUTPUT_DIR, 'freesurfer_feature_importance.pkl')
with open(output_file, 'wb') as f:
    pickle.dump(fs_importance_data, f)
print(f"Saved: {output_file}")

# Save as CSV
all_features = list(range(400))
all_importance_age = np.concatenate([age_importance_sa, age_importance_thick])
all_importance_sex = np.concatenate([sex_importance_sa, sex_importance_thick])
all_types = ['Surface Area'] * 400 + ['Thickness'] * 400

importance_df = pd.DataFrame({
    'ROI_Index': all_features + all_features,
    'Region_Name': list(region_names) + list(region_names),
    'Hemisphere': list(hemisphere) + list(hemisphere),
    'Feature_Type': all_types,
    'Age_Permutation_Importance': all_importance_age,
    'Sex_Permutation_Importance': all_importance_sex,
})

for task, sort_col in [('age', 'Age_Permutation_Importance'), ('sex', 'Sex_Permutation_Importance')]:
    df_sorted = importance_df.sort_values(sort_col, ascending=False)
    csv_file = os.path.join(OUTPUT_DIR, f'freesurfer_permutation_importance_{task}_sorted.csv')
    df_sorted.to_csv(csv_file, index=False)
    print(f"Saved: {csv_file}")

print("\n" + "="*80)
print("COMPLETE!")
print(f"Analyzed {len(common_subjects)} subjects")
print(f"90/10 split: {len(cv_indices)} CV / {len(test_indices)} Test")
print(f"Permutation importance extracted on test set")
print("="*80)

In [None]:
%%writefile /home/arelbaha/links/projects/rrg-glatard/arelbaha/ppmi_fs_save_freesurfer_data.py

import os
import glob
import numpy as np
import pandas as pd
import pickle
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_absolute_error, balanced_accuracy_score, roc_auc_score

SEED = 42
np.random.seed(SEED)

CAT12_BASE_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/inputs"
DATA_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/data"
LABELS_PATH = "/home/arelbaha/links/projects/def-glatard/arelbaha/data/processed_cohort_with_mri.csv"
BRAINIAC_MAPPING_CSV = os.path.join(DATA_DIR, "processed_files_mapping.csv")
OUTPUT_DIR = "/home/arelbaha/links/projects/def-glatard/arelbaha/ppmi_freesurfer_importance_outputs"

os.makedirs(OUTPUT_DIR, exist_ok=True)

#Loading Demographics
print("Loading demographics...")
labels_df = pd.read_csv(LABELS_PATH)

id_sex_dict = {}
id_parkinson_dict = {}
id_age_dict = {}

for _, row in labels_df.iterrows():
    patno = str(int(row['PATNO']))
    sex = row['Sex'].strip().upper()
    if sex == 'F':
        id_sex_dict[patno] = 1
    elif sex == 'M':
        id_sex_dict[patno] = 0
    
    group = row['Group'].strip()
    if group == 'PD':
        id_parkinson_dict[patno] = 1
    elif group == 'HC':
        id_parkinson_dict[patno] = 0
    
    id_age_dict[patno] = row['Age']

print(f"Loaded demographics for {len(id_sex_dict)} subjects")

def extract_patno_from_path(filepath):
    for part in filepath.split(os.sep):
        if part.startswith('sub-'):
            return part[4:]
    return None

print("\nLoading CAT12 data")
cat12_files = glob.glob(os.path.join(CAT12_BASE_DIR, "**", "*s6mwp1*.nii*"), recursive=True)
cat12_data = {}
for f in cat12_files:
    if not os.path.isfile(f):
        continue
    patno = extract_patno_from_path(f)
    if patno and patno in id_sex_dict and patno in id_parkinson_dict and patno in id_age_dict:
        cat12_data[patno] = f

print(f"Found {len(cat12_data)} CAT12 subjects")

print("Loading BrainIAC data")
brainiac_df = pd.read_csv(BRAINIAC_MAPPING_CSV).dropna(subset=["processed_file", "Age", "subject_id"])
brainiac_data = {}
for _, row in brainiac_df.iterrows():
    patno = str(row['subject_id'])
    if patno in id_sex_dict and patno in id_parkinson_dict and patno in id_age_dict:
        if os.path.exists(row['processed_file']):
            brainiac_data[patno] = row['processed_file']

print(f"Found {len(brainiac_data)} BrainIAC subjects")

print("Loading FreeSurfer data")
fs_cth_df = pd.read_csv(os.path.join(CAT12_BASE_DIR, "FS7_APARC_CTH_23Oct2025.csv"))
fs_sa_df = pd.read_csv(os.path.join(CAT12_BASE_DIR, "FS7_APARC_SA_23Oct2025.csv"))

# Filter to baseline only
fs_cth_df = fs_cth_df[fs_cth_df['EVENT_ID'] == 'BL'].copy()
fs_sa_df = fs_sa_df[fs_sa_df['EVENT_ID'] == 'BL'].copy()

fs_cth_df['PATNO'] = fs_cth_df['PATNO'].astype(str)
fs_sa_df['PATNO'] = fs_sa_df['PATNO'].astype(str)

cth_features = [c for c in fs_cth_df.columns if c not in ['PATNO', 'EVENT_ID']]
sa_features = [c for c in fs_sa_df.columns if c not in ['PATNO', 'EVENT_ID']]

print(f"CTH features: {len(cth_features)}")
print(f"SA features: {len(sa_features)}")

print("Finding common subjects across all modalities")

# First get subjects with CAT12 and BrainIAC
common_subjects = sorted(list(set(cat12_data.keys()) & set(brainiac_data.keys())))
print(f"Subjects with CAT12 + BrainIAC: {len(common_subjects)}")

# Now filter for FreeSurfer
fs_data = {}
for patno in common_subjects:
    cth_row = fs_cth_df[fs_cth_df['PATNO'] == patno]
    sa_row = fs_sa_df[fs_sa_df['PATNO'] == patno]
    
    if len(cth_row) > 0 and len(sa_row) > 0:
        combined = np.concatenate([
            cth_row[cth_features].values.flatten(), 
            sa_row[sa_features].values.flatten()
        ])
        
        if not np.any(np.isnan(combined)):
            fs_data[patno] = combined

# Final common subjects with ALL modalities
common_subjects = sorted(list(set(cat12_data.keys()) & set(brainiac_data.keys()) & set(fs_data.keys())))

print(f"Subjects with CAT12 + BrainIAC + FreeSurfer: {len(common_subjects)}")
print(f"\nThese {len(common_subjects)} subjects have ALL modalities and will be used for analysis")

fs_features = np.array([fs_data[p] for p in common_subjects])
sex_labels = np.array([id_sex_dict[p] for p in common_subjects])
parkinson_labels = np.array([id_parkinson_dict[p] for p in common_subjects])
age_labels = np.array([id_age_dict[p] for p in common_subjects])

print(f"\nFinal dataset: {len(common_subjects)} subjects")
print(f"FreeSurfer features shape: {fs_features.shape}")
print(f"Sex: {np.sum(sex_labels)} Female, {len(sex_labels) - np.sum(sex_labels)} Male")
print(f"PD: {np.sum(parkinson_labels)} PD, {len(parkinson_labels) - np.sum(parkinson_labels)} HC")
print(f"Age: mean={age_labels.mean():.1f}, range={age_labels.min():.1f}-{age_labels.max():.1f}\n")

print("Parsing region names and hemispheres")

cth_region_names = []
cth_hemispheres = []
for feat in cth_features:
    if feat.startswith('lh_'):
        hemi = 'L'
        region = feat[3:].replace('_thickness', '')
    elif feat.startswith('rh_'):
        hemi = 'R'
        region = feat[3:].replace('_thickness', '')
    else:
        hemi = 'L'
        region = feat
    cth_region_names.append(region)
    cth_hemispheres.append(hemi)

sa_region_names = []
sa_hemispheres = []
for feat in sa_features:
    if feat.startswith('lh_'):
        hemi = 'L'
        region = feat[3:].replace('_area', '')
    elif feat.startswith('rh_'):
        hemi = 'R'
        region = feat[3:].replace('_area', '')
    else:
        hemi = 'L'
        region = feat
    sa_region_names.append(region)
    sa_hemispheres.append(hemi)

print("\n" + "="*80)
print("SPLITTING DATA")
print("="*80)

all_indices = np.arange(len(common_subjects))

# 90/10 split for sex classification and age prediction
cv_indices, test_indices = train_test_split(
    all_indices, 
    test_size=0.1,  # 10% for test
    stratify=sex_labels,
    random_state=SEED
)

# 80/20 split for Parkinson's (needs more test samples)
cv_indices_pd, test_indices_pd = train_test_split(
    all_indices,
    test_size=0.2,  # 20% for test
    stratify=parkinson_labels,
    random_state=SEED
)

print(f"Sex/Age CV set: {len(cv_indices)} subjects, Test set: {len(test_indices)} subjects")
print(f"PD CV set: {len(cv_indices_pd)} subjects, Test set: {len(test_indices_pd)} subjects")

print("\n" + "="*80)
print("CROSS-VALIDATION WITH ALPHA SUBSAMPLING")
print("="*80)

alphas = [0.1, 0.25, 0.5, 0.75, 1.0]
n_folds = 5

cv_results = {
    'sex': {'acc': {a: [] for a in alphas}, 'auc': {a: [] for a in alphas}},
    'pd': {'acc': {a: [] for a in alphas}, 'auc': {a: [] for a in alphas}},
    'age': {'mae': {a: [] for a in alphas}, 'r2': {a: [] for a in alphas}}
}

# SEX CLASSIFICATION CV
print("\nSex classification cross-validation...")
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=SEED)

for fold_idx, (train_rel, val_rel) in enumerate(skf.split(cv_indices, sex_labels[cv_indices])):
    train_idx = cv_indices[train_rel]
    val_idx = cv_indices[val_rel]
    
    for alpha in alphas:
        n_train = max(4, int(alpha * len(train_idx)))
        try:
            train_sub, _ = train_test_split(
                train_idx,
                train_size=n_train,
                stratify=sex_labels[train_idx],
                random_state=SEED
            )
        except:
            train_sub = np.random.choice(train_idx, n_train, replace=False)
        
        if len(np.unique(sex_labels[train_sub])) < 2:
            continue
        
        # Train model
        scaler = StandardScaler()
        X_train = scaler.fit_transform(fs_features[train_sub])
        X_val = scaler.transform(fs_features[val_idx])
        
        rf = RandomForestClassifier(
            n_estimators=500, 
            random_state=SEED,
            max_features='sqrt',
            class_weight='balanced',
            n_jobs=-1
        )
        rf.fit(X_train, sex_labels[train_sub])
        
        # Evaluate
        val_pred = rf.predict(X_val)
        val_proba = rf.predict_proba(X_val)[:, 1]
        
        cv_results['sex']['acc'][alpha].append(
            balanced_accuracy_score(sex_labels[val_idx], val_pred)
        )
        cv_results['sex']['auc'][alpha].append(
            roc_auc_score(sex_labels[val_idx], val_proba)
        )
    
    print(f"  Fold {fold_idx+1}/{n_folds} complete")

# PARKINSON'S CLASSIFICATION CV
print("\nParkinson's classification cross-validation...")
skf_pd = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=SEED)

for fold_idx, (train_rel, val_rel) in enumerate(skf_pd.split(cv_indices_pd, parkinson_labels[cv_indices_pd])):
    train_idx = cv_indices_pd[train_rel]
    val_idx = cv_indices_pd[val_rel]
    
    for alpha in alphas:
        n_train = max(4, int(alpha * len(train_idx)))
        try:
            train_sub, _ = train_test_split(
                train_idx,
                train_size=n_train,
                stratify=parkinson_labels[train_idx],
                random_state=SEED
            )
        except:
            train_sub = np.random.choice(train_idx, n_train, replace=False)
        
        if len(np.unique(parkinson_labels[train_sub])) < 2:
            continue
        
        scaler = StandardScaler()
        X_train = scaler.fit_transform(fs_features[train_sub])
        X_val = scaler.transform(fs_features[val_idx])
        
        rf = RandomForestClassifier(
            n_estimators=500,
            random_state=SEED,
            max_features='sqrt',
            class_weight='balanced',
            n_jobs=-1
        )
        rf.fit(X_train, parkinson_labels[train_sub])
        
        val_pred = rf.predict(X_val)
        val_proba = rf.predict_proba(X_val)[:, 1]
        
        cv_results['pd']['acc'][alpha].append(
            balanced_accuracy_score(parkinson_labels[val_idx], val_pred)
        )
        cv_results['pd']['auc'][alpha].append(
            roc_auc_score(parkinson_labels[val_idx], val_proba)
        )
    
    print(f"  Fold {fold_idx+1}/{n_folds} complete")

# AGE PREDICTION CV
print("\nAge prediction cross-validation...")
kf = KFold(n_splits=n_folds, shuffle=True, random_state=SEED)

for fold_idx, (train_rel, val_rel) in enumerate(kf.split(cv_indices)):
    train_idx = cv_indices[train_rel]
    val_idx = cv_indices[val_rel]
    
    for alpha in alphas:
        n_train = max(40, int(alpha * len(train_idx)))
        if n_train < len(train_idx):
            train_sub, _ = train_test_split(train_idx, train_size=n_train, random_state=SEED)
        else:
            train_sub = train_idx
        
        scaler = StandardScaler()
        X_train = scaler.fit_transform(fs_features[train_sub])
        X_val = scaler.transform(fs_features[val_idx])
        
        rf = RandomForestRegressor(
            n_estimators=200,
            max_depth=10,
            min_samples_split=5,
            random_state=SEED,
            n_jobs=-1
        )
        rf.fit(X_train, age_labels[train_sub])
        
        val_pred = rf.predict(X_val)
        
        cv_results['age']['mae'][alpha].append(
            mean_absolute_error(age_labels[val_idx], val_pred)
        )
        cv_results['age']['r2'][alpha].append(
            r2_score(age_labels[val_idx], val_pred)
        )
    
    print(f"  Fold {fold_idx+1}/{n_folds} complete")


print("\n" + "="*80)
print("SELECTING BEST ALPHA FROM CV")
print("="*80)

avg_auc_sex = {a: np.mean(cv_results['sex']['auc'][a]) if cv_results['sex']['auc'][a] else 0 for a in alphas}
avg_auc_pd = {a: np.mean(cv_results['pd']['auc'][a]) if cv_results['pd']['auc'][a] else 0 for a in alphas}
avg_r2_age = {a: np.mean(cv_results['age']['r2'][a]) if cv_results['age']['r2'][a] else 0 for a in alphas}

best_alpha_sex = max(avg_auc_sex, key=avg_auc_sex.get)
best_alpha_pd = max(avg_auc_pd, key=avg_auc_pd.get)
best_alpha_age = max(avg_r2_age, key=avg_r2_age.get)

print(f"Best alpha - Sex: {best_alpha_sex} (CV AUC: {avg_auc_sex[best_alpha_sex]:.4f})")
print(f"Best alpha - PD: {best_alpha_pd} (CV AUC: {avg_auc_pd[best_alpha_pd]:.4f})")
print(f"Best alpha - Age: {best_alpha_age} (CV R²: {avg_r2_age[best_alpha_age]:.4f})")

print("\n" + "="*80)
print("FINAL TEST SET EVALUATION & FEATURE IMPORTANCE EXTRACTION")
print("="*80)

# SEX CLASSIFICATION
print("\nSex classification final model...")
n_train_sex = int(best_alpha_sex * len(cv_indices))
try:
    final_train_sex, _ = train_test_split(
        cv_indices,
        train_size=n_train_sex,
        stratify=sex_labels[cv_indices],
        random_state=SEED
    )
except:
    final_train_sex = cv_indices

scaler_sex = StandardScaler()
X_train_sex = scaler_sex.fit_transform(fs_features[final_train_sex])
X_test_sex = scaler_sex.transform(fs_features[test_indices])

rf_sex = RandomForestClassifier(
    n_estimators=500,
    random_state=SEED,
    max_features='sqrt',
    class_weight='balanced',
    n_jobs=-1
)
rf_sex.fit(X_train_sex, sex_labels[final_train_sex])

test_pred_sex = rf_sex.predict(X_test_sex)
test_proba_sex = rf_sex.predict_proba(X_test_sex)[:, 1]
test_acc_sex = balanced_accuracy_score(sex_labels[test_indices], test_pred_sex)
test_auc_sex = roc_auc_score(sex_labels[test_indices], test_proba_sex)

sex_importance = rf_sex.feature_importances_

print(f"  Test Balanced Accuracy: {test_acc_sex:.4f}")
print(f"  Test AUC: {test_auc_sex:.4f}")

# PARKINSON'S CLASSIFICATION
print("\nParkinson's classification final model...")
n_train_pd = int(best_alpha_pd * len(cv_indices_pd))
try:
    final_train_pd, _ = train_test_split(
        cv_indices_pd,
        train_size=n_train_pd,
        stratify=parkinson_labels[cv_indices_pd],
        random_state=SEED
    )
except:
    final_train_pd = cv_indices_pd

scaler_pd = StandardScaler()
X_train_pd = scaler_pd.fit_transform(fs_features[final_train_pd])
X_test_pd = scaler_pd.transform(fs_features[test_indices_pd])

rf_pd = RandomForestClassifier(
    n_estimators=500,
    random_state=SEED,
    max_features='sqrt',
    class_weight='balanced',
    n_jobs=-1
)
rf_pd.fit(X_train_pd, parkinson_labels[final_train_pd])

test_pred_pd = rf_pd.predict(X_test_pd)
test_proba_pd = rf_pd.predict_proba(X_test_pd)[:, 1]
test_acc_pd = balanced_accuracy_score(parkinson_labels[test_indices_pd], test_pred_pd)
test_auc_pd = roc_auc_score(parkinson_labels[test_indices_pd], test_proba_pd)

pd_importance = rf_pd.feature_importances_

print(f"  Test Balanced Accuracy: {test_acc_pd:.4f}")
print(f"  Test AUC: {test_auc_pd:.4f}")

# AGE PREDICTION
print("\nAge prediction final model...")
n_train_age = int(best_alpha_age * len(cv_indices))
if n_train_age < len(cv_indices):
    final_train_age, _ = train_test_split(cv_indices, train_size=n_train_age, random_state=SEED)
else:
    final_train_age = cv_indices

scaler_age = StandardScaler()
X_train_age = scaler_age.fit_transform(fs_features[final_train_age])
X_test_age = scaler_age.transform(fs_features[test_indices])

rf_age = RandomForestRegressor(
    n_estimators=200,
    max_depth=10,
    min_samples_split=5,
    random_state=SEED,
    n_jobs=-1
)
rf_age.fit(X_train_age, age_labels[final_train_age])

test_pred_age = rf_age.predict(X_test_age)
test_r2_age = r2_score(age_labels[test_indices], test_pred_age)
test_mae_age = mean_absolute_error(age_labels[test_indices], test_pred_age)

age_importance = rf_age.feature_importances_

print(f"  Test R²: {test_r2_age:.4f}")
print(f"  Test MAE: {test_mae_age:.2f} years")

n_cth = len(cth_features)
n_sa = len(sa_features)

age_importance_cth = age_importance[:n_cth]
age_importance_sa = age_importance[n_cth:]
sex_importance_cth = sex_importance[:n_cth]
sex_importance_sa = sex_importance[n_cth:]
pd_importance_cth = pd_importance[:n_cth]
pd_importance_sa = pd_importance[n_cth:]

print("\n" + "="*80)
print("FEATURE IMPORTANCE STATISTICS")
print("="*80)

print(f"\nAge prediction:")
print(f"  CTH: {age_importance_cth.sum():.4f} ({age_importance_cth.sum()*100:.1f}%)")
print(f"  SA: {age_importance_sa.sum():.4f} ({age_importance_sa.sum()*100:.1f}%)")

print(f"\nSex classification:")
print(f"  CTH: {sex_importance_cth.sum():.4f} ({sex_importance_cth.sum()*100:.1f}%)")
print(f"  SA: {sex_importance_sa.sum():.4f} ({sex_importance_sa.sum()*100:.1f}%)")

print(f"\nParkinson's disease:")
print(f"  CTH: {pd_importance_cth.sum():.4f} ({pd_importance_cth.sum()*100:.1f}%)")
print(f"  SA: {pd_importance_sa.sum():.4f} ({pd_importance_sa.sum()*100:.1f}%)")

print("\n" + "="*80)
print("TOP 10 REGIONS PER TASK")
print("="*80)

print("\nAge - Thickness:")
top_age_cth = np.argsort(age_importance_cth)[-10:][::-1]
for rank, idx in enumerate(top_age_cth, 1):
    print(f"  {rank:2d}. {cth_region_names[idx]:30s} ({cth_hemispheres[idx]}) = {age_importance_cth[idx]:.6f} ({age_importance_cth[idx]*100:.2f}%)")

print("\nAge - Surface Area:")
top_age_sa = np.argsort(age_importance_sa)[-10:][::-1]
for rank, idx in enumerate(top_age_sa, 1):
    print(f"  {rank:2d}. {sa_region_names[idx]:30s} ({sa_hemispheres[idx]}) = {age_importance_sa[idx]:.6f} ({age_importance_sa[idx]*100:.2f}%)")

print("\nSex - Surface Area:")
top_sex_sa = np.argsort(sex_importance_sa)[-10:][::-1]
for rank, idx in enumerate(top_sex_sa, 1):
    print(f"  {rank:2d}. {sa_region_names[idx]:30s} ({sa_hemispheres[idx]}) = {sex_importance_sa[idx]:.6f} ({sex_importance_sa[idx]*100:.2f}%)")

print("\nParkinson's - Thickness:")
top_pd_cth = np.argsort(pd_importance_cth)[-10:][::-1]
for rank, idx in enumerate(top_pd_cth, 1):
    print(f"  {rank:2d}. {cth_region_names[idx]:30s} ({cth_hemispheres[idx]}) = {pd_importance_cth[idx]:.6f} ({pd_importance_cth[idx]*100:.2f}%)")

print("\n" + "="*80)
print("Saving results")
print("="*80)

ppmi_importance_data = {
    'age_importance_cth': age_importance_cth,
    'age_importance_sa': age_importance_sa,
    'sex_importance_cth': sex_importance_cth,
    'sex_importance_sa': sex_importance_sa,
    'pd_importance_cth': pd_importance_cth,
    'pd_importance_sa': pd_importance_sa,
    'cth_region_names': cth_region_names,
    'cth_hemispheres': cth_hemispheres,
    'sa_region_names': sa_region_names,
    'sa_hemispheres': sa_hemispheres,
    'n_subjects': len(common_subjects),
    'n_pd': int(np.sum(parkinson_labels)),
    'n_hc': int(len(parkinson_labels) - np.sum(parkinson_labels)),
    'age_mean': float(age_labels.mean()),
    'age_range': (float(age_labels.min()), float(age_labels.max())),
    # Add test set performance
    'test_results': {
        'sex': {'accuracy': float(test_acc_sex), 'auc': float(test_auc_sex), 'alpha': best_alpha_sex},
        'pd': {'accuracy': float(test_acc_pd), 'auc': float(test_auc_pd), 'alpha': best_alpha_pd},
        'age': {'r2': float(test_r2_age), 'mae': float(test_mae_age), 'alpha': best_alpha_age}
    },
    # Add CV results for visualization
    'cv_results': cv_results
}

output_file = os.path.join(OUTPUT_DIR, 'ppmi_freesurfer_feature_importance.pkl')
with open(output_file, 'wb') as f:
    pickle.dump(ppmi_importance_data, f)
print(f"Saved: {output_file}")

# Save CSVs
all_features = cth_features + sa_features
all_importance_age = np.concatenate([age_importance_cth, age_importance_sa])
all_importance_sex = np.concatenate([sex_importance_cth, sex_importance_sa])
all_importance_pd = np.concatenate([pd_importance_cth, pd_importance_sa])
all_types = ['CTH'] * n_cth + ['SA'] * n_sa

importance_df = pd.DataFrame({
    'Feature': all_features,
    'Type': all_types,
    'Age_Importance': all_importance_age,
    'Age_Percent': all_importance_age * 100,
    'Sex_Importance': all_importance_sex,
    'Sex_Percent': all_importance_sex * 100,
    'PD_Importance': all_importance_pd,
    'PD_Percent': all_importance_pd * 100
})

for task, sort_col in [('age', 'Age_Importance'), ('sex', 'Sex_Importance'), ('pd', 'PD_Importance')]:
    df_sorted = importance_df.sort_values(sort_col, ascending=False)
    csv_file = os.path.join(OUTPUT_DIR, f'ppmi_feature_importance_{task}_sorted.csv')
    df_sorted.to_csv(csv_file, index=False)
    print(f"Saved: {csv_file}")

print("\n" + "="*80)
print("COMPLETE!")
print(f"Analyzed {len(common_subjects)} subjects with ALL modalities")
print(f"Feature importance extracted from models trained on CV set (best alpha)")
print(f"Test set performance evaluated on held-out {len(test_indices)} subjects")
print("Download .pkl file for visualization")
print("="*80)