# Imports & Device Setup

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchvision
from torchvision.transforms import v2
from tqdm import tqdm
from skorch.helper import predefined_split

from skorch.callbacks import LRScheduler

import pandas as pd
import mne

np.int = int
np.bool = bool
np.object = object
from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
    create_fixed_length_windows
)
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet, EEGConformer
from braindecode import EEGClassifier
from braindecode.datasets import BaseDataset, BaseConcatDataset, create_from_X_y


from braindecode.augmentation import (
    FTSurrogate,
    SmoothTimeMask,
    ChannelsDropout,
    AugmentedDataLoader
)

In [2]:
if torch.cuda.is_available():
  device = torch.device('cuda')
elif torch.backends.mps.is_available():
  device = torch.device('mps')
else:
  device = torch.device('cpu')

print("Using", device)

Using mps


# Important Variables

In [3]:
seed = 1
set_random_seeds(seed, str(device)=='cuda')
num_channels = 22
num_classes = 4
ch_names = [str(i) for i in range(num_channels)]
classes = list(range(num_classes))
input_window_samples = 800
lr = 0.0625 * 0.01
weight_decay = 0
batch_size = 64
n_epochs = 1000
folds = 10
sfreq = 250

# Load Data

In [4]:
X_train_valid = np.load("./project_data/X_train_valid.npy")
y_train_valid = np.load("./project_data/y_train_valid.npy") - 769

X_test = np.load("./project_data/X_test.npy")
y_test = np.load("./project_data/y_test.npy") - 769

person_train_valid = np.load("./project_data/person_train_valid.npy")
person_test = np.load("./project_data/person_test.npy")

In [5]:
indices = np.random.permutation(X_train_valid.shape[0])
split_idx = int(X_train_valid.shape[0] * ((folds-1)/folds))
X_train_valid = X_train_valid[:,:,0:800]
X_test = X_test[:,:,0:800]

X_train, X_valid = X_train_valid[indices[:split_idx]], X_train_valid[indices[split_idx:]]
y_train, y_valid = y_train_valid[indices[:split_idx]], y_train_valid[indices[split_idx:]]

## Create Braindecode Datasets

In [6]:
train_dataset = create_from_X_y(X_train, y_train, False, sfreq, ch_names=ch_names)
valid_dataset = create_from_X_y(X_valid, y_valid, False, sfreq, ch_names=ch_names)
test_dataset = create_from_X_y(X_test, y_test, False, sfreq, ch_names=ch_names)

Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.
Creating RawArray with float64 data, n_channels=22, n_times=800
    Range : 0 ... 799 =      0.000 ...     3.196 secs
Ready.


# Preprocessing

In [7]:
low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
factor_new = 1e-3
init_block_size = 800

preprocessors = [
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz, picks=ch_names, verbose=False),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
        picks=ch_names,
    ),
]

# Preprocess the data
preprocess(train_dataset, preprocessors)
preprocess(valid_dataset, preprocessors)
preprocess(test_dataset, preprocessors)

<braindecode.datasets.base.BaseConcatDataset at 0x2c6cedfd0>

# Augmentations

In [8]:
channels_dropout = ChannelsDropout(
    probability=0.5,
    p_drop=1
)

smooth_time_mask = SmoothTimeMask(
    probability=0.5,
    mask_len_samples=300
)

transforms = [smooth_time_mask, channels_dropout]

# Model Training

In [9]:
model = ShallowFBCSPNet(
    num_channels,
    num_classes,
    n_times=input_window_samples,
    final_conv_length="auto",
)
print(model)
model.to(device)

print(f'TRAIN LENGTH: {len(train_dataset)}')
print(f'VALID LENGTH: {len(valid_dataset)}')
print(f'TEST LENGTH: {len(test_dataset)}')

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 22, 800]              [1, 4]                    --                        --
├─Ensure4d (ensuredims): 1-1             [1, 22, 800]              [1, 22, 800, 1]           --                        --
├─Rearrange (dimshuffle): 1-2            [1, 22, 800, 1]           [1, 1, 800, 22]           --                        --
├─CombinedConv (conv_time_spat): 1-3     [1, 1, 800, 22]           [1, 40, 776, 1]           36,240                    --
├─BatchNorm2d (bnorm): 1-4               [1, 40, 776, 1]           [1, 40, 776, 1]           80                        --
├─Expression (conv_nonlin_exp): 1-5      [1, 40, 776, 1]           [1, 40, 776, 1]           --                        --
├─AvgPool2d (pool): 1-6                  [1, 40, 776, 1]           [1, 40, 47, 1]            --                        [75, 1]
├─Express



In [None]:
clf = EEGClassifier(
    model,
    iterator_train=AugmentedDataLoader,
    iterator_train__transforms=transforms,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_dataset),
    optimizer__lr=lr,
    optimizer__weight_decay=weight_decay,
    batch_size=batch_size,
    callbacks=[
        "accuracy",
        ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
    ],
    device=device,
    classes=classes,
    max_epochs=n_epochs,
)
clf.fit(train_dataset, y=None)

# evaluated the model after training
y_test = test_dataset.get_metadata().target
test_acc = clf.score(test_dataset, y=y_test)
print(f"Test acc: {(test_acc * 100):.2f}%")

  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            [36m0.2696[0m       [32m10.9947[0m       [35m0.3066[0m            [31m0.3066[0m        [94m1.6806[0m  0.0006  2.5821
      2            0.2554        [32m8.7138[0m       0.2358            0.2358        1.7156  0.0006  2.1090
      3            [36m0.2701[0m        [32m8.2649[0m       0.2406            0.2406        [94m1.5180[0m  0.0006  2.1545
      4            [36m0.2848[0m        [32m7.3648[0m       0.2500            0.2500        [94m1.4565[0m  0.0006  2.1755
      5            [36m0.3106[0m        [32m7.0687[0m       0.2500            0.2500        1.4789  0.0006  2.1405
      6            0.2937        [32m6.6841[0m       [35m0.3113[0m            [31m0.3113[0m        1.4904  0.0006  2.1259
      7            [36m0.3242[0m     