## Detecting epileptic seizures

In [None]:
import os

import numpy as np
import pandas as pd

import lightning as L
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import Dataset, DataLoader, random_split

from model import LSTMDetector
from dataset import EpilepsyDataset
from dataset_formatter import DatasetFormatter
from model_arguments import ModelArguments

### Defining parameters (such as data paths, signals, etc.)

In [None]:
arguments = ModelArguments(
    '/workspace/new_data/', 
    '/workspace/labels.csv', 
    ['Acc x', 'Acc y', 'Acc z', 'Acc Mag', 'EDA', 'BVP'],
    True
)

### Set the train and test datasets

In [None]:
epilepsy_dataset = EpilepsyDataset(arguments)

In [None]:
train_size = int(0.7 * len(epilepsy_dataset))
test_size = len(epilepsy_dataset) - train_size

train_dataset, test_dataset = random_split(epilepsy_dataset, [train_size, test_size])

train_dataloader = DataLoader(epilepsy_dataset, batch_size=16, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

### Defining our model

In [None]:
model = LSTMDetector(arguments.number_of_channels)

### Training

In [None]:
wandb_logger = WandbLogger()

trainer = L.Trainer(max_epochs=25, logger=wandb_logger)
trainer.fit(model=model, train_dataloaders=train_dataloader)

### Testing

In [None]:
trainer.test(dataloaders=test_dataloader, ckpt_path='best')