## Classification with GAN-augmentation and Transformer

This notebook implements and evaluates a GAN-augmented Transformer pipeline for the classification of EEG error-related potentials (ErrPs).

Steps:

1. Loads preprocessed EEG epochs, labels, subject/session information, and trial indices from disk (prepared by data_preprocessing_aggregation.ipynb).
2. Splits data into error and correct trials to enable label-aware processing and augmentation.
3. Trains a Generative Adversarial Network (GAN) to synthesize error trials, addressing class imbalance in the training set for each cross-validation fold.
4. Applies a Transformer-based neural network classifier, using GAN-augmented and real training data for model fitting.
5. Performs stratified K-fold cross-validation to robustly evaluate classification performance.
6. Collects, summarizes, and saves metrics (e.g., balanced accuracy, F1, recall) along with experiment parameters to JSON files for reproducibility and further analysis.

All code is modular and uses the bci_utils.py toolkit for model definition, augmentation, cross-validation, and results management.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import json
import bci_utils

from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import (
    confusion_matrix, precision_score, recall_score, f1_score,
    balanced_accuracy_score, roc_auc_score, roc_curve, accuracy_score
)

In [None]:
# Load preprocessed data output from data_preprocessing_aggregation.ipynb
all_epochs = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_epochs.npy")
all_labels = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_labels.npy") 
all_subjects = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_subjects.npy") 
all_sessions = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_sessions.npy") 
all_trials = np.load("/Users/Rosie/Documents/Applications/HRC_BCI_VU/Casus_BCI_classifier/data_preprocessed/all_trials.npy") 

In [None]:
# Error and correct masks
error_mask = all_labels == 1
correct_mask = all_labels == 0

# Split arrays
error_epochs = all_epochs[error_mask]
correct_epochs = all_epochs[correct_mask]

error_labels = all_labels[error_mask]
correct_labels = all_labels[correct_mask]

print('Error epochs shape:', error_epochs.shape)
print('Correct epochs shape:', correct_epochs.shape)

print('Error labels shape:', error_labels.shape)
print('Correct labels shape:', correct_labels.shape)

### Stratified 5-fold cross-validation

In [None]:
%%time

# Run cross-validation for GAN-augmented Transformer classifier on EEG error potential data
# Only the training set in each fold is augmented with synthetic error trials (never the test set)
pooled_metrics, fold_means, fold_stds = bci_utils.crossval_gan_augmented_transformer(
    correct_epochs, error_epochs, correct_labels, error_labels,
    n_splits=5, latent_dim=32, n_gan_epochs=1000,
    transformer_epochs=20, batch_size=32, lr=1e-3, plot_roc=False, random_state=42
)

# Store parameters used for this experiment (for traceability in result files)
params = {
    "classifier": "Transformer",        
    "cv_method": "StratifiedKFold",        
    "n_splits": 5,                     
    "bandpass": "0.5-10 Hz",              
    "epoch_window": "209-600 ms",        
    "augmentation": "GAN"   
}

# Save cross-validation results (metrics, parameters, timestamp) to a JSON file for reproducibility
bci_utils.save_crossval_results(
    "crossval_metrics_stratified_kfold", pooled_metrics, fold_means, fold_stds, params
)

### Leave-one-subject-out cross-validation

In [None]:
%%time

# Run cross-validation for GAN-augmented Transformer classifier on EEG error potential data
# Only the training set in each fold is augmented with synthetic error trials (never the test set)
# For leave-one-subject-out:
pooled_metrics, fold_means, fold_stds = bci_utils.crossval_gan_augmented_transformer_logo(
    correct_epochs, error_epochs, correct_labels, error_labels,
    all_subjects,  # shape (n_trials,)
    latent_dim=32, n_gan_epochs=1000, transformer_epochs=20, batch_size=32, lr=1e-3, plot_roc=True
)
print(pooled_metrics)

params = {
    "classifier": "Transformer",
    "cv_method": "LOGO-subject",
    "bandpass": "0.5-10 Hz",
    "epoch_window": "209-600 ms",
    "augmentation": "GAN"
}

bci_utils.save_crossval_results(
    "crossval_metrics_stratified_kfold", pooled_metrics, fold_means, fold_stds, params
)

### Leave-one-session-out cross-validation

In [None]:
%%time

# Run cross-validation for GAN-augmented Transformer classifier on EEG error potential data
# Only the training set in each fold is augmented with synthetic error trials (never the test set)
# For leave-one-session-out:
pooled_metrics, fold_means, fold_stds = bci_utils.crossval_gan_augmented_transformer_logo(
    correct_epochs, error_epochs, correct_labels, error_labels,
    all_sessions,  # shape (n_trials,)
    latent_dim=32, n_gan_epochs=1000, transformer_epochs=20, batch_size=32, lr=1e-3, plot_roc=True
)
print(pooled_metrics)

params = {
    "classifier": "Transformer",
    "cv_method": "LOGO-session",
    "bandpass": "0.5-10 Hz",
    "epoch_window": "209-600 ms",
    "augmentation": "GAN"
}

bci_utils.save_crossval_results(
    "crossval_metrics_stratified_kfold", pooled_metrics, fold_means, fold_stds, params
)