# EEGDash example for sex classification

The code below provides an example of using the *EEGDash* library in combination with PyTorch to develop a deep learning model for detecting sex in a collection of 136 subjects.

1. **Data Retrieval Using EEGDash**: An instance of *EEGDashDataset* is created to search and retrieve resting state data for 136 subjects (dataset ds005505). At this step, only the metadata is transferred.

2. **Data Preprocessing Using BrainDecode**: This process preprocesses EEG data using Braindecode by selecting specific channels, resampling, filtering, and extracting 2-second epochs. This takes about 2 minutes.

3. **Creating a train and testing sets**: The dataset is split into training (80%) and testing (20%) sets with balanced labels--making sure also that we have as many males as females--converted into PyTorch tensors, and wrapped in DataLoader objects for efficient mini-batch training.

4. **Model Definition**: The model is a custom convolutional neural network with 24 input channels (EEG channels), 2 output classes (male and female).

5. **Model Training and Evaluation Process**: This section trains the neural network, normalizes input data, computes cross-entropy loss, updates model parameters, and evaluates classification accuracy over six epochs. This takes less than 10 seconds to a couple of minutes, depending on the device you use.



## Data Retrieval Using EEGDash

First we find one resting state dataset for a collection of subject. The dataset ds005505 contains 136 subjects with both male and female participants.

In [1]:

# from eegdash import EEGDashDataset

# ds_sexdata = EEGDashDataset({'dataset': 'ds005505', 'task': 'RestingState'}, target_name='sex')

## Data Preprocessing Using Braindecode

[BrainDecode](https://braindecode.org/stable/install/install.html) is a specialized library for preprocessing EEG and MEG data. 

We apply three preprocessing steps in Braindecode:
1.	**Selection** of 24 specific EEG channels from the original 128.
2.	**Resampling** the EEG data to a frequency of 128 Hz.
3.	**Filtering** the EEG signals to retain frequencies between 1 Hz and 55 Hz.

When calling the **preprocess** function, the data is retrieved from the remote repository.

Finally, we use **create_windows_from_events** to extract 2-second epochs from the data. These epochs serve as the dataset samples.

In [2]:
# from braindecode.preprocessing import (preprocess, Preprocessor, create_fixed_length_windows)
# import os

# # Alternatively, if you want to include this as a preprocessing step in a Braindecode pipeline:
# preprocessors = [
#     Preprocessor('pick_channels', ch_names=['E22', 'E9', 'E33', 'E24', 'E11', 'E124', 'E122', 'E29', 'E6', 'E111', 'E45', 'E36', 'E104', 'E108', 'E42', 'E55', 'E93', 'E58', 'E52', 'E62', 'E92', 'E96', 'E70', 'Cz']),
#     Preprocessor("resample", sfreq=128),
#     Preprocessor("filter", l_freq=1, h_freq=55)
# ]
# preprocess(ds_sexdata, preprocessors, n_jobs=-1) #, save_dir='xxxx'' will save and set preload to false

# # extract windows and save to disk
# windows_ds = create_fixed_length_windows(ds_sexdata, start_offset_samples=0, stop_offset_samples=None,
#         window_size_samples=256, window_stride_samples=256, drop_last_window=True, preload=False) 
# os.makedirs('data/hbn_preprocessed_restingstate', exist_ok=True)
# windows_ds.save('data/hbn_preprocessed_restingstate', overwrite=True)

## Plotting a Single Channel for One Sample

It’s always a good practice to verify that the data has been properly loaded and processed. Here, we plot a single channel from one sample to ensure the signal is present and looks as expected.

In [3]:
# import matplotlib.pyplot as plt
# plt.figure()
# plt.plot(windows_ds[1000][0][0,:].transpose()) # first channel of first epoch
# plt.show()

## Load pre-saved data

If you have run the previous steps before, the data should be saved and may be reloaded here. If you are simply running this notebook for the first time, there is no need to reload the data, and this step may be skipped. However, it is quick, so you might as well execute the cell; it will have no consequences and will allow you to check that the data was saved properly.

In [4]:
from braindecode.datautil import load_concat_dataset

print("Loading data from disk")
windows_ds = load_concat_dataset(path='data/hbn_preprocessed_restingstate', preload=False)


Loading data from disk


## Feature Extraction

In [5]:
from eegdash.features import extract_features
from eegdash.features import FeatureExtractor, ByChannelFeatureExtractor
from eegdash.features import signal, spectral, connectivity
from functools import partial

features = FeatureExtractor(
    {
        "sig": ByChannelFeatureExtractor(
            {
                "mean": signal.signal_mean,
                "var": signal.signal_variance,
                "std": signal.signal_std,
                "skew": signal.signal_skewness,
                "kurt": signal.signal_kurtosis,
                "rms": signal.signal_rms,
                "ptp": signal.signal_amp_ptp,
                "quan.5": partial(signal.signal_quantile, q=0.5),
                "quan.9": partial(signal.signal_quantile, q=0.9),
                "line_len": signal.signal_line_length,
            },
        ),
        "spec": spectral.SpectralFeatureExtractor(
            {
                "rtot_power": spectral.root_total_power,
                "band_power": spectral.spectral_bands_power,
                0: spectral.NormalizedSpectralFeatureExtractor(
                    {
                        "moment": spectral.spectral_moment,
                        "entropy": spectral.spectral_entropy,
                        "edge": partial(spectral.spectral_edge, edge=0.8),
                    },
                ),
                1: spectral.DBSpectralFeatureExtractor(
                    {
                        "slope": spectral.spectral_slope,
                    },
                ),
            },
            fs=windows_ds.datasets[0].raw.info['sfreq'],
        ),
        "coher": connectivity.CoherenceFeatureExtractor(
            {
                "msc": connectivity.connectivity_magnitude_square_coherence,
            },
            fs=windows_ds.datasets[0].raw.info['sfreq'],
        )
    },
)
features_ds = extract_features(windows_ds, features, n_jobs=-1)

Extracting features: 100%|██████████| 136/136 [01:16<00:00,  1.77it/s]


In [6]:
import os

os.makedirs('data/hbn_features_restingstate', exist_ok=True)
features_ds.save('data/hbn_features_restingstate', overwrite=True)

In [7]:
from eegdash.features import load_features_concat_dataset

print("Loading features from disk")
features_ds = load_features_concat_dataset(path='data/hbn_features_restingstate', n_jobs=-1)

Loading features from disk


In [8]:
mean = features_ds.mean()
features_ds.fillna(mean)
features_ds.fillna(0)
features_ds.zscore(eps=1e-7)

  mean = np.sum((counts / count) * means, axis=0)


In [9]:
features_ds.to_dataframe()

Unnamed: 0,sig_mean_E22,sig_mean_E9,sig_mean_E33,sig_mean_E24,sig_mean_E11,sig_mean_E124,sig_mean_E122,sig_mean_E29,sig_mean_E6,sig_mean_E111,...,coher_msc_beta_E62<>E96,coher_msc_beta_E62<>E70,coher_msc_beta_E62<>Cz,coher_msc_beta_E92<>E96,coher_msc_beta_E92<>E70,coher_msc_beta_E92<>Cz,coher_msc_beta_E96<>E70,coher_msc_beta_E96<>Cz,coher_msc_beta_E70<>Cz,target
0,0.144647,0.048857,0.015319,0.052162,0.209428,-0.018073,0.004150,0.054689,0.032073,0.000455,...,0.01166,1.110223e-09,0.0,0.012212,1.110223e-09,0.0,0.011267,0.0,0.0,F
1,-0.074877,-0.129201,0.018310,0.023309,0.050833,0.064574,0.013648,0.022104,0.016217,0.002118,...,0.01166,1.110223e-09,0.0,0.012212,1.110223e-09,0.0,0.011267,0.0,0.0,F
2,-0.023045,-0.056217,-0.008971,-0.010753,-0.033605,-0.096856,-0.002726,-0.012868,-0.007039,-0.002244,...,0.01166,1.110223e-09,0.0,0.012212,1.110223e-09,0.0,0.011267,0.0,0.0,F
3,0.081790,0.153533,0.000220,0.009420,0.081392,0.136290,0.014210,0.011162,0.008168,0.002876,...,0.01166,1.110223e-09,0.0,0.012212,1.110223e-09,0.0,0.011267,0.0,0.0,F
4,-0.217537,-0.177577,-0.011244,-0.028816,-0.131727,-0.029716,0.006665,-0.019649,-0.007627,-0.000441,...,0.01166,1.110223e-09,0.0,0.012045,1.110223e-09,0.0,0.011267,0.0,0.0,F
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28787,0.125200,0.119488,0.028479,0.065078,0.409021,0.197757,0.012932,0.053864,0.099700,0.003196,...,0.01166,1.110223e-09,0.0,0.012212,1.110223e-09,0.0,0.011267,0.0,0.0,F
28788,-0.044270,-0.030968,-0.011111,-0.008408,-0.087369,0.002570,0.008641,-0.004710,-0.017775,0.000842,...,0.01166,1.110223e-09,0.0,0.012212,1.110223e-09,0.0,0.011267,0.0,0.0,F
28789,-0.007891,0.016211,-0.002247,-0.005089,0.013947,0.002533,0.004072,-0.003454,-0.003185,-0.000214,...,0.01166,1.110223e-09,0.0,0.012212,1.110223e-09,0.0,0.011267,0.0,0.0,F
28790,0.098026,-0.068835,0.028472,0.048080,0.139886,-0.125748,-0.005607,0.027155,0.012778,-0.001464,...,0.01166,1.110223e-09,0.0,0.012212,1.110223e-09,0.0,0.011267,0.0,0.0,F


## Creating a Training and Test Set

The code below creates a training and test set. We first split the data using the **train_test_split** function and then create a **TensorDataset** for both sets.

1. **Set Random Seed** – The random seed is fixed using `torch.manual_seed(random_state)` to ensure reproducibility in dataset splitting and model training.
2. **Get Balanced Indices for Male and Female Subjects** – We ensure a 50/50 split of male and female subjects in both the training and test sets. Additionally, we prevent subject leakage, meaning the same subjects do not appear in both sets. The dataset is split into training (90%) and testing (10%) subsets using `train_test_split()`, ensuring balanced stratification based on gender.
3. **Convert Data to PyTorch Tensors** – The selected training and testing samples are converted into `FloatTensor` for input features and `LongTensor` for labels, making them compatible with PyTorch models.
4. **Create DataLoaders** – The datasets are wrapped in PyTorch `DataLoader` objects with a batch size of 100, allowing efficient mini-batch training and shuffling. Although there are only 136 subjects, the dataset contains more than 10,000 2-second samples.


In [10]:
from eegdash.features import FeaturesConcatDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import numpy as np
import torch

# random seed for reproducibility
random_state = 0
np.random.seed(random_state)
torch.manual_seed(random_state)

# Get balanced indices for male and female subjects and create a balanced dataset
male_subjects   = features_ds.description['subject'][features_ds.description['sex'] == 'M']
female_subjects = features_ds.description['subject'][features_ds.description['sex'] == 'F']
n_samples = min(len(male_subjects), len(female_subjects))
balanced_subjects = np.concatenate([male_subjects[:n_samples], female_subjects[:n_samples]])
balanced_gender = ['M'] * n_samples + ['F'] * n_samples
train_subj, val_subj, train_gender, val_gender = train_test_split(balanced_subjects, balanced_gender, train_size=0.9, stratify=balanced_gender, random_state=random_state)

# Create datasets
train_ds = FeaturesConcatDataset([ds for ds in features_ds.datasets if ds.description.subject in train_subj])
val_ds = FeaturesConcatDataset([ds for ds in features_ds.datasets if ds.description.subject in val_subj])

# Check the balance of the dataset
assert len(balanced_subjects) == len(balanced_gender)
print(f"Number of subjects in balanced dataset: {len(balanced_subjects)}")
print(f"Gender distribution in balanced dataset: {np.unique(balanced_gender, return_counts=True)}")

Number of subjects in balanced dataset: 92
Gender distribution in balanced dataset: (array(['F', 'M'], dtype='<U1'), array([46, 46]))


In [11]:
from lightgbm import LGBMClassifier

target_name = train_ds.datasets[0].target_name
train_df = train_ds.to_dataframe()
X_train, y_train = train_df.drop(target_name, axis=1), train_df[target_name]
val_df = val_ds.to_dataframe()
X_val, y_val = val_df.drop(target_name, axis=1), val_df[target_name]

clf = LGBMClassifier()
clf.fit(X_train, y_train)

y_hat_train = clf.predict(X_train)
correct_train = (y_train == y_hat_train).mean()
y_hat_val = clf.predict(X_val)
correct_val = (y_val == y_hat_val).mean()
print(f'Train accuracy: {correct_train:.2f}, Validation accuracy: {correct_val:.2f}\n')

[LightGBM] [Info] Number of positive: 8485, number of negative: 8175
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.164134 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 115478
[LightGBM] [Info] Number of data points in the train set: 16660, number of used features: 1426
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.509304 -> initscore=0.037219
[LightGBM] [Info] Start training from score 0.037219
Train accuracy: 1.00, Validation accuracy: 0.88



In [12]:

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=100, shuffle=True)
val_loader =  DataLoader(val_ds, batch_size=100, shuffle=True)

# Check labels

It is good practice to verify the labels and ensure the random seed is functioning correctly. If all labels are 'M' (male) or 'F' (female), it could indicate an issue with data loading or stratification, requiring further investigation.

In [13]:
# get the first batch to check the labels
dataiter = iter(train_loader)
first_item, label = dataiter.__next__()
np.array(label).T

array(['F', 'F', 'M', 'F', 'F', 'M', 'M', 'M', 'M', 'F', 'M', 'M', 'F',
       'M', 'M', 'M', 'F', 'M', 'M', 'M', 'M', 'M', 'M', 'F', 'F', 'M',
       'M', 'F', 'M', 'M', 'F', 'M', 'M', 'M', 'M', 'M', 'F', 'M', 'F',
       'F', 'M', 'M', 'F', 'F', 'M', 'M', 'F', 'M', 'M', 'F', 'M', 'F',
       'F', 'M', 'F', 'M', 'F', 'M', 'F', 'M', 'F', 'M', 'F', 'M', 'M',
       'F', 'F', 'F', 'M', 'M', 'M', 'F', 'F', 'F', 'M', 'F', 'M', 'M',
       'F', 'F', 'F', 'F', 'F', 'F', 'F', 'M', 'M', 'M', 'F', 'M', 'F',
       'F', 'M', 'F', 'F', 'F', 'F', 'F', 'F', 'M'], dtype='<U1')

# Create model

The model is a custom convolutional neural network with 24 input channels (EEG channels), 2 output classes (male vs. female), and an input window size of 256 samples (2 seconds of EEG data). See the reference below for more information.

[1] Truong, D., Milham, M., Makeig, S., & Delorme, A. (2021). Deep Convolutional Neural Network Applied to Electroencephalography: Raw Data vs Spectral Features. IEEE Engineering in Medicine and Biology Society. Annual International Conference, 2021, 1039–1042. https://doi.org/10.1109/EMBC46164.2021.9630708



In [14]:
# create model
from torchinfo import summary
from torch import nn

# MLP
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(features_ds.datasets[0].n_features, 100),
    nn.Linear(100, 100),
    nn.Linear(100, 100),
    nn.Linear(100, 2),
)

print(summary(model, input_size=first_item.shape))


Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [100, 2]                  --
├─Flatten: 1-1                           [100, 1584]               --
├─Linear: 1-2                            [100, 100]                158,500
├─Linear: 1-3                            [100, 100]                10,100
├─Linear: 1-4                            [100, 100]                10,100
├─Linear: 1-5                            [100, 2]                  202
Total params: 178,902
Trainable params: 178,902
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 17.89
Input size (MB): 0.63
Forward/backward pass size (MB): 0.24
Params size (MB): 0.72
Estimated Total Size (MB): 1.59


# Model Training and Evaluation Process

This section trains the neural network using the Adamax optimizer, normalizes input data, computes cross-entropy loss, updates model parameters, and tracks accuracy across six epochs.

1. **Set Up Optimizer and Learning Rate Scheduler** – The `Adamax` optimizer initializes with a learning rate of 0.002 and weight decay of 0.001 for regularization.  

2. **Allocate Model to Device** – The model moves to the specified device (CPU, GPU, or MPS for Mac silicon) to optimize computation efficiency.  

3. **Normalize Input Data** – The `normalize_data` function standardizes input data by subtracting the mean and dividing by the standard deviation along the time dimension before transferring it to the appropriate device.  

4. **Train the Model for Two Epochs** – The training loop iterates through data batches with the model in training mode. It normalizes inputs, computes predictions, calculates cross-entropy loss, performs backpropagation, updates model parameters, and steps the learning rate scheduler. It tracks correct predictions to compute accuracy.  

5. **Evaluate on Test Data** – After each epoch, the model runs in evaluation mode on the test set. It computes predictions on normalized data and calculates test accuracy by comparing outputs with actual labels.  


In [15]:

from torch.nn import functional as F

optimizer = torch.optim.Adamax(model.parameters(), lr=0.0005, weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
model.to(device=device)

# dictionary of genders for converting sample labels to numerical values
gender_dict = {'M': 0, 'F': 1}

epochs = 2
for e in range(epochs):
    
    # training
    correct_train = 0
    for t, (x, y) in enumerate(train_loader):
        model.train()  # put model to training mode
        scores = model(x)
        _, preds = scores.max(1)
        y = torch.tensor([gender_dict[gender] for gender in y], device=device, dtype=torch.long)
        correct_train += (preds == y).sum()/len(train_ds)

        # Calculates the cross-entropy loss and performs backpropagation
        loss = F.cross_entropy(scores, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if t % 50 == 0:
            print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))
    
    # validation
    correct_test = 0
    for t, (x, y) in enumerate(val_loader):
        model.eval() # put model to testing mode
        scores = model(x)
        _, preds = scores.max(1)
        y = torch.tensor([gender_dict[gender] for gender in y], device=device, dtype=torch.long)
        correct_test += (preds == y).sum()/len(val_ds)

    print(f'Epoch {e}, Train accuracy: {correct_train:.2f}, Test accuracy: {correct_test:.2f}\n')


Epoch 0, Iteration 0, loss = 0.6971
Epoch 0, Iteration 50, loss = 0.4521
Epoch 0, Iteration 100, loss = 0.2760
Epoch 0, Iteration 150, loss = 0.3525
Epoch 0, Train accuracy: 0.86, Test accuracy: 0.73

Epoch 1, Iteration 0, loss = 0.2299
Epoch 1, Iteration 50, loss = 0.2266
Epoch 1, Iteration 100, loss = 0.2581
Epoch 1, Iteration 150, loss = 0.1543
Epoch 1, Train accuracy: 0.91, Test accuracy: 0.74

