# EEGDash Example for Auditory Oddball Classification

This code demonstrates using the *EEGDash* library with PyTorch to classify EEG responses in an auditory oddball paradigm.

1. **Data Description**: Dataset contains EEG recordings during an auditory oddball task with two stimulus types:
   - Standard: 500 Hz tone
   - Oddball: 1000 Hz tone

2. **Data Preprocessing**: 
   - Applies bandpass filtering (1-55 Hz)
   - Selects 24 Channels
   - Creates event-based windows
   - Processes data in batches for memory efficiency

3. **Dataset Preparation**: 
   - Remaps events into two classes: oddball, standard
   - Splits into training (80%) and test (20%) sets
   - Creates PyTorch DataLoaders

4. **Model**: 
   - ShallowFBCSPNet architecture
   - 24 input channels, 2 output classes
   - 256-sample input windows

5. **Training**: 
   - Adamax optimizer with learning rate decay
   - 5 training epochs
   - Reports accuracy on train and test sets

## Data Retrieval Using EEGDash

Data retrieved from https://nemar.org/dataexplorer/detail?dataset_id=ds003061.

Download locally and change the path.

In [1]:
from eegdash.data_utils import EEGBIDSDataset

dataset = EEGBIDSDataset(
    data_dir='d:/Users/vivian/Desktop/UCSD/EEG/ds003061/ds003061',
    dataset='ds003061'
)

all_files = dataset.get_files()

## Data Preprocessing Using Braindecode

[Braindecode](https://braindecode.org/) provides a powerful framework for EEG data preprocessing and analysis. This implementation processes EEG files in batches to efficiently manage memory usage while handling multiple recordings.

The preprocessing pipeline consists of several key steps:

1. **Batch Processing**: Files can be processed in small batches recordings to optimize memory usage and processing efficiency. Each batch is loaded, processed, and converted to windows before moving to the next batch.

2. **Channel Selecting**: Select 24 specific EEG channels from the original 79.

3. **Signal Filtering**: Bandpass filtering between 1 Hz and 55 Hz to remove noise and unwanted frequency components

4. **Event Processing**: For each recording:
   - Events are extracted from annotations using MNE's events_from_annotations
   - The last event is removed to prevent time duration issues
   - Events are then converted back to annotations with proper timing information

5. **Window Creation**: The create_windows_from_events function extracts epochs from the continuous data:
   - Windows are created with 128 samples before and after each event
   - Data is loaded on demand (preload=False) to maintain memory efficiency
   - Each window is automatically associated with its corresponding event type

In [2]:
from braindecode.preprocessing import preprocess, Preprocessor, create_windows_from_events
import mne
from mne.io import read_raw_eeglab
from braindecode.datasets import BaseConcatDataset, BaseDataset
import warnings
import mne
import logging
mne.set_log_level('ERROR')  
logging.getLogger('joblib').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

test_files = all_files[0:3] # Select files from a single subject
print("\ntest files:")
for i, file in enumerate(test_files):
    print(f"{i+1}. {file}")
    

batch_size = 1 # When selecting multiple subjects, it can be processed in batch to reduce memory usage.
all_windows_datasets = [] 

preprocessors = [
    Preprocessor('pick_channels', ch_names=[
        'Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 
        'O1', 'O2', 'F7', 'F8', 'T7', 'T8', 'P7', 'P8',
        'Fz', 'Cz', 'Pz', 'Oz', 'FC1', 'FC2', 'CP1', 'CP2'
    ]),
    Preprocessor("filter", l_freq=1, h_freq=55)
]

for batch_start in range(0, len(test_files), batch_size):
    batch_end = batch_start + batch_size
    batch_files = test_files[batch_start:batch_end]
    batch_windows_datasets = []
    
    for i, file in enumerate(batch_files):
        raw = read_raw_eeglab(file, preload=False)
        raw_dataset = BaseDataset(raw, target_name=None)
        single_ds = BaseConcatDataset([raw_dataset])
        ds_preprocessed = preprocess(single_ds, preprocessors)
        
        raw = ds_preprocessed.datasets[0].raw
        events, event_dict = mne.events_from_annotations(raw)
        
        # remove the last event to avoid time duration issues
        events = events[:-1]
        
        # create a reverse mapping for event descriptions
        reverse_event_dict = {v: k for k, v in event_dict.items()}
        annot_from_events = mne.annotations_from_events(
            events=events,
            event_desc=reverse_event_dict,  
            sfreq=raw.info['sfreq']
        )
        raw.set_annotations(annot_from_events)
        

        # create windows from events
        file_windows_ds = create_windows_from_events(
            ds_preprocessed,
            trial_start_offset_samples=-128,    
            trial_stop_offset_samples=128,     
            preload=False                 
        )
        batch_windows_datasets.extend(file_windows_ds.datasets)
    
    all_windows_datasets.extend(batch_windows_datasets)

# combine all datasets into a single dataset
windows_ds = BaseConcatDataset(all_windows_datasets)
print(f"\n All batches processed, total number of windows: {len(windows_ds)}") 
print(f"Window shape: {windows_ds[0][0].shape}")
print("\nevent mapping:")
print("Event number -> Event name:")
for event_name, event_number in event_dict.items():
    print(f"{event_number} -> {event_name}")


test files:
1. d:\Users\vivian\Desktop\UCSD\EEG\ds003061\ds003061\sub-001\eeg\sub-001_task-P300_run-1_eeg.set
2. d:\Users\vivian\Desktop\UCSD\EEG\ds003061\ds003061\sub-001\eeg\sub-001_task-P300_run-2_eeg.set
3. d:\Users\vivian\Desktop\UCSD\EEG\ds003061\ds003061\sub-001\eeg\sub-001_task-P300_run-3_eeg.set

 All batches processed, total number of windows: 2582
Window shape: (24, 256)

event mapping:
Event number -> Event name:
1 -> ignore
2 -> noise
3 -> oddball
4 -> oddball_with_reponse
5 -> response
6 -> standard
7 -> standard_with_reponse


## Creating training and test sets

The FilteredDataset class processes the windowed data by remapping event labels into two categories: oddball (0) and standard (1). The data preparation pipeline consists of these key steps:

1. **Data Filtering and Label Remapping** - The FilteredDataset processes windows_ds by keeping only relevant events and mapping them to two categories. Labels 3,4 are mapped to oddball (0), and labels 6,7 to standard (1).

2. **Train-Test Split** - Using sklearn's train_test_split, the dataset is divided into 80% training and 20% testing sets. The split is stratified to maintain class proportions across both sets.

3. **PyTorch Data Preparation** - The split datasets are converted to PyTorch tensors and wrapped in DataLoader objects with a batch size of 10, enabling efficient training with shuffled mini-batches.

The resulting training and test sets maintain balanced class distributions, ensuring representative samples for both model training and evaluation.

In [3]:
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# set random seed 
random_state = 42
torch.manual_seed(random_state)
np.random.seed(random_state)

class FilteredDataset:
    def __init__(self, windows_ds):
        self.data = []
        self.labels = []
        
        # remap labels according to the specified mapping
        label_mapping = {
            3: 0, 4: 0,  # oddball
            6: 1, 7: 1   # standard
        }
        
        for i in range(len(windows_ds)):
            label = windows_ds[i][1]
            if label in label_mapping:
                self.data.append(windows_ds[i][0])
                self.labels.append(label_mapping[label])
        
        self.data = np.array(self.data)
        self.labels = np.array(self.labels)
    
    def __len__(self):
        return len(self.labels)

# create a filtered dataset
filtered_ds = FilteredDataset(windows_ds)
print(f"Filtered dataset size:  {len(filtered_ds)}")
print(f"Data shape: {filtered_ds.data.shape}")
print("Distribution of labels after filtering and remapping:", np.unique(filtered_ds.labels, return_counts=True))
labels = filtered_ds.labels
print("Label meanings: 0=oddball, 1=standard")

# divide the dataset into training and testing sets
train_indices, test_indices = train_test_split(
    range(len(filtered_ds)),
    test_size=0.2,
    stratify=filtered_ds.labels,
    random_state=random_state
)

# convert data to PyTorch tensors
X_train = torch.FloatTensor(filtered_ds.data[train_indices])
X_test = torch.FloatTensor(filtered_ds.data[test_indices])
y_train = torch.LongTensor(filtered_ds.labels[train_indices])
y_test = torch.LongTensor(filtered_ds.labels[test_indices])

# create data loaders
dataset_train = TensorDataset(X_train, y_train)
dataset_test = TensorDataset(X_test, y_test)

train_loader = DataLoader(dataset_train, batch_size=10, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=10, shuffle=True)

# dataset information
print(f"\nDataset size:") 
print(f"Training set: {X_train.shape}, labels: {y_train.shape}") 
print(f"Test set: {X_test.shape}, labels: {y_test.shape}") 
print(f"\nProportion of samples of each class in training set:") 
for label in np.unique(labels):
    ratio = np.mean(y_train.numpy() == label)
    print(f"Category {label}: {ratio:.3f}")

Filtered dataset size:  1492
Data shape: (1492, 24, 256)
Distribution of labels after filtering and remapping: (array([0, 1]), array([970, 522]))
Label meanings: 0=oddball, 1=standard

Dataset size:
Training set: torch.Size([1193, 24, 256]), labels: torch.Size([1193])
Test set: torch.Size([299, 24, 256]), labels: torch.Size([299])

Proportion of samples of each class in training set:
Category 0: 0.650
Category 1: 0.350


# Create model

The model is a shallow convolutional neural network (ShallowFBCSPNet) with 24 input channels (EEG channels), 2 output classes (oddball, standard), and an input window size of 256 samples (1 seconds of EEG data). 

In [4]:
from braindecode.models import ShallowFBCSPNet
from torchinfo import summary

model = ShallowFBCSPNet(
    in_chans=24,        
    n_classes=2,         
    input_window_samples=256,  
    final_conv_length="auto"
)

summary(model, input_size=(1, 24, 256))

Layer (type:depth-idx)                   Output Shape              Param #
ShallowFBCSPNet                          [1, 2]                    --
├─Ensure4d: 1-1                          [1, 24, 256, 1]           --
├─Rearrange: 1-2                         [1, 1, 256, 24]           --
├─CombinedConv: 1-3                      [1, 40, 232, 1]           39,440
├─BatchNorm2d: 1-4                       [1, 40, 232, 1]           80
├─Expression: 1-5                        [1, 40, 232, 1]           --
├─AvgPool2d: 1-6                         [1, 40, 11, 1]            --
├─Expression: 1-7                        [1, 40, 11, 1]            --
├─Dropout: 1-8                           [1, 40, 11, 1]            --
├─Sequential: 1-9                        [1, 2]                    --
│    └─Conv2d: 2-1                       [1, 2, 1, 1]              882
│    └─LogSoftmax: 2-2                   [1, 2, 1, 1]              --
│    └─Expression: 2-3                   [1, 2]                    --
Total para

## Model Training and Evaluation Process

The training and evaluation pipeline runs for 5 epochs using Adamax optimization. Key components include:

1. **Hardware Setup** - Model allocation to CPU/GPU for optimal computation.

2. **Data Processing** - Channel-wise normalization of input data using mean and standard deviation.

3. **Training Process** - Each epoch performs forward passes, computes cross-entropy loss, updates parameters, and tracks accuracy.

4. **Evaluation** - Model performance is assessed on the test set after each training epoch.

The process monitors both training and test accuracy to track model learning progress.

In [None]:
# set up the device, optimizer, and learning rate scheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.Adamax(model.parameters(), lr=0.002, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1)

def normalize_data(x):
    mean = x.mean(dim=2, keepdim=True)
    std = x.std(dim=2, keepdim=True) + 1e-7
    x = (x - mean) / std
    x = x.to(device=device, dtype=torch.float32)
    return x

print("\nstart training...")
epochs = 5

for e in range(epochs):
    model.train()
    correct_train = 0
    for t, (x, y) in enumerate(train_loader):
        scores = model(normalize_data(x))
        y = y.to(device=device, dtype=torch.long)
        _, preds = scores.max(1)
        correct_train += (preds == y).sum()/len(dataset_train)
        
        loss = F.cross_entropy(scores, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    
    model.eval()
    correct_test = 0
    with torch.no_grad():
        for t, (x, y) in enumerate(test_loader):
            scores = model(normalize_data(x))
            y = y.to(device=device, dtype=torch.long)
            _, preds = scores.max(1)
            correct_test += (preds == y).sum()/len(dataset_test)
    
    print(f'epoch {e+1}, training accuracy: {correct_train:.3f}, test accuracy: {correct_test:.3f}')


start training...
epoch 1, training accuracy: 0.637, test accuracy: 0.753
epoch 2, training accuracy: 0.750, test accuracy: 0.789
epoch 3, training accuracy: 0.795, test accuracy: 0.836
epoch 4, training accuracy: 0.809, test accuracy: 0.846
epoch 5, training accuracy: 0.838, test accuracy: 0.870
