## Classification with shrinkage-regularized Linear Discriminant Analysis

This notebook evaluates the performance of shrinkage-regularized Linear Discriminant Analysis (sLDA) for the classification of EEG error-related potentials (ErrPs) using features extracted by the Fisher Criterion Beamformer (FCB).

Steps:

1. Loads preprocessed EEG features, labels, subject/session IDs, and trials from disk (output of feature_extraction_selection.ipynb).
2. Searches for the optimal number of FCB projections (spatial filters) by cross-validating sLDA classifiers on the training data.
3. Runs stratified K-fold cross-validation to assess classification performance and saves results (metrics, parameters) for each configuration.
4. Evaluates generalization across subjects and sessions using leave-one-group-out (LOGO) cross-validation.
5. Stores results in JSON files.

All classification metrics and parameter settings are saved for later analysis.
All code relies on the bci_utils.py utility file for data processing, cross-validation, and results saving.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import json
import bci_utils

from datetime import datetime

from sklearn.svm import LinearSVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import (
    confusion_matrix, precision_score, recall_score, f1_score,
    balanced_accuracy_score, roc_auc_score, roc_curve, accuracy_score
)

In [None]:
# Load preprocessed data with feature extraction from feature_extraction_selection.ipynb
all_epochs_proj = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_epochs_proj.npy")
all_labels_proj = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_labels_proj.npy") 
all_subjects_proj = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_subjects_proj.npy") 
all_sessions_proj = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_sessions_proj.npy") 
all_trials_proj = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_trials_proj.npy") 

In [None]:
# Define labels for training and testing
y_fcb = all_labels_proj

### Stratified 5-fold cross-validation

In [None]:
projections = 1

# Search for best number of projections
for i in range(1, all_epochs_proj.shape[2]):
    
    all_epochs_proj_ = all_epochs_proj[:, : , :i]
    
    # Flatten epochs for sLDA: (n_epochs, n_timepoints, n_channels) --> (n_epochs, n_timepoints * n_channels)
    X_fcb = all_epochs_proj_.reshape(all_epochs_proj_.shape[0], -1)

    metrics, fold_means, fold_stds = bci_utils.crossval_metrics_stratified_kfold(X_fcb, y_fcb, n_splits=5, plot_roc=True, random_state=42)

    params = {
        "classifier": "sLDA",
        "cv_method": "StratifiedKFold",
        "bandpass": "0.5-10",
        "epoch_window": "200-600 ms",
        "feature_extraction": f"FCB_{projections}"
    }

    bci_utils.save_crossval_results(
        "crossval_metrics_stratified_kfold", metrics, fold_means, fold_stds, params
    )
    
    projections = projections + 1

### Leave-one-subject-out cross-validation

In [None]:
# Continue with number of projections that yielded highest balanced accuracy
all_epochs_proj_ = all_epochs_proj[:, : , :55]
X_fcb = all_epochs_proj_.reshape(all_epochs_proj_.shape[0], -1)

In [None]:
%%time

metrics, fold_means, fold_stds = bci_utils.crossval_metrics_leave_one_group(X_fcb, y_fcb, all_subjects_proj, plot_roc=True)

params = {
    "classifier": "sLDA",
    "cv_method": "LOGO-subject",
    "bandpass": "0.5-10",
    "epoch_window": "200-600 ms",
    "feature_extraction": "FCB_55"
}

bci_utils.save_crossval_results(
    "crossval_metrics_stratified_kfold", metrics, fold_means, fold_stds, params
)

### Leave-one-session-out cross-validation

In [None]:
metrics, fold_means, fold_stds = bci_utils.crossval_metrics_leave_one_group(X_fcb, y_fcb, all_sessions_proj, plot_roc=True)

params = {
    "classifier": "sLDA",
    "cv_method": "LOGO-session",
    "bandpass": "0.5-10 Hz",
    "epoch_window": "200-600 ms",
    "feature_extraction": "FCB_55",
}

bci_utils.save_crossval_results(
    "crossval_metrics_leave_one_group", metrics, fold_means, fold_stds, params
)