# Introduction

We are building a machine learning pipeline for classification of EEG signals.

Preprocessing files will be run separetely from this notebook, and we will import their variables.

This notebook will focus on creating the pipeline for assessing the best model to detect seizures in EEG signals. We will use three main strategies:

* Res2Net Transformer
* 1D-CNN + LSTM 
* Gated 2 Tower Transformer 

# Importing Packages

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import lightning.pytorch as pl
import pickle
import copy

print(torch.cuda.is_available())
print(torch.__version__)


plt.style.use("ggplot")

True
2.0.0+cu118


# Importing and Preprocessing Data

In [2]:
%run ./preprocessing.py

In [3]:
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
X_val = torch.FloatTensor(X_val)
y_train = torch.FloatTensor(y_train).unsqueeze(1)
y_test = torch.FloatTensor(y_test).unsqueeze(1)
y_val = torch.FloatTensor(y_val).unsqueeze(1)

In [4]:
y_train.unique(return_counts=True)

(tensor([0., 1.]), tensor([6440, 1610]))

In [5]:
class EEGDataset(Dataset):
    def __init__(self, features, target) -> None:
        super().__init__()
        self.features = features
        self.target = target

    def __getitem__(self, index):
        data = {}
        features = self.features[index]
        target = self.target[index]
        data["X"] = features
        data["y"] = target
        return data

    def __len__(self):
        return len(self.features)


BATCH_SIZE = 64
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
train_dataloader = DataLoader(EEGDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(EEGDataset(X_val, y_val), batch_size=BATCH_SIZE, shuffle=False)
final_train_dataloader = DataLoader(EEGDataset(torch.cat((X_train, X_val), 0), torch.cat((y_train, y_val), 0)), batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(EEGDataset(X_test, y_test), batch_size=BATCH_SIZE, shuffle=False)

# Creating Models

## Training code

In [6]:


def training(
        model, train_dataloader=None, val_dataloader=None,
        epochs=5, lr=0.001, device='cpu', earlystopping_tolerance=5):
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([9200/2300]).to(device))
    # criterion = nn.BCELoss()
    model_state = {
        "model": None,
        "train_loss": [],
        "val_loss": [],
    }
    best_validation = np.inf
    best_model = None
    count_tolerance = 0


    for epoch in range(epochs):
        training_loss = 0
        for i, data in enumerate(train_dataloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            X, y = data["X"].to(device), data["y"].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(X)
            train_loss = criterion(outputs, y)
            train_loss.backward()
            optimizer.step()

            # print statistics
            training_loss += train_loss.item()
        
        training_loss /= i
        
        if isinstance(val_dataloader, DataLoader):
            validation_loss = 0
            for j, data in enumerate(val_dataloader, 1):
                # get the inputs; data is a list of [inputs, labels]
                X, y = data["X"].to(device), data["y"].to(device)


                # forward + backward + optimize
                with torch.no_grad():
                    outputs = model(X)
                    val_loss = criterion(outputs, y)
                    # print statistics
                    validation_loss += val_loss.item()

            validation_loss /= j
            
            if (validation_loss) < best_validation:
                count_tolerance = 0
                best_validation = validation_loss
                best_model = copy.deepcopy(model)
            
            count_tolerance += 1
            print(f"Epoch: {epoch}\tTraining loss: {training_loss:.5f}\t\t Validation Loss: {validation_loss:.5f}")
            model_state["train_loss"].append(training_loss)
            model_state["val_loss"].append(validation_loss)

            if count_tolerance >= earlystopping_tolerance:
                break
            
        else:
            print(f"Epoch: {epoch}\tTraining loss: {training_loss:.5f}")
            model_state["train_loss"].append(training_loss)
            best_model = copy.deepcopy(model)
        
    
    model_state["model"] = best_model
    save_model(model_state)
    return model_state

def save_model(model_state):
    with open(f"models/{ model_state['model'].to_string() }.pkl", "wb") as fp:
        model_state["model"] = model_state["model"].to("cpu").state_dict()
        pickle.dump(model_state, fp)
        print("Saved model successfully!")

## 1D CNN-LSTM Model

In [7]:
class CNN_LSTM_Classifier(pl.LightningModule):
    def __init__(self,):
        super().__init__()
        self.device_ = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.conv_1 = nn.Conv1d(1, 64, 3)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool1d(2, 2)
        self.conv_layers = nn.Sequential(
            nn.Conv1d(64, 128, 3),
            nn.ReLU(),
            nn.Conv1d(128, 512, 3),
            nn.ReLU(),
            nn.Conv1d(512, 1024, 3),
            nn.ReLU()
        )
        self.flatten_layer = nn.Linear(82, 256)
        dropout = 0.2
        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(1024, 64, 2, batch_first=True, dropout=dropout)

        self.fc_out = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )


    def forward(self, X: torch.FloatTensor, y=None):
        X = X.transpose(1, 2)
        out = self.relu(self.conv_1(X))
        out = self.max_pool(out)
        out = self.conv_layers(out)
        out = self.flatten_layer(out)
        out = out.transpose(1, 2)
        out, (_, _) = self.lstm(out)
        out = out[:, -1, :]
        out = self.fc_out(out)
        return out
    
    def predict_batch(self, X: torch.FloatTensor):
        pred = (torch.sigmoid(self(X)) > 0.5).int()
        return pred

    def predict(self, dataloader: DataLoader):
        predictions = list()
        for i, data in enumerate(dataloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            with torch.no_grad():
                X = data["X"]
                y_pred = self.predict_batch(X)
                predictions.append(y_pred)
        predictions = torch.cat(predictions, 0)
        return predictions

    def to_string(self):
        return "CNN_LSTM_Classifier"


    
model = CNN_LSTM_Classifier()
state = training(model, train_dataloader, val_dataloader, device=DEVICE, epochs=5, lr=0.0001)
# state["model"]
# model.predict(test_dataloader).shape

Epoch: 0	Training loss: 1.10903		 Validation Loss: 1.10591
Epoch: 1	Training loss: 1.04071		 Validation Loss: 0.97511
Epoch: 2	Training loss: 0.91864		 Validation Loss: 0.90266
Epoch: 3	Training loss: 0.86137		 Validation Loss: 0.88162
Epoch: 4	Training loss: 0.83003		 Validation Loss: 0.84592
Saved model successfully!


In [8]:
# plt.figure(figsize=(8, 6))
# plt.plot(state["train_loss"], label="Training Loss")
# plt.plot(state["val_loss"], label="Validation Loss")
# plt.legend(loc="best")
# plt.title("Training x Validation Losses")
# plt.xlabel("Epochs")
# plt.ylabel("Binary Cross Entropy Loss")
# plt.show()

## Positional Encoding

In [9]:
class PositionalEncoding(nn.Module):
    """
    The authors of the original transformer paper describe very succinctly what 
    the positional encoding layer does and why it is needed:
    
    "Since our model contains no recurrence and no convolution, in order for the 
    model to make use of the order of the sequence, we must inject some 
    information about the relative or absolute position of the tokens in the 
    sequence." (Vaswani et al, 2017)
    Adapted from: 
    https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    """

    def __init__(
        self, 
        dropout: float=0.1, 
        max_seq_len: int=5000, 
        d_model: int=512,
        batch_first: bool=False
        ):

        """
        Parameters:
            dropout: the dropout rate
            max_seq_len: the maximum length of the input sequences
            d_model: The dimension of the output of sub-layers in the model 
                     (Vaswani et al, 2017)
        """

        super().__init__()

        self.d_model = d_model
        
        self.dropout = nn.Dropout(p=dropout)

        self.batch_first = batch_first

        self.x_dim = 1 if batch_first else 0

        # copy pasted from PyTorch tutorial
        position = torch.arange(max_seq_len).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        
        pe = torch.zeros(max_seq_len, 1, d_model)
        
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, enc_seq_len, dim_val] or 
               [enc_seq_len, batch_size, dim_val]
        """

        x = x + self.pe[:x.size(self.x_dim)]

        return self.dropout(x)

## Gated Transformer Network

In [10]:
class GatedTransformerNet(nn.Module):
    def __init__(self, device="cpu"):
        super().__init__()
        self.device = device 
        self.dropout = nn.Dropout(0.2)

        self.d_model = 512

        self.step_embedding = nn.Linear(1, self.d_model)
        self.channel_embedding = nn.Linear(1, self.d_model)
        self.positional_embedding = PositionalEncoding(d_model=self.d_model, dropout=0.2)
        self.tanh = nn.Tanh()

        self.step_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.d_model, nhead=8), num_layers=2)
        self.channel_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.d_model, nhead=8), num_layers=2)

        self.gating = nn.Linear(self.d_model*178*2, 2)

        self.fc_out = nn.Sequential(
            nn.Linear(self.d_model*178*2, 256),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(64, 1),
        )
    
    def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        """Generates an upper-triangular matrix of -inf, with zeros on diag."""
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1).to(self.device)


    def forward(self, X: torch.FloatTensor, y=None):
        batch_size = X.shape[0]
        seq_len = X.shape[1]
        mask = self.generate_square_subsequent_mask(batch_size)

        channel = self.tanh(self.channel_embedding(X))
        channel = self.dropout(self.channel_encoder(channel))

        step = self.tanh(self.step_embedding(X))
        step = self.positional_embedding(step)
        step = self.dropout(self.step_encoder(step, mask))

        channel = channel.reshape(batch_size, -1)
        step = step.reshape(batch_size, -1)

        concat = torch.cat([channel, step], -1)
        h = self.gating(concat)
        gate = torch.softmax(h, dim=-1)


        encoding = torch.cat([channel * gate[:, 0:1], step * gate[:, 1:2]], dim=-1)
        encoding = self.dropout(encoding)
        out = self.fc_out(encoding)

        return out
    
    def predict_batch(self, X: torch.FloatTensor):
        pred = (torch.sigmoid(self(X)) > 0.5).int()
        return pred

    def predict(self, dataloader: DataLoader):
        predictions = list()
        for i, data in enumerate(dataloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            with torch.no_grad():
                X = data["X"]
                y_pred = self.predict_batch(X)
                predictions.append(y_pred)
        predictions = torch.cat(predictions, 0)
        return predictions

    def to_string(self):
        return "GatedTransformerNet"

model = GatedTransformerNet(device=DEVICE)
state = training(model, train_dataloader, val_dataloader, device=DEVICE, epochs=5, lr=0.0001)
# model(X_train[:32]).shape

Epoch: 0	Training loss: 1.12035		 Validation Loss: 1.07689
Epoch: 1	Training loss: 0.97601		 Validation Loss: 0.90811
Epoch: 2	Training loss: 0.85421		 Validation Loss: 0.82152
Epoch: 3	Training loss: 0.77639		 Validation Loss: 0.78973
Epoch: 4	Training loss: 0.70246		 Validation Loss: 0.76238
Saved model successfully!


In [11]:
X_train.shape

torch.Size([8050, 178, 1])

## Multilayer Perceptron Network

In [14]:
class MLPClassifier(nn.Module):
    def __init__(self, device="cpu"):
        super().__init__()
        self.device_ = device

        self.dropout_1 = nn.Dropout(0.1)

        self.fc_net = nn.Sequential(
            nn.Linear(178, 500),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(500, 1)
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        X = X.squeeze(-1).to(self.device_)
        X = self.dropout_1(X)
        return self.fc_net(X)
    
    def predict_batch(self, X: torch.FloatTensor):
        pred = (torch.sigmoid(self(X)) > 0.5).int()
        return pred

    def predict(self, dataloader: DataLoader):
        predictions = list()
        for i, data in enumerate(dataloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            with torch.no_grad():
                X = data["X"]
                y_pred = self.predict_batch(X)
                predictions.append(y_pred)
        predictions = torch.cat(predictions, 0)
        return predictions

    def to_string(self):
        return "MLPClassifier"
    

model = MLPClassifier(DEVICE)
state = training(model, train_dataloader, val_dataloader, device=DEVICE, epochs=5, lr=1e-4, earlystopping_tolerance=10)
model.predict_batch(X_train[:32])

Epoch: 0	Training loss: 1.08846		 Validation Loss: 1.05065
Epoch: 1	Training loss: 0.95194		 Validation Loss: 0.91078
Epoch: 2	Training loss: 0.84612		 Validation Loss: 0.82710
Epoch: 3	Training loss: 0.79448		 Validation Loss: 0.81759
Epoch: 4	Training loss: 0.76118		 Validation Loss: 0.78374
Saved model successfully!


tensor([[0],
        [0],
        [0],
        [0],
        [1],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0]], device='cuda:0', dtype=torch.int32)

## FCN (Fully Convoluted)

In [41]:
class FCN(nn.Module):
    def __init__(self, device="cpu"):
        super().__init__()
        self.device_ = device

        self.conv_fc = nn.Sequential(
            nn.Conv1d(1, 128, 3),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 256, 3),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Conv1d(256, 128, 3),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(128, 1)
        )

    def forward(self, X):
        X = X.transpose(-1, -2).to(self.device_)
        conv = self.conv_fc(X)
        return conv
    
    def predict_batch(self, X: torch.FloatTensor):
        pred = (torch.sigmoid(self(X)) > 0.5).int()
        return pred

    def predict(self, dataloader: DataLoader):
        predictions = list()
        for i, data in enumerate(dataloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            with torch.no_grad():
                X = data["X"]
                y_pred = self.predict_batch(X)
                predictions.append(y_pred)
        predictions = torch.cat(predictions, 0)
        return predictions

    def to_string(self):
        return "FCN"
    

model = FCN(DEVICE)
state = training(model, train_dataloader, val_dataloader, device=DEVICE, epochs=10, lr=1e-4, earlystopping_tolerance=10)

model(X_train[:32]).shape

Epoch: 0	Training loss: 0.78734		 Validation Loss: 0.62891
Epoch: 1	Training loss: 0.56710		 Validation Loss: 0.52135
Epoch: 2	Training loss: 0.47348		 Validation Loss: 0.44257
Epoch: 3	Training loss: 0.41519		 Validation Loss: 0.38716
Epoch: 4	Training loss: 0.37916		 Validation Loss: 0.36255
Epoch: 5	Training loss: 0.34183		 Validation Loss: 0.36619
Epoch: 6	Training loss: 0.32938		 Validation Loss: 0.31104
Epoch: 7	Training loss: 0.31243		 Validation Loss: 0.31013
Epoch: 8	Training loss: 0.29606		 Validation Loss: 0.29248
Epoch: 9	Training loss: 0.28701		 Validation Loss: 0.28065
Saved model successfully!


torch.Size([32, 1])

## Residual Network

In [43]:
class ResNet(nn.Module):
    def __init__(self, device="cpu"):
        super().__init__()
        self.device_ = device

        self.conv_block_1 = nn.Sequential(
            nn.Conv1d(1, 64, 3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, 3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, 3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )

    def forward(self, X):
        X = X.transpose(-1, -2).to(self.device_)
        conv = self.conv_block_1(X)
        return conv
    

model = ResNet()
# state = training(model, train_dataloader, val_dataloader, device=DEVICE, epochs=10, lr=1e-4, earlystopping_tolerance=10)

model(X_train[:32]).shape

torch.Size([32, 64, 178])