# EEGDash example

The code below provides an example of using the *EEGDash* library in combination with PyTorch to develop a deep learning model for analyzing EEG data, specifically for the task of gender discrimination.

1. **Library Imports**: Essential libraries are imported, including *EEGDash* for data handling, *torch* for deep learning model definition, and other utilities such as *xarray*, *pandas*, and *numpy*.

2. **Data Retrieval Using EEGDash**: An instance of *EEGDash* is created to search and retrieve an EEG dataset (in this case "FaceRecognition"). Then data records are fetched and inspected.

3. **Model Definition**: A custom deep learning model (*VGGSSL*) based on a rescaled version of the VGG-16 model is defined, intended to encode EEG signal data [1]. The model structure includes: (1) **VGG16 Rescaling**, the convolutional layers are modified to fit the specific dimensions of EEG datam and (2) **Encoding Layer**, an encoding layer allows feature extraction suitable for downstream tasks.

4. **Custom Dataset Class**: *DeepLearningEEGDataset* is defined to structure the EEG data in a PyTorch-compatible format, including reading participant labels (e.g., gender) from a *participants.tsv* file. This class assigns labels to each subject and plits EEG recordings into smaller time windows for better input handling.

5. **Training Loop**: The dataset is loaded into a PyTorch DataLoader. A training loop is initiated to perform backpropagation and train the model and loss is printed periodically to track model performance.

Refrences:

[1] Truong, D., Milham, M., Makeig, S., & Delorme, A. (2021). Deep Convolutional Neural Network Applied to Electroencephalography: Raw Data vs Spectral Features. Annual International Conference of the IEEE Engineering in Medicine and Biology Society. IEEE Engineering in Medicine and Biology Society. Annual International Conference, 2021, 1039–1042. https://doi.org/10.1109/EMBC46164.2021.9630708


#### Import necessary libraries

In [2]:
from eegdash import EEGDash
import torch
import torchvision.models as torchmodels
import torch.nn as nn
import xarray as xr
import numpy as np
import math
import pandas as pd
import torch.nn.functional as F
import torch.optim as optim

#### Use EEGDash to find data

In [3]:
EEGDashInstance = EEGDash()
EEGDashInstance.find({'task': 'FaceRecognition'})
records = EEGDashInstance.get({'task': 'FaceRecognition'})
print(records[0])
print('Shape of one array recording data', records[0].shape)

Pinged your deployment. You successfully connected to MongoDB!
Found 18 records
Found 18 records
<xarray.DataArray 'eeg_signal__ds002718_sub-014_task-FaceRecognition_eeg.set' (
                                                                               channel: 74,
                                                                               time: 742500)> Size: 440MB
[54945000 values with dtype=float64]
Coordinates:
  * channel  (channel) object 592B 'EEG001' 'EEG002' ... 'EEG073' 'EEG074'
  * time     (time) float64 6MB 0.0 0.004 0.008 ... 2.97e+03 2.97e+03 2.97e+03
Attributes:
    data_name:           ds002718_sub-014_task-FaceRecognition_eeg.set
    dataset:             ds002718
    has_file:            True
    modality:            EEG
    run:                 
    sampling_frequency:  250
    schema_ref:          eeg_signal
    session:             
    subject:             14
    task:                FaceRecognition
    version_timestamp:   0
Shape of one array recording dat

#### Specify PyTorch Dataset and Deep Learning Model Architecture

In [4]:
class VGGSSL(nn.Module):
    def __init__(self, model_params=None):
        super().__init__()
        default_params = {
            'task': 'RP',
            'weights': 'DEFAULT'
        }

        if model_params:
            default_params.update(model_params)
        for k,v in default_params.items():
            setattr(self, k, v)

        self.model: nn.Module = None
        self.projection: nn.Linear = None
        vgg = self.create_vgg_rescaled(weights=self.weights)
        self.encoder = nn.Sequential(vgg.features, vgg.flatten)
        
    def create_vgg_rescaled(self, subsample=4, feature='raw', weights='DEFAULT'):
        tmp = torchmodels.vgg16(weights=weights)
        tmp.features = tmp.features[0:17]
        vgg16_rescaled = nn.Sequential()
        modules = []
        
        if feature == 'raw':
            first_in_channels = 1
            first_in_features = 6144
        else:
            first_in_channels = 3
            first_in_features = 576
            
        for layer in tmp.features.children():
            if isinstance(layer, nn.Conv2d):
                if layer.in_channels == 3:
                    in_channels = first_in_channels
                else:
                    in_channels = int(layer.in_channels/subsample)
                out_channels = int(layer.out_channels/subsample)
                modules.append(nn.Conv2d(in_channels, out_channels, layer.kernel_size, layer.stride, layer.padding))
            else:
                modules.append(layer)
        vgg16_rescaled.add_module('features',nn.Sequential(*modules))
        vgg16_rescaled.add_module('flatten', nn.Flatten())

        modules = []
        for layer in tmp.classifier.children():
            if isinstance(layer, nn.Linear):
                if layer.in_features == 25088:
                    in_features = first_in_features
                else:
                    in_features = int(layer.in_features/subsample) 
                if layer.out_features == 1000:
                    out_features = 2
                else:
                    out_features = int(layer.out_features/subsample) 
                modules.append(nn.Linear(in_features, out_features))
            else:
                modules.append(layer)
        vgg16_rescaled.add_module('classifier', nn.Sequential(*modules))
        return vgg16_rescaled

    def forward(self, x):
        '''
        @param x: (batch_size, channel, time)
        '''
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        return self.encode(x)
    
    def encode(self, x):
        return self.encoder(x)

    def aggregate(self, x):
        return super().aggregate(x)

class DeepLearningEEGDataset(torch.utils.data.IterableDataset):
    def __init__(self, records, participants):
        self.records = records
        self.window_size = 2 # seconds
        participants = pd.read_csv(participants, sep='\t')
        self.labels = self.get_labels(participants)

    def get_labels(self, participants):
        subjects_str = participants['participant_id'].values
        subjects_str = [s.split('-')[1] for s in subjects_str]
        # get the number removing leading 0s
        subjects = [int(s) for s in subjects_str]
        gender = participants['gender'].values
        def gender_to_int(g):
            if g == 'M':
                return 0
            elif g == 'F':
                return 1
            else:
                return 2
        gender_int = list(map(gender_to_int, gender))
        labels = dict(zip(subjects, gender_int))
        return labels

    def __iter__(self):
        # set up multi-processing
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = 0
            iter_end = len(self.records)
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil(len(self.records) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = worker_id * per_worker
            iter_end = min(iter_start + per_worker, len(self.records))
            print(f'worker_id: {worker_id}, iter_start: {iter_start}, iter_end: {iter_end}\n')
        for i in range(iter_start, iter_end):
            record = self.records[i]
            data = record.values # C x T    
            window_size_in_samples = int(self.window_size * record.sampling_frequency)
            indices = np.arange(0, data.shape[1]-window_size_in_samples, window_size_in_samples)
            for idx in indices:
                if idx < data.shape[-1]-window_size_in_samples:
                    yield data[:,idx:idx+window_size_in_samples], self.labels[record.subject]



#### Set up training loop and train model

In [5]:
dataset = DeepLearningEEGDataset(records, 'participants.tsv')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=0)
model = VGGSSL()
optimizer = torch.optim.Adamax(model.parameters(), lr=0.002, weight_decay=0.001)
model.train()
for e in range(1):
    for t, (samples, labels) in enumerate(dataloader):
        samples = samples.to(dtype=torch.float32)
        labels = labels.to(dtype=torch.long)
        scores = model(samples)
        loss = F.cross_entropy(scores, labels)

        # Zero out all of the gradients for the variables which the optimizer
        # will update.
        optimizer.zero_grad()

        # This is the backwards pass: compute the gradient of the loss with
        # respect to each  parameter of the model.
        loss.backward()

        # Actually update the parameters of the model using the gradients
        # computed by the backwards pass.
        optimizer.step()

        print(f"Epoch {e} - Iter {t} - Loss/train: {loss.item()}")
        if t == 20:
            break

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /Users/arno/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:15<00:00, 34.7MB/s] 


Epoch 0 - Iter 0 - Loss/train: 10.462847709655762
Epoch 0 - Iter 1 - Loss/train: 10.416534423828125
Epoch 0 - Iter 2 - Loss/train: 10.368654251098633
Epoch 0 - Iter 3 - Loss/train: 10.308568000793457
Epoch 0 - Iter 4 - Loss/train: 10.228968620300293
Epoch 0 - Iter 5 - Loss/train: 10.12086296081543
Epoch 0 - Iter 6 - Loss/train: 9.975360870361328
Epoch 0 - Iter 7 - Loss/train: 9.777841567993164
Epoch 0 - Iter 8 - Loss/train: 9.512860298156738
Epoch 0 - Iter 9 - Loss/train: 9.160260200500488
Epoch 0 - Iter 10 - Loss/train: 8.696739196777344
Epoch 0 - Iter 11 - Loss/train: 8.117931365966797
Epoch 0 - Iter 12 - Loss/train: 7.482291221618652
Epoch 0 - Iter 13 - Loss/train: 7.119117736816406
Epoch 0 - Iter 14 - Loss/train: 7.198959827423096
Epoch 0 - Iter 15 - Loss/train: 7.317335605621338
Epoch 0 - Iter 16 - Loss/train: 7.335928440093994
Epoch 0 - Iter 17 - Loss/train: 7.238356113433838
Epoch 0 - Iter 18 - Loss/train: 7.0485334396362305
Epoch 0 - Iter 19 - Loss/train: 6.806842803955078
Epoc