In [2]:
# %% [code]
"""
ML Pipeline with Class Imbalance Correction

This script:
1. Loads aggregated data from all subjects.
2. One-hot encodes the ROI information from the 'roiNum' column.
3. Defines an explicit feature list (bandpower, FOOOF, entropy, catch22, and ROI dummies).
4. Splits the data at the subject level.
5. Standardizes features.
6. Trains models (Random Forest, Logistic Regression, SVM, and XGBoost) using group-aware CV,
   setting class_weight='balanced' (or scale_pos_weight for XGBoost) to correct for class imbalance.
7. Evaluates models on a held-out test set and prints detailed performance metrics.

Requirements:
    pip install numpy pandas scikit-learn matplotlib pycatch22 xgboost
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pycatch22

from sklearn.model_selection import train_test_split, GroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay
)
from sklearn.inspection import permutation_importance

In [3]:
##################################
# 1) DATA LOADING & MERGING
##################################

def convert_spared_to_label(val):
    """Convert 'spared' to binary: 0 if spared (TRUE), 1 if resected (FALSE)."""
    if isinstance(val, bool):
        return 0 if val else 1
    elif isinstance(val, str):
        return 0 if val.strip().upper() == 'TRUE' else 1
    else:
        return 0 if bool(val) else 1

def load_all_subjects(subjects_dir, subject_list):
    """
    Loads each subject's aggregated pickle file (e.g. sub-RIDXXXX_features_averaged.pkl),
    adds a 'subject_id' column, converts 'spared' to a binary label, and concatenates all data.
    """
    all_dfs = []
    for subj in subject_list:
        subj_path = os.path.join(subjects_dir, subj)
        pkl_file = os.path.join(subj_path, f"{subj}_features_averaged.pkl")
        if not os.path.isfile(pkl_file):
            print(f"Warning: file not found: {pkl_file}")
            continue
        df = pd.read_pickle(pkl_file)
        df['subject_id'] = subj
        if 'spared' not in df.columns:
            raise ValueError(f"'spared' column missing for {subj}")
        df['label'] = df['spared'].apply(convert_spared_to_label)
        all_dfs.append(df)
    if not all_dfs:
        raise ValueError("No data loaded. Check paths or subject list.")
    combined_df = pd.concat(all_dfs, axis=0, ignore_index=True)
    return combined_df

# Define subjects directory and list (adjust paths as needed)
subjects_dir = "/Users/tereza/nishant/atlas/atlas_work_terez/atlas_harmonization/Data/hup/derivatives/clean"
subject_list = [
    "sub-RID0031", "sub-RID0032", "sub-RID0033", "sub-RID0050", "sub-RID0051", 
    "sub-RID0064", "sub-RID0089", "sub-RID0101", "sub-RID0117", "sub-RID0143",
    "sub-RID0167", "sub-RID0175", "sub-RID0179", "sub-RID0238", "sub-RID0301",
    "sub-RID0320", "sub-RID0381", "sub-RID0405", "sub-RID0424", "sub-RID0508",
    "sub-RID0562", "sub-RID0589", "sub-RID0658"
]

print("Loading data from all subjects...")
combined_df = load_all_subjects(subjects_dir, subject_list)
print(f"Combined data shape: {combined_df.shape}")

##################################
# 2) INCLUDE ROI AS A FEATURE
##################################
# We assume the ROI identifier is stored in 'roiNum'
# Convert 'roiNum' to a categorical variable (as string) and one-hot encode it.
combined_df['roiNum_cat'] = combined_df['roiNum'].astype(str)
roi_dummies = pd.get_dummies(combined_df['roiNum_cat'], prefix='roiNum')
combined_df = pd.concat([combined_df, roi_dummies], axis=1)
print("ROI one-hot encoding completed.")

##################################
# 3) DEFINE EXPLICIT FEATURE LIST
##################################
def get_explicit_feature_list():
    # Bandpower features: 5 bands x 3 metrics = 15 features
    band_names = ['delta', 'theta', 'alpha', 'beta', 'gamma']
    band_features = [f"{band}_{metric}" for band in band_names for metric in ['power', 'rel', 'log']]
    
    # FOOOF features: 5 features
    fooof_features = [
        'fooof_aperiodic_offset', 
        'fooof_aperiodic_exponent', 
        'fooof_r_squared', 
        'fooof_error', 
        'fooof_num_peaks'
    ]
    
    # Entropy feature: 1 feature
    entropy_features = ['entropy_5secwin']
    
    # catch22 features: typically ~22 features (depends on pycatch22 output)
    dummy = np.random.randn(100).tolist()
    res = pycatch22.catch22_all(dummy, catch24=False)
    catch22_features = [f"catch22_{nm}" for nm in res['names']]
    
    # ROI one-hot encoded features: all columns that start with "roiNum"
    roi_features = [col for col in combined_df.columns if col.startswith("roiNum")]
    
    return band_features + fooof_features + entropy_features + catch22_features + roi_features

explicit_feature_list = get_explicit_feature_list()
# Only include features present in combined_df:
present_features = [feat for feat in explicit_feature_list if feat in combined_df.columns]
print(f"Using explicit feature list with {len(present_features)} features.")

##################################
# 4) PREPROCESSING
##################################
# Drop rows with missing values in the explicit feature columns.
combined_df = combined_df.dropna(subset=present_features)
X_full = combined_df[present_features].values
y_full = combined_df['label'].values

##################################
# 5) TRAIN-TEST SPLIT AT SUBJECT LEVEL
##################################
unique_subjects = combined_df['subject_id'].unique()
print("Total unique subjects:", len(unique_subjects))
from sklearn.model_selection import train_test_split
train_subjects, test_subjects = train_test_split(unique_subjects, test_size=0.2, random_state=42)
print("Training subjects:", train_subjects)
print("Testing subjects:", test_subjects)

# Create masks for train and test based on subject_id
train_mask = combined_df['subject_id'].isin(train_subjects)
test_mask = combined_df['subject_id'].isin(test_subjects)

X_train = combined_df[train_mask][present_features].values
y_train = combined_df[train_mask]['label'].values
X_test = combined_df[test_mask][present_features].values
y_test = combined_df[test_mask]['label'].values

print(f"Training set shape: {X_train.shape}")
print(f"Test set shape: {X_test.shape}")

# Standardize features based on training data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

##################################
# 6) TRAIN MODELS USING GROUP-AWARE CROSS-VALIDATION
##################################
from sklearn.model_selection import GroupKFold

def train_and_evaluate_with_groups(X, y, groups, model_choice='random_forest'):
    if model_choice == 'random_forest':
        model = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
    elif model_choice == 'logistic':
        model = LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced')
    elif model_choice == 'svm':
        model = SVC(kernel='rbf', probability=True, random_state=42, class_weight='balanced')
    elif model_choice == 'xgboost':
        pos = np.sum(y == 1)
        neg = np.sum(y == 0)
        scale_pos_weight = neg / pos if pos > 0 else 1
        model = XGBClassifier(random_state=42, scale_pos_weight=scale_pos_weight,
                              use_label_encoder=False, eval_metric='logloss')
    else:
        raise ValueError("model_choice must be one of ['random_forest','logistic','svm','xgboost']")
    
    gkf = GroupKFold(n_splits=5)
    accuracies, aucs = [], []
    for train_idx, val_idx in gkf.split(X, y, groups=groups):
        X_train_cv, X_val_cv = X[train_idx], X[val_idx]
        y_train_cv, y_val_cv = y[train_idx], y[val_idx]
        model.fit(X_train_cv, y_train_cv)
        y_pred_cv = model.predict(X_val_cv)
        acc = accuracy_score(y_val_cv, y_pred_cv)
        accuracies.append(acc)
        if hasattr(model, 'predict_proba'):
            y_proba_cv = model.predict_proba(X_val_cv)[:, 1]
            aucs.append(roc_auc_score(y_val_cv, y_proba_cv))
    print(f"\n=== {model_choice.upper()} GROUP-CV Results ===")
    print(f"Accuracy: {np.mean(accuracies):.3f} ± {np.std(accuracies):.3f}")
    if aucs:
        print(f"ROC AUC: {np.mean(aucs):.3f} ± {np.std(aucs):.3f}")
    
    # Refit on full training data
    model.fit(X, y)
    return model

# Create a groups vector from training data (subject IDs for each electrode)
train_groups = combined_df[train_mask]['subject_id'].values

# List of models to train
models_to_run = ['random_forest', 'logistic', 'svm', 'xgboost']
trained_models = {}

for model_name in models_to_run:
    print("\n=======================================")
    print(f"Training Model: {model_name.upper()}")
    clf = train_and_evaluate_with_groups(X_train_scaled, y_train, groups=train_groups, model_choice=model_name)
    
    # Evaluate on test set
    y_pred_test = clf.predict(X_test_scaled)
    print(f"\n--- {model_name.upper()} TEST SET PERFORMANCE ---")
    print("Accuracy:", accuracy_score(y_test, y_pred_test))
    if hasattr(clf, 'predict_proba'):
        y_proba_test = clf.predict_proba(X_test_scaled)[:, 1]
        print("ROC AUC:", roc_auc_score(y_test, y_proba_test))
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred_test, zero_division=0))
    cm = confusion_matrix(y_test, y_pred_test)
    print("Confusion Matrix:")
    print(cm)
    
    trained_models[model_name] = clf

##################################
# 7) PRINT SUMMARY OF RESULTS
##################################
print("\n=== Detailed Summary of All Model Results on Test Set ===")

for model_name in models_to_run:
    print("\n------------------------------")
    print(f"Model: {model_name.upper()}")
    model = trained_models[model_name]
    y_pred = model.predict(X_test_scaled)
    if hasattr(model, 'predict_proba'):
        y_proba = model.predict_proba(X_test_scaled)[:, 1]
        roc_auc = roc_auc_score(y_test, y_proba)
    else:
        roc_auc = "N/A"
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy: {accuracy:.3f}")
    print(f"ROC AUC: {roc_auc}")
    print("Classification Report:")
    print(classification_report(y_test, y_pred, zero_division=0))
    cm = confusion_matrix(y_test, y_pred)
    print("Confusion Matrix:")
    print(cm)


Loading data from all subjects...
Combined data shape: (1675, 51)
ROI one-hot encoding completed.
Using explicit feature list with 118 features.
Total unique subjects: 23
Training subjects: ['sub-RID0179' 'sub-RID0032' 'sub-RID0238' 'sub-RID0064' 'sub-RID0033'
 'sub-RID0175' 'sub-RID0562' 'sub-RID0050' 'sub-RID0051' 'sub-RID0424'
 'sub-RID0381' 'sub-RID0589' 'sub-RID0658' 'sub-RID0101' 'sub-RID0167'
 'sub-RID0301' 'sub-RID0508' 'sub-RID0089']
Testing subjects: ['sub-RID0320' 'sub-RID0143' 'sub-RID0031' 'sub-RID0117' 'sub-RID0405']
Training set shape: (1251, 118)
Test set shape: (424, 118)

Training Model: RANDOM_FOREST

=== RANDOM_FOREST GROUP-CV Results ===
Accuracy: 0.857 ± 0.054
ROC AUC: 0.579 ± 0.127

--- RANDOM_FOREST TEST SET PERFORMANCE ---
Accuracy: 0.8349056603773585
ROC AUC: 0.6467701392490922

Classification Report:
              precision    recall  f1-score   support

           0       0.83      1.00      0.91       353
           1       1.00      0.01      0.03        7

Parameters: { "use_label_encoder" } are not used.

Parameters: { "use_label_encoder" } are not used.

Parameters: { "use_label_encoder" } are not used.

Parameters: { "use_label_encoder" } are not used.

Parameters: { "use_label_encoder" } are not used.




=== XGBOOST GROUP-CV Results ===
Accuracy: 0.834 ± 0.051
ROC AUC: 0.614 ± 0.091


Parameters: { "use_label_encoder" } are not used.




--- XGBOOST TEST SET PERFORMANCE ---
Accuracy: 0.8042452830188679
ROC AUC: 0.5937038662570323

Classification Report:
              precision    recall  f1-score   support

           0       0.85      0.93      0.89       353
           1       0.32      0.15      0.21        71

    accuracy                           0.80       424
   macro avg       0.58      0.54      0.55       424
weighted avg       0.76      0.80      0.77       424

Confusion Matrix:
[[330  23]
 [ 60  11]]

=== Detailed Summary of All Model Results on Test Set ===

------------------------------
Model: RANDOM_FOREST
Accuracy: 0.835
ROC AUC: 0.6467701392490922
Classification Report:
              precision    recall  f1-score   support

           0       0.83      1.00      0.91       353
           1       1.00      0.01      0.03        71

    accuracy                           0.83       424
   macro avg       0.92      0.51      0.47       424
weighted avg       0.86      0.83      0.76       424

Confusi