In [1]:
from natsort import natsorted
import numpy as np
import os
import random

#Helper functions
from create_dataset import make_dataset

#Deep Learning
import torch
from torch import nn, optim
from torch.utils.data import dataset as ds
from torch.utils.data import DataLoader, SubsetRandomSampler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torchmetrics import Accuracy

#Logging
#!wandb login 6b63fbb174d08e296b363d52818553c19d89e43d
import wandb
from pytorch_lightning.loggers import WandbLogger

  np.save("train_{}/labels.npy".format(experiment), np.hstack(np.array(train_labels)))
  np.save("val_{}/labels.npy".format(experiment), np.hstack(np.array(val_labels)))
  from .autonotebook import tqdm as notebook_tqdm


1. Make a dataset from the sub_preprocessed.npy files, where the samples for each condition are in separate folders.


In [3]:
#Create dataset
make_dataset("C:/Users/Daydreamore/Desktop/Semester/BCI")

  np.save("train_{}/labels.npy".format(experiment), np.hstack(np.array(train_labels)))
  np.save("val_{}/labels.npy".format(experiment), np.hstack(np.array(val_labels)))


2. Create Dataset-Class that can be indexed by Pytorch DataLoader.
To save some GPU resources, we only pass the path such that the DataLoader (generator) loads one by one

Note: use "eeg_dataset" to load all data or "eeg_dataset_undersample" to have a balanced dataset

In [3]:
#Dataset class without undersampling
class eeg_dataset():

    def __init__(self, path):
        self.path = path
        self.sample_list = os.listdir(path)[:-1]
        self.targets = torch.from_numpy(np.load(os.path.join(path,os.listdir(path)[-1])))
        
    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (eeg_data, target) where target is index of the target class.
        """
        sample = np.load(os.path.join(self.path, self.sample_list[idx]))
        # mean = np.mean(sample)
        # std = np.std(sample)
        # sample = (sample-mean)/std
        return torch.from_numpy(sample), nn.functional.one_hot(self.targets[idx].to(torch.int64), num_classes = 3).float()

path_train_visual = "C:/Users/Daydreamore/Desktop/Semester/BCI/train_visual"
path_train_multi = "C:/Users/Daydreamore/Desktop/Semester/BCI/train_multi"
path_val_visual = "C:/Users/Daydreamore/Desktop/Semester/BCI/val_visual"
path_val_multi = "C:/Users/Daydreamore/Desktop/Semester/BCI/val_multi"
train_set_visual = eeg_dataset(path_train_visual)
train_set_multi = eeg_dataset(path_train_multi)
val_set_visual = eeg_dataset(path_val_visual)
val_set_multi = eeg_dataset(path_val_multi)

In [17]:
path = "C:/Users/Daydreamore/Desktop/Semester/BCI/train_visual"
targets = torch.from_numpy(np.load(os.path.join(path,os.listdir(path)[-1])))
torch.ones_like(targets).shape

torch.Size([1544])

In [2]:
#Dataset class with undersampling
class eeg_dataset_undersample():

    def __init__(self, path, classes):
        self.path = path
        self.sample_list = []
        self.target_list = []

        samples = np.array(os.listdir(path)[:-1])
        targets = torch.from_numpy(np.load(os.path.join(path,os.listdir(path)[-1])))
        n_samp = np.where(targets == classes[1])[0].shape[0] #number of samples for condition (used to sample same number from control)
        #for i in targets.unique(): #iterate over conditions (control, explosion, burning)
        for i in classes:
            target_ixs = np.where(targets == i)
            if i == 0: 
                target_ixs_shuffled = np.random.choice(target_ixs[0], size = n_samp, replace = False) #randomly sample n_samp control to get balanced dataset
                self.sample_list.append(samples[target_ixs_shuffled])
                self.target_list.append(targets[target_ixs_shuffled])

            else:
                self.sample_list.append(samples[target_ixs]) 
                self.target_list.append(torch.ones_like(targets[target_ixs]))

        self.sample_list = np.concatenate(self.sample_list, axis=0)
        self.target_list = torch.cat(self.target_list)

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (eeg_data, target) where target is index of the target class.
        """
        sample = np.load(os.path.join(self.path, self.sample_list[idx]))
        # mean = np.mean(sample)
        # std = np.std(sample)
        # sample = (sample-mean)/std
        return torch.from_numpy(sample), nn.functional.one_hot(self.target_list[idx].to(torch.int64), num_classes = 2).float()

path_train_visual = "C:/Users/Daydreamore/Desktop/Semester/BCI/train_visual"
path_train_multi = "C:/Users/Daydreamore/Desktop/Semester/BCI/train_multi"
path_val_visual = "C:/Users/Daydreamore/Desktop/Semester/BCI/val_visual"
path_val_multi = "C:/Users/Daydreamore/Desktop/Semester/BCI/val_multi"
train_set_visual_u_01 = eeg_dataset_undersample(path_train_visual, classes = [0,1])
train_set_multi_u_01 = eeg_dataset_undersample(path_train_multi, classes = [0,1])
val_set_visual_u_01 = eeg_dataset_undersample(path_val_visual, classes = [0,1])
val_set_multi_u_01 = eeg_dataset_undersample(path_val_multi, classes = [0,1])

In [9]:
train_set_visual_u_02 = eeg_dataset_undersample(path_train_visual, classes = [0,2])
train_set_multi_u_02 = eeg_dataset_undersample(path_train_multi, classes = [0,2])
val_set_visual_u_02 = eeg_dataset_undersample(path_val_visual, classes = [0,2])
val_set_multi_u_02 = eeg_dataset_undersample(path_val_multi, classes = [0,2])

Check if values in the sample can be accessed:

In [5]:
train_set_visual_u_01.__getitem__(1)[0][0][0]

tensor(2.8946, dtype=torch.float64)

We have a highly unblanaced dataset. Therefore, we calculate a weight parameter to give a higher penalty to missprediction of less frequent class occurences.

In [5]:
class_weights = torch.from_numpy(train_set_visual_u_01.__len__() / (2 * np.bincount(train_set_visual_u_01.target_list))).cuda()

3. 1D-ConvNet Setup (this could be further modularized)

In [10]:
class ConvNet(pl.LightningModule):
    def __init__(
        self,
        train_set = train_set_visual_u_02,
        val_set = val_set_visual_u_02,
        batch_size = 16,
        epochs = 100,
        learning_rate = 0.00004,
        in_channels = 11,
        out_channels = 256,
        kernel_size = 5,
        num_classes = 2,
        bn_alpha = 0.1,
        pool_out1 = 120,
        pool_out2 = 60,
        pool_out3 = 30,
        pool_out4 = 15,
        class_weights = class_weights,
        dropout = 0.5
    ):
        super().__init__()
        self.save_hyperparameters()
        
        #Model Architecture Stuff
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size) 
        self.conv3 = nn.Conv1d(out_channels, 128, kernel_size) 
        self.conv4 = nn.Conv1d(128, 64, kernel_size)

        #self.pool = nn.MaxPool1d(kernel_size=pool_kernel) #stride = kernel_size
        self.pool1 = nn.AdaptiveMaxPool1d(pool_out1)
        self.pool2 = nn.AdaptiveMaxPool1d(pool_out2)
        self.pool3 = nn.AdaptiveMaxPool1d(pool_out3)
        self.pool4 = nn.AdaptiveMaxPool1d(pool_out4)
        self.pool_final = nn.AdaptiveMaxPool1d(1)

        self.lazy_linear = nn.LazyLinear(out_features = num_classes)
        self.lazy_bn = nn.LazyBatchNorm1d()
        self.lazy_bn2 = nn.LazyBatchNorm1d()
        self.lazy_bn3 = nn.LazyBatchNorm1d()
        self.GELU = nn.GELU()
        self.dropout = nn.Dropout1d(dropout)
        self.initialize_weights()

        #Hyperparameters
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.train_set = train_set
        self.val_set = val_set
        #self.dataset = dataset
        #self.train_sampler = train_sampler
        self.stepsize = np.around(self.train_set.__len__()*0.8/self.batch_size) #for cycling lr
        #self.val_sampler = val_sampler
        self.class_weights = class_weights #torch.from_numpy(train_set.__len__() / (2 * np.bincount(train_set.targets)))
        self.loss = nn.CrossEntropyLoss(weight = self.class_weights) 
        self.acc = Accuracy(task = "multiclass", num_classes = num_classes)

        #Log Hyperparameters
        self.save_hyperparameters()

    def forward(self, x):
        x = self.conv1(x) #[32, 256, 166]
        x = self.lazy_bn(x)
        x = self.GELU(x)
        x = self.pool1(x) #[32, 256, 120]
        x = self.dropout(x)
        x = self.conv2(x) #[32, 256, 114]
        x = self.lazy_bn(x)
        x = self.GELU(x)
        x = self.pool2(x) #[32, 256, 60]
        x = self.dropout(x)
        x = self.conv3(x) #[32, 128, 56]
        x = self.lazy_bn2(x)
        x = self.GELU(x)
        x = self.pool3(x) #[32, 128, 30]
        x = self.dropout(x)
        x = self.conv4(x) #[32, 64, 26]
        x = self.lazy_bn3(x)
        x = self.GELU(x)
        x = self.pool4(x) #[32, 64, 15]
        #x = self.dropout(x)
        x = self.pool_final(x) #[32, 64, 1]
        x = torch.squeeze(x)
        x = self.lazy_linear(x)
        x = self.GELU(x) #[32 x 3]
        return x

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_uniform_(m.weight)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def configure_optimizers(self):
        optimizer = optim.Adam(params = self.parameters(), lr = self.learning_rate, weight_decay = 0.01) #wd = 0.01 by default #lr_before = 0.00043
        #parameters for the cycling lr scheduler are chosen according to Smith (2015): https://arxiv.org/pdf/1506.01186.pdf
        lr = torch.optim.lr_scheduler.CyclicLR(
            optimizer, base_lr = self.learning_rate,
            max_lr = 4*self.learning_rate,
            step_size_up = 4*int(self.stepsize),
            mode = "triangular",
            cycle_momentum = False
            )
        #Fix pickling bug for cycling learning rate (https://github.com/pytorch/pytorch/issues/88684)
        #instantiate the WeakMethod in the lr scheduler object into the custom scale function attribute
        lr._scale_fn_custom = lr._scale_fn_ref()
        #remove the reference so there are no more WeakMethod references in the object
        lr._scale_fn_ref = None
        lr_scheduler = {
            "scheduler": lr,
            "interval": "step",
            "name": "Learning Rate Scheduling"
        }
        # return {"optimizer": optimizer,
        #         "lr_scheduler": lr_scheduler}
        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        x,y = batch
        logit = self.forward(x.float())
        train_loss = self.loss(logit, y)
        _, y_pred = torch.max(logit, dim = 1)
        _, y_true = torch.max(y, dim = 1)
        return {"loss": train_loss, "y_pred": y_pred, "y_true": y_true}

    def training_epoch_end(self, outputs):
        train_loss_epoch = torch.stack([x["loss"] for x in outputs]).mean()
        y_pred_epoch = torch.stack([x["y_pred"] for x in outputs]).flatten()
        y_true_epoch = torch.stack([x["y_true"] for x in outputs]).flatten()
        class_acc = torch.bincount(y_true_epoch[y_true_epoch == y_pred_epoch])
        class_count = torch.bincount(y_true_epoch)

        #Log rel. amount of falsely predicted targets per class
        for ix, tclass in enumerate(class_acc):
            if tclass == 0:
                self.log(f"class{ix}_acc_train", 0.0, on_epoch = True, prog_bar = False)
            else:
                self.log(f"class{ix}_acc_train", tclass/class_count[ix], on_epoch = True, prog_bar = False)
        
        train_acc = self.acc(y_pred_epoch, y_true_epoch)
        self.log("train/loss", train_loss_epoch, on_epoch = True, prog_bar = True)
        self.log("train/acc", train_acc, on_epoch = True, prog_bar = True)
        
    def validation_step(self, batch, batch_idx):
        x,y = batch
        logit = self.forward(x.float())
        val_loss = self.loss(logit, y)
        _, y_pred = torch.max(logit, dim = 1)
        _, y_true = torch.max(y, dim = 1)
        return {"loss": val_loss, "y_pred": y_pred, "y_true": y_true}

    def validation_epoch_end(self, outputs):
        val_loss_epoch = torch.stack([x["loss"] for x in outputs]).mean()
        y_pred_epoch = torch.stack([x["y_pred"] for x in outputs]).flatten()
        y_true_epoch = torch.stack([x["y_true"] for x in outputs]).flatten()
        class_acc = torch.bincount(y_true_epoch[y_true_epoch == y_pred_epoch])
        class_count = torch.bincount(y_true_epoch)

        #Log rel. amount of falsely predicted targets per class
        for ix, tclass in enumerate(class_acc):
            if tclass == 0:
                self.log(f"class{ix}_acc_val", 0.0, on_epoch = True, prog_bar = False)
            else:
                self.log(f"class{ix}_acc_val", tclass/class_count[ix], on_epoch = True, prog_bar = False)
                
        val_acc = self.acc(y_pred_epoch, y_true_epoch)
        self.log("val/loss", val_loss_epoch, on_epoch = True, prog_bar = True)
        self.log("val/acc", val_acc, on_epoch = True, prog_bar = True)

    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.train_set, batch_size = self.batch_size,
                                           shuffle = True, drop_last = True)
        return train_loader

    def val_dataloader(self):
        val_loader = torch.utils.data.DataLoader(self.val_set, batch_size = self.batch_size,
                                         shuffle = False, drop_last = True)
        return val_loader

model = ConvNet()



4. Logging & Model Fitting

In [15]:
wandb.finish()
wandb_logger = WandbLogger(project="EEG_Analysis", log_model = True)
lr_monitor = LearningRateMonitor(logging_interval='step')
# checkpoint_callback = ModelCheckpoint(
#     dirpath='C:/Users/Daydreamore/Desktop/Semester/BCI/model_checkpoints',
#     monitor='val/acc',
#     save_top_k=2
# )
    
trainer = pl.Trainer(max_epochs = 200, gpus = 1, logger = wandb_logger,
                    auto_lr_find = False, callbacks = [lr_monitor])

0,1
Learning Rate Scheduling,▆▁▇▇▄▃▄▄▇▆▁▁▇▇▄▃▄█▇▂▁▆▆▅▄▃▃█▃▂▅▆▆▅▂▃██▃▂
class0_acc_train,▁▃▅▆▆▇▆▆▆▇▆▇▆▅▆▇▆▆▇▆▇▇▇▇▇█▇▇▇▇▇█▇▇▇▇▇███
class0_acc_val,███████████▇█▇▇█▇▇▇▇██▇▇▇▇▇▇▆▇▇▆▆▅▃▁▃▁▃▃
class1_acc_train,▁▄▄▄▅▅▆▆▆▆▇▆▆▆▆▇▇▆▇▇▇▇▇▇▇▇█▇▇█▇█████▇███
class1_acc_val,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▅▄▅▅▄▅▄▅▄▅▆▅▆▅█▇▅▆▇▇▇▇▇███
class2_acc_train,▄▂▁▁▁▁▁▁▃▃▃▅▄▅▅▅▅▅▆▆▆▆▆▇▆▇▆▆▇▇▇▇▇▇▇▇▇██▇
class2_acc_val,▂▁▁▁▂▂▃▄▅▄▆▆▅▆▆▇▆▇▆▆▆▇▇▇▇▇▇▆▇▇▇▇▇█▇▇█▆▇▇
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/acc,▁▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇█▇▇▇▇▇███
train/loss,█▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁

0,1
Learning Rate Scheduling,6e-05
class0_acc_train,0.88506
class0_acc_val,0.47619
class1_acc_train,0.80077
class1_acc_val,0.85714
class2_acc_train,0.68898
class2_acc_val,0.74138
epoch,199.0
train/acc,0.79253
train/loss,0.77883


  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [16]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name        | Type               | Params
----------------------------------------------------
0  | conv1       | Conv1d             | 14.3 K
1  | conv2       | Conv1d             | 327 K 
2  | conv3       | Conv1d             | 163 K 
3  | conv4       | Conv1d             | 41.0 K
4  | pool1       | AdaptiveMaxPool1d  | 0     
5  | pool2       | AdaptiveMaxPool1d  | 0     
6  | pool3       | AdaptiveMaxPool1d  | 0     
7  | pool4       | AdaptiveMaxPool1d  | 0     
8  | pool_final  | AdaptiveMaxPool1d  | 0     
9  | lazy_linear | LazyLinear         | 0     
10 | lazy_bn     | LazyBatchNorm1d    | 0     
11 | lazy_bn2    | LazyBatchNorm1d    | 0     
12 | lazy_bn3    | LazyBatchNorm1d    | 0     
13 | GELU        | GELU               | 0     
14 | dropout     | Dropout1d          | 0     
15 | loss        | CrossEntropyLoss   | 0     
16 | acc         | MulticlassAccuracy | 0     
----------------------------------------------------
547 K

Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 20.41it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 199: 100%|██████████| 80/80 [00:03<00:00, 22.25it/s, loss=0.106, v_num=7yki, val/loss=0.285, val/acc=0.950, train/loss=0.139, train/acc=0.950] 

`Trainer.fit` stopped: `max_epochs=200` reached.


Epoch 199: 100%|██████████| 80/80 [00:03<00:00, 21.42it/s, loss=0.106, v_num=7yki, val/loss=0.285, val/acc=0.950, train/loss=0.139, train/acc=0.950]
