# Imports & Device Setup

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchvision
from torchvision.transforms import v2
from tqdm import tqdm
from skorch.helper import predefined_split
import copy

from skorch.callbacks import LRScheduler

import pandas as pd
import mne

np.int = int
np.bool = bool
np.object = object
from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
    create_fixed_length_windows
)
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet, EEGConformer
from braindecode import EEGClassifier
from braindecode.datasets import BaseDataset, BaseConcatDataset


from braindecode.augmentation import (
    FTSurrogate,
    SmoothTimeMask,
    ChannelsDropout,
    AugmentedDataLoader
)

In [2]:
if torch.cuda.is_available():
  device = torch.device('cuda')
elif torch.backends.mps.is_available():
  device = torch.device('mps')
else:
  device = torch.device('cpu')

print("Using", device)

Using mps


# Important Variables

In [3]:
seed = 1
set_random_seeds(seed, str(device)=='cuda')
num_channels = 22
num_classes = 4
ch_names = [str(i) for i in range(num_channels)]
classes = list(range(num_classes))
input_window_samples = 400
lr = 0.0625 * 0.01
weight_decay = 0
batch_size = 64
n_epochs = 15 # 50
folds = 10
sfreq = 250

# Load Data

In [4]:
X_train_valid = np.load("./project_data/X_train_valid.npy")
y_train_valid = np.load("./project_data/y_train_valid.npy") - 769

X_test = np.load("./project_data/X_test.npy")
y_test = np.load("./project_data/y_test.npy") - 769

person_train_valid = np.load("./project_data/person_train_valid.npy")
person_test = np.load("./project_data/person_test.npy")

In [5]:
def data_prep(X,
              y,
              sub_sample,
              average,
              noise,
              p_channel_dropoup=0,
              smooth_time_mask=False,
              mask_size=0,
              time_shift=0,
              clipping_max=800,
              noise_stdev=0.5):

    total_X = None
    total_y = None
    X = X[:,:,0:clipping_max]
    print('Shape of X after trimming: {X.shape}')
    X_max, _ = torch.max(X.view(X.size(0), X.size(1), -1, sub_sample), axis=3)

    total_X = X_max
    total_y = y
    print('Shape of X after maxpooling:',total_X.shape)

    X_average = torch.mean(X.view(X.size(0), X.size(1), -1, average), axis=3)
    X_average = X_average + torch.normal(0.0, noise_stdev, X_average.shape)

    total_X = torch.cat((total_X, X_average), dim=0)
    total_y = torch.cat((total_y, y))
    print('Shape of X after averaging+noise and concatenating:',total_X.shape)

    for i in range(sub_sample):
        X_subsample = X[:, :, i::sub_sample] + \
                            (torch.normal(0.0, 0.5, X[:, :,i::sub_sample].shape) if noise else 0.0)
        total_X = torch.cat((total_X, X_subsample), dim=0)
        print(total_y.view(-1,1).shape)
        print(y.view(-1,1).shape)
        total_y = torch.cat((total_y, y))

    print('Shape of X after subsampling and concatenating:',total_X.shape)
    print('Shape of Y:',total_y.shape)

    if p_channel_dropout != 0:
        mask = (torch.rand(total_X.shape[0], total_X.shape[1]) >= p_channel_dropout).unsqueeze(2)
        X_dropout = mask * total_X
        total_X = torch.cat((total_X, X_dropout))
        total_y = torch.cat((total_y, total_y))

        print(f'Shape of X after channel dropout {total_X.shape}')
        print(f'Shape of Y: {total_y.shape}')

    if smooth_time_mask:
        copy_X = copy.deepcopy(total_X)
        starts = ((torch.rand(copy_X.shape[0])*(copy_X.shape[2]-mask_size-1))).round()
        for idx, m in enumerate(copy_X):
            start = int(starts[idx])
            end = start+mask_size
            m[:,start:end] = 0
        total_X = torch.cat((total_X, copy_X))
        total_y = torch.cat((total_y, total_y))
        
        print(f'Shape of X after smooth time mask {total_X.shape}')
        print(f'Shape of Y: {total_y.shape}')

    if time_shift != 0:
        time_shift_X = copy.deepcopy(total_X)
        shifts = np.random.randint(low=-time_shift, high=time_shift+1, size=(total_X.shape[0],))
        time_shift_X = torch.Tensor(np.array([torch.roll(elem, shift, 1) for elem, shift in zip(time_shift_X, shifts)]))
        total_X = torch.cat((total_X, time_shift_X))
        total_y = torch.cat((total_y, total_y))
        
        print(f'Shape of X after time_shift {total_X.shape}')
        print(f'Shape of Y: {total_y.shape}')

    return total_X,total_y

def test_data_prep(X):
    total_X = None
    X = X[:,:,0:800]
    print('Shape of X after trimming:', X.shape)
    X_max, _ = torch.max(X.view(X.size(0), X.size(1), -1, 2), axis=3)
    total_X = X_max
    print('Shape of X after maxpooling:',total_X.shape)
    return total_X

In [6]:
indices = np.random.permutation(X_train_valid.shape[0])
split_idx = int(X_train_valid.shape[0] * ((folds-1)/folds))

folds = 10
split_seed = 1

subsample = 2
average = 2
noise = True
p_channel_dropout = 0.01
smooth_time_mask = False
mask_size = 120
time_shift = 30
clipping_max = 800
noise_stdev = 0.16 # 0.5

X_train_valid = torch.Tensor(X_train_valid)
y_train_valid = torch.Tensor(y_train_valid)

print(f'SOME {X_train_valid.shape}')
indices = torch.randperm(X_train_valid.shape[0])
split_idx = int(X_train_valid.shape[0] * ((folds-1)/folds))
X_train, X_valid = X_train_valid[indices[:split_idx]], X_train_valid[indices[split_idx:]]
y_train, y_valid = y_train_valid[indices[:split_idx]], y_train_valid[indices[split_idx:]]
X_test, y_test = torch.Tensor(X_test), torch.Tensor(y_test)

print('Prepping Training Data')
X_train, y_train = data_prep(X_train, y_train, subsample, average, noise, p_channel_dropout, smooth_time_mask, mask_size,
                            time_shift, clipping_max, noise_stdev)
print('\nPrepping Validation Data')
X_valid, y_valid = data_prep(X_valid, y_valid, subsample, average, noise, p_channel_dropout, smooth_time_mask, mask_size,
                            time_shift, clipping_max, noise_stdev)
print('\nPrepping Test Data')
X_test = test_data_prep(X_test)

X_train, y_train = np.array(X_train), np.array(y_train)
X_valid, y_valid = np.array(X_valid), np.array(y_valid)
X_test, y_test = np.array(X_test), np.array(y_test)

SOME torch.Size([2115, 22, 1000])
Prepping Training Data
Shape of X after trimming: {X.shape}
Shape of X after maxpooling: torch.Size([1903, 22, 400])
Shape of X after averaging+noise and concatenating: torch.Size([3806, 22, 400])
torch.Size([3806, 1])
torch.Size([1903, 1])
torch.Size([5709, 1])
torch.Size([1903, 1])
Shape of X after subsampling and concatenating: torch.Size([7612, 22, 400])
Shape of Y: torch.Size([7612])
Shape of X after channel dropout torch.Size([15224, 22, 400])
Shape of Y: torch.Size([15224])
Shape of X after time_shift torch.Size([30448, 22, 400])
Shape of Y: torch.Size([30448])

Prepping Validation Data
Shape of X after trimming: {X.shape}
Shape of X after maxpooling: torch.Size([212, 22, 400])
Shape of X after averaging+noise and concatenating: torch.Size([424, 22, 400])
torch.Size([424, 1])
torch.Size([212, 1])
torch.Size([636, 1])
torch.Size([212, 1])
Shape of X after subsampling and concatenating: torch.Size([848, 22, 400])
Shape of Y: torch.Size([848])
Shap

In [7]:
def create_from_X_y(X, y, drop_last_window, sfreq, ch_names):
  n_samples_per_x = []
  base_datasets = []
  for x, target in zip(X, y):
    n_samples_per_x.append(x.shape[1])
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq) # , ch_types=ch_types
    raw = mne.io.RawArray(x, info, verbose=False)
    base_dataset = BaseDataset(raw, pd.Series({"target": target}),
                               target_name="target")
    base_datasets.append(base_dataset)
  base_datasets = BaseConcatDataset(base_datasets)
  if not len(np.unique(n_samples_per_x)) == 1:
    raise ValueError("if 'window_size_samples' and "
                      "'window_stride_samples' are None, "
                      "all trials have to have the same length")
  window_size_samples = n_samples_per_x[0]
  window_stride_samples = n_samples_per_x[0]
  windows_datasets = create_fixed_length_windows(
    base_datasets,
    start_offset_samples=0,
    stop_offset_samples=None,
    window_size_samples=window_size_samples,
    window_stride_samples=window_stride_samples,
    drop_last_window=drop_last_window
  )
  return windows_datasets

## Create Braindecode Datasets

In [8]:
train_dataset = create_from_X_y(X_train, y_train, False, sfreq, ch_names=ch_names);
print('DONE TRAIN')
valid_dataset = create_from_X_y(X_valid, y_valid, False, sfreq, ch_names=ch_names);
print('DONE VALID')
test_dataset = create_from_X_y(X_test, y_test, False, sfreq, ch_names=ch_names);
print('DONE TEST')

DONE TRAIN
DONE VALID
DONE TEST


# Preprocessing

In [9]:
low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
factor_new = 1e-3
init_block_size = 400

preprocessors = [
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz, picks=ch_names, verbose=False),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
        picks=ch_names,
    ),
]

# Preprocess the data
# preprocess(train_dataset, preprocessors)
# preprocess(valid_dataset, preprocessors)
# preprocess(test_dataset, preprocessors)

# Augmentations

In [10]:
# channels_dropout = ChannelsDropout(
#     probability=0.5,
#     p_drop=1
# )

# smooth_time_mask = SmoothTimeMask(
#     probability=0.5,
#     mask_len_samples=300
# )

# transforms = [smooth_time_mask, channels_dropout]

# Model Training

In [11]:
model = ShallowFBCSPNet(
    num_channels,
    num_classes,
    n_times=input_window_samples,
    final_conv_length="auto",
)
print(model)
model.to(device)

print(f'TRAIN LENGTH: {len(train_dataset)}')
print(f'VALID LENGTH: {len(valid_dataset)}')
print(f'TEST LENGTH: {len(test_dataset)}')

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 22, 400]              [1, 4]                    --                        --
├─Ensure4d (ensuredims): 1-1             [1, 22, 400]              [1, 22, 400, 1]           --                        --
├─Rearrange (dimshuffle): 1-2            [1, 22, 400, 1]           [1, 1, 400, 22]           --                        --
├─CombinedConv (conv_time_spat): 1-3     [1, 1, 400, 22]           [1, 40, 376, 1]           36,240                    --
├─BatchNorm2d (bnorm): 1-4               [1, 40, 376, 1]           [1, 40, 376, 1]           80                        --
├─Expression (conv_nonlin_exp): 1-5      [1, 40, 376, 1]           [1, 40, 376, 1]           --                        --
├─AvgPool2d (pool): 1-6                  [1, 40, 376, 1]           [1, 40, 21, 1]            --                        [75, 1]
├─Express



In [None]:
clf = EEGClassifier(
    model,
    # iterator_train=AugmentedDataLoader,
    # iterator_train__transforms=transforms,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_dataset),
    optimizer__lr=lr,
    optimizer__weight_decay=weight_decay,
    batch_size=batch_size,
    callbacks=[
        "accuracy",
        ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
    ],
    device=device,
    classes=classes,
    max_epochs=n_epochs,
)
clf.fit(train_dataset, y=None)

# evaluated the model after training
y_test = test_dataset.get_metadata().target
test_acc = clf.score(test_dataset, y=y_test)
print(f"Test acc: {(test_acc * 100):.2f}%")

  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr      dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  -------
      1            [36m0.6791[0m        [32m1.2530[0m       [35m0.5825[0m            [31m0.5825[0m        [94m0.9553[0m  0.0006  27.4831
      2            [36m0.7573[0m        [32m0.8677[0m       [35m0.6624[0m            [31m0.6624[0m        [94m0.8642[0m  0.0006  19.7449
      3            [36m0.8376[0m        [32m0.7099[0m       [35m0.6787[0m            [31m0.6787[0m        [94m0.8385[0m  0.0006  19.7757
      4            [36m0.8512[0m        [32m0.6139[0m       0.6754            0.6754        0.8632  0.0006  20.3260
