In [1]:
import numpy as np
from moabb.datasets import BNCI2014_001
from moabb.paradigms import MotorImagery, LeftRightImagery
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from mne.decoding import CSP
from moabb.evaluations import CrossSubjectEvaluation
from sklearn.pipeline import make_pipeline
from scipy import signal

In [None]:
import os
import numpy as np
import mne
from scipy.signal import butter, lfilter

# Define a causal bandpass filter function using a Butterworth design.
def causal_bandpass_filter(data, lowcut=8, highcut=30, fs=250, order=50):
    nyq = 0.5 * fs
    # Normalize the cutoff frequencies (Matlab's fir1 expects normalized cutoff frequencies
    low = lowcut / nyq
    high = highcut / nyq
    # Design the FIR filter. Note: order+1 coefficients are returned to match Matlab's fir1 which returns n+1 taps.
    b = signal.firwin(order + 1, [low, high], window='hamming', pass_zero=False)
    # Apply the filter causally using lfilter (this introduces a constant delay).
    filtered_data = signal.lfilter(b, [1.0], data)
    return filtered_data

# Set the data directory where your GDF files are stored.
data_dir = '/home/vishwa/eeg_tl/Recreating papers/BCICIV_2a'  # Replace with your actual path

# Containers to hold processed data for all subjects.
X = []         # List to hold numpy arrays with shape (n_trials, 22, 1001) per subject.
y = []         # List to hold event labels per subject.
metadata = []  # List to hold event metadata per subject.


event_id = {'769': 769, '770': 770}
# event_id = {'277': 277}
# Epoch parameters:

# With a sampling frequency of 250 Hz, 1001 samples equate to 1001/250 seconds.
sfreq = 250
tmin = 0       # Epoch start at cue onset.
# Set tmax so that n_samples = (tmax-tmin)*sfreq + 1 = 1001, i.e. 4 seconds long.
tmax = (1001 - 1) / sfreq  # This gives 4.0 seconds.

# Loop over subjects. Assume files are named "A01T.gdf", "A02T.gdf", ..., "A09T.gdf".
for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}T.gdf')
    
    # Read the GDF file (using preload=True to load data into memory).
    raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False, eog=['EOG-left', 'EOG-central', 'EOG-right'])
    # Retain only EEG channels (22 channels) and exclude EOG channels.
    raw.pick_types(eeg=True, eog=False)
    
    # Extract events corresponding only to the four desired types.
    events, _ = mne.events_from_annotations(raw, event_id=event_id)
    
    # Create epochs from tmin to tmax.
    epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                        baseline=None, preload=True, verbose=False)
    
    # Get the epoch data (num_epochs x 22 channels x 1001 samples).
    data = epochs.get_data()
    
    # Print the number of extracted epochs to verify
    print(f"Subject {subj}: Epoch data shape {data.shape}")
    
    # Sampling frequency from raw.info (should be 250).
    fs = int(raw.info['sfreq'])
    # print(fs)
    n_trials, n_channels, n_times = data.shape
    filtered_data = np.empty_like(data)
    
    # Apply the causal bandpass filter channel‐wise for each trial.
    for trial in range(n_trials):
        for ch in range(n_channels):
            filtered_data[trial, ch, :] = causal_bandpass_filter(
                data[trial, ch, :],
                lowcut=8,    # Lower bound of sensorimotor rhythm.
                highcut=30,  # Upper bound of sensorimotor rhythm.
                fs=fs,
                order=50     # Lower order for a smoother causal filter.
            )
    
    # Append the processed data, labels, and event metadata.
    X.append(filtered_data)
    y.append(epochs.events[:, 2])  # The third column holds the event code.
    metadata.append(epochs.events)

print("Loaded data for", len(X), "subjects.")

  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 1: Epoch data shape (144, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 2: Epoch data shape (144, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 3: Epoch data shape (144, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 4: Epoch data shape (144, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 5: Epoch data shape (144, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 6: Epoch data shape (144, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 7: Epoch data shape (144, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 8: Epoch data shape (144, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770']
Subject 9: Epoch data shape (144, 22, 501)
Loaded data for 9 subjects.


In [3]:
X = np.array(X)
y = np.array(y)

In [4]:
print(X[0].shape)

(144, 22, 501)


In [5]:
print(y[0].shape)

(144,)


In [6]:
X_aligned = []
def EA(X):
    covs = np.array([trial @ trial.T for trial in X])  # shape [n_trials, n_channels, n_channels]
    mean_cov = np.mean(covs, axis=0)
    # Eigen-decomposition of mean_cov
    eigvals, eigvecs = np.linalg.eigh(mean_cov)
    D_inv_sqrt = np.diag(eigvals**(-0.5))
    R_inv_sqrt = eigvecs @ D_inv_sqrt @ eigvecs.T

    # Align each trial
    xaligned = np.array([R_inv_sqrt @ trial for trial in X])
    X_aligned.append(xaligned)

for i in X:
    EA(i)


In [7]:
n_components = 4
csp = CSP(n_components=n_components, reg=None, log=True, norm_trace=False)
lda = LinearDiscriminantAnalysis()

In [8]:
accuracies = []

# Assuming:
# X_aligned = [subject1_data, ..., subject9_data] where each subject_data.shape = (288, 22, 1001)
# y = [subject1_labels, ..., subject9_labels] where each labels.shape = (288,)

for subj_idx in range(len(X_aligned)):
    # print(subj_idx)
    # Split data into train/test using leave-one-subject-out
    X_test = X_aligned[subj_idx]
    y_test = y[subj_idx]
    
    # Concatenate data from other subjects
    X_train = np.concatenate([X_aligned[i] for i in range(len(X_aligned)) if i != subj_idx], axis=0)
    y_train = np.concatenate([y[i] for i in range(len(y)) if i != subj_idx], axis=0)
    # Create pipeline
    pipeline = Pipeline([
        ('CSP', CSP(n_components=n_components, reg=None, log=True, norm_trace=False)),
        ('LDA', LinearDiscriminantAnalysis())
    ])

    # Fit and predict
    pipeline.fit(X_train, y_train)
    accuracy = pipeline.score(X_test, y_test)
    accuracies.append(accuracy)
    
for subj_idx in range(len(X_aligned)):
    print(f"Subject {subj_idx+1} Test Accuracy: {accuracies[subj_idx]:.2f}")

print(f"\nMean Cross-Validation Accuracy: {np.mean(accuracies):.2f} ± {np.std(accuracies):.2f}")


Computing rank from data with rank=None
    Using tolerance 0.17 (2.2e-16 eps * 22 dim * 3.4e+13  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=769 covariance using EMPIRICAL
Done.
Estimating class=770 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 0.17 (2.2e-16 eps * 22 dim * 3.4e+13  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=769 covariance using EMPIRICAL
Done.
Estimating class=770 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 0.17 (2.2e-16 eps * 22 dim * 3.4e+13  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=769 covariance using EM