In [1]:
import numpy as np
from moabb.datasets import BNCI2014_001
from moabb.paradigms import MotorImagery, LeftRightImagery
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from mne.decoding import CSP
from moabb.evaluations import CrossSubjectEvaluation
from sklearn.pipeline import make_pipeline
from scipy import signal
from scipy.io import loadmat
import os
import mne

In [2]:
active_all_event_ids = {'769': 769, '770': 770, '771': 771, '772': 772}
active_lr_event_ids = {'769': 769, '770': 770}
unknown_event_id = {'783': 783} 

In [3]:
data_dir = '/home/vishwa/eeg_tl/Recreating papers/BCICIV_2a'

In [4]:
# Define a causal bandpass filter function using a Butterworth design.
def causal_bandpass_filter(data, lowcut=8, highcut=30, fs=250, order=50):
    nyq = 0.5 * fs
    # Normalize the cutoff frequencies (Matlab's fir1 expects normalized cutoff frequencies
    low = lowcut / nyq
    high = highcut / nyq
    # Design the FIR filter. Note: order+1 coefficients are returned to match Matlab's fir1 which returns n+1 taps.
    b = signal.firwin(order + 1, [low, high], window='hamming', pass_zero=False)
    # Apply the filter causally using lfilter (this introduces a constant delay).
    filtered_data = signal.lfilter(b, [1.0], data)
    return filtered_data

In [5]:
# With a sampling frequency of 250 Hz, 1001 samples equate to 1001/250 seconds.
sfreq = 250
tmin = 0       # Epoch start at cue onset.
# Set tmax so that n_samples = (tmax-tmin)*sfreq + 1 = 1001, i.e. 4 seconds long.
tmax = (1001 - 1) / sfreq  # This gives 4.0 seconds.

# Train - Active - done

In [6]:
train_active_X = []         # List to hold numpy arrays with shape (n_trials, 22, 1001) per subject.
train_active_y = []         # List to hold event labels per subject.
train_active_metadata = []  # List to hold event metadata per subject.

# Loop over subjects. Assume files are named "A01T.gdf", "A02T.gdf", ..., "A09T.gdf".
for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}T.gdf')
    
    # Read the GDF file (using preload=True to load data into memory).
    train_raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False, eog=['EOG-left', 'EOG-central', 'EOG-right'])
    # Retain only EEG channels (22 channels) and exclude EOG channels.
    train_raw.pick_types(eeg=True, eog=False)
    
    # Extract events corresponding only to the four desired types.
    train_active_events, _ = mne.events_from_annotations(train_raw, event_id=active_all_event_ids)
    
    # Create epochs from tmin to tmax.
    train_active_epochs = mne.Epochs(train_raw, train_active_events, event_id=active_all_event_ids, tmin=tmin, tmax=tmax,
                        baseline=None, preload=True, verbose=False)
    
    # Get the epoch data (num_epochs x 22 channels x 1001 samples).
    train_active_data = train_active_epochs.get_data()
    
    # Print the number of extracted epochs to verify
    print(f"Subject {subj}: Epoch data shape {train_active_data.shape}")
    
    # Sampling frequency from raw.info (should be 250).
    fs = int(train_raw.info['sfreq'])
    # print(fs)
    n_trials, n_channels, n_times = train_active_data.shape
    train_active_filtered_data = np.empty_like(train_active_data)
    
    # Apply the causal bandpass filter channel‐wise for each trial.
    for trial in range(n_trials):
        for ch in range(n_channels):
            train_active_filtered_data[trial, ch, :] = causal_bandpass_filter(
                train_active_data[trial, ch, :],
                lowcut=8,    # Lower bound of sensorimotor rhythm.
                highcut=30,  # Upper bound of sensorimotor rhythm.
                fs=fs,
                order=50     # Lower order for a smoother causal filter.
            )
    
    # Append the processed data, labels, and event metadata.
    train_active_X.append(train_active_filtered_data)
    train_active_y.append(train_active_epochs.events[:, 2])  # The third column holds the event code.
    train_active_metadata.append(train_active_epochs.events)

print("Loaded data for", len(train_active_X), "subjects.")

  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 1: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 2: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 3: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 4: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 5: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 6: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 7: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 8: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 9: Epoch data shape (288, 22, 1001)
Loaded data for 9 subjects.


# Train - Resting

In [7]:
train_resting_X = []
train_resting_metadata = []

for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}T.gdf')
    
    # Load the GDF file and select EEG channels.
    train_raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False, 
                              eog=['EOG-left', 'EOG-central', 'EOG-right'])
    train_raw.pick_types(eeg=True, eog=False)
    
    # Extract active events (e.g. left/right motor imagery cues).
    train_resting_events, _ = mne.events_from_annotations(train_raw, event_id=active_all_event_ids)
    
    # Create active epochs.
    train_resting_epochs = mne.Epochs(train_raw, train_resting_events, event_id=active_all_event_ids,
                               tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False)
    
    # Now derive resting events by shifting each active event by 6 seconds (trial end).
    fs = int(train_raw.info['sfreq'])
    train_resting_events[:, 0] += int(6 * fs)
    
    # Create resting epochs of 1.5 s duration.
    train_resting_epochs = mne.Epochs(train_raw, train_resting_events, tmin=0, tmax=1.5,
                                baseline=None, preload=True, verbose=False)
    
    train_resting_data = train_resting_epochs.get_data()
    print(f"Subject {subj}: Resting data shape {train_resting_data.shape}")
    
    n_trials, n_channels, n_times = train_resting_data.shape
    train_resting_filtered_data = np.empty_like(train_resting_data)
    
    for trial in range(n_trials):
        for ch in range(n_channels):
            train_resting_filtered_data[trial, ch, :] = causal_bandpass_filter(
                train_resting_data[trial, ch, :],
                lowcut=8,
                highcut=30,
                fs=fs,
                order=50
            )
    
    train_resting_X.append(train_resting_filtered_data)
    train_resting_metadata.append(train_resting_epochs.events)

print("Loaded resting data for", len(train_resting_X), "subjects.")


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 1: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 2: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 3: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 4: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 5: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 6: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 7: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 8: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 9: Resting data shape (287, 22, 376)
Loaded resting data for 9 subjects.


# Eval - Active

In [8]:
eval_active_X = []         
eval_active_y = []        
eval_active_metadata = [] 

# Loop over subjects. Assume files are named "A01T.gdf", "A02T.gdf", ..., "A09T.gdf".
for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}E.gdf')
    mat_data = loadmat(f'/home/vishwa/eeg_tl/Recreating papers/BCICIV_2A true labels/A{subj:02d}E.mat')
    true_y =  np.array(mat_data['classlabel'], dtype=np.int64).reshape(288,) + 768
    # Read the GDF file (using preload=True to load data into memory).
    eval_raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False, eog=['EOG-left', 'EOG-central', 'EOG-right'])
    # Retain only EEG channels (22 channels) and exclude EOG channels.
    eval_raw.pick_types(eeg=True, eog=False)
    
    # Extract events corresponding only to the four desired types.
    eval_active_events, _ = mne.events_from_annotations(eval_raw, event_id=unknown_event_id)
    
    # Create epochs from tmin to tmax.
    eval_active_epochs = mne.Epochs(eval_raw, eval_active_events, event_id=unknown_event_id, tmin=tmin, tmax=tmax,
                        baseline=None, preload=True, verbose=False)
    
    # Get the epoch data (num_epochs x 22 channels x 1001 samples).
    eval_active_data = eval_active_epochs.get_data()
    
    # Print the number of extracted epochs to verify
    print(f"Subject {subj}: Epoch data shape {eval_active_data.shape}")
    
    # Sampling frequency from raw.info (should be 250).
    fs = int(eval_raw.info['sfreq'])
    # print(fs)
    n_trials, n_channels, n_times = eval_active_data.shape
    eval_active_filtered_data = np.empty_like(eval_active_data)
    
    # Apply the causal bandpass filter channel‐wise for each trial.
    for trial in range(n_trials):
        for ch in range(n_channels):
            eval_active_filtered_data[trial, ch, :] = causal_bandpass_filter(
                eval_active_data[trial, ch, :],
                lowcut=8,    # Lower bound of sensorimotor rhythm.
                highcut=30,  # Upper bound of sensorimotor rhythm.
                fs=fs,
                order=50     # Lower order for a smoother causal filter.
            )
    
    # Append the processed data, labels, and event metadata.
    eval_active_X.append(eval_active_filtered_data)
    eval_active_y.append(true_y)  # The third column holds the event code.
    eval_active_metadata.append(eval_active_epochs.events)

print("Loaded data for", len(eval_active_X), "subjects.")

  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 1: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 2: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 3: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 4: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 5: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 6: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 7: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 8: Epoch data shape (288, 22, 1001)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 9: Epoch data shape (288, 22, 1001)
Loaded data for 9 subjects.


# Eval - Resting

In [9]:
eval_resting_X = []
eval_resting_metadata = []

for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}E.gdf')
   
    # Load the GDF file and select EEG channels.
    eval_raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False,
                              eog=['EOG-left', 'EOG-central', 'EOG-right'])
    eval_raw.pick_types(eeg=True, eog=False)
   
    # Extract resting events (e.g. left/right motor imagery cues).
    eval_resting_events, _ = mne.events_from_annotations(eval_raw, event_id=unknown_event_id)
   
    # Create resting epochs.
    eval_resting_epochs = mne.Epochs(eval_raw, eval_resting_events, event_id=unknown_event_id,
                               tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False)
   
    # Now derive resting events by shifting each resting event by 6 seconds (trial end).
    fs = int(eval_raw.info['sfreq'])
    # eval_resting_events = eval_resting_events.copy()
    eval_resting_events[:, 0] += int(6 * fs)
   
    # Create resting epochs of 1.5 s duration.
    eval_resting_epochs = mne.Epochs(eval_raw, eval_resting_events, tmin=0, tmax=1.5,
                                baseline=None, preload=True, verbose=False)
   
    eval_resting_data = eval_resting_epochs.get_data()
    print(f"Subject {subj}: Resting data shape {eval_resting_data.shape}")
   
    n_trials, n_channels, n_times = eval_resting_data.shape
    eval_resting_filtered_data = np.empty_like(eval_resting_data)
   
    for trial in range(n_trials):
        for ch in range(n_channels):
            eval_resting_filtered_data[trial, ch, :] = causal_bandpass_filter(
                eval_resting_data[trial, ch, :],
                lowcut=8,
                highcut=30,
                fs=fs,
                order=50
            )
   
    eval_resting_X.append(eval_resting_filtered_data)
    # For resting epochs, labels might not be needed or can be set to a dummy value.
    eval_resting_metadata.append(eval_resting_epochs.events)


print("Loaded resting data for", len(eval_resting_X), "subjects.")

  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 1: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 2: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 3: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 4: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 5: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 6: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 7: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 8: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 9: Resting data shape (287, 22, 376)
Loaded resting data for 9 subjects.
