## Multimodal vs. Baseline Models
- Notice this notebook runs with the sample data for you to play with the code. Please change the configurations and use the full data to get better results.

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from sklearn.model_selection import GroupKFold
from utils.config import Config
from utils.loss import KLDivLossWithLogits, CrossEntropyLossWithLogits
from eeg_dataset.data_loader import EEGDataset, EEGSpecDataset, EEGWaveDataset
from eeg_dataset.data_utils import get_fold_dls
from exp.experiment import Experiment

In [2]:
!pip install pytorchtools

Collecting pytorchtools
  Downloading pytorchtools-0.0.2-py2.py3-none-any.whl.metadata (2.2 kB)
Downloading pytorchtools-0.0.2-py2.py3-none-any.whl (3.1 kB)
Installing collected packages: pytorchtools
Successfully installed pytorchtools-0.0.2


In [3]:
if not os.path.exists(Config.output_dir):
    os.makedirs(Config.output_dir)

In [4]:
pl.seed_everything(Config.seed, workers=True)

42

### Load the data

In [5]:
Config.data_root = "/kaggle/input/hba-sampled-data/"
Config.PRE_LOADED_SPECTOGRAMS = "/kaggle/input/hba-sampled-data/eeg_specs.npy"
Config.PRE_LOADED_EEGS = "/kaggle/input/hba-sampled-data/eegs.npy"

In [6]:
"""The main function to run the experiment."""

# Load the data
df = pd.read_csv(f"{Config.data_root}train_300_patients.csv")
TARGETS = Config.TARGETS

train_df = df.groupby("eeg_id")[
    ["spectrogram_id", "spectrogram_label_offset_seconds"]
].agg({"spectrogram_id": "first", "spectrogram_label_offset_seconds": "min"})
train_df.columns = ["spectogram_id", "min"]

aux = df.groupby("eeg_id")[["spectrogram_id", "spectrogram_label_offset_seconds"]].agg(
    {"spectrogram_label_offset_seconds": "max"}
)
train_df["max"] = aux

aux = df.groupby("eeg_id")[["patient_id"]].agg("first")
train_df["patient_id"] = aux

aux = df.groupby("eeg_id")[TARGETS].agg("sum")
for label in TARGETS:
    train_df[label] = aux[label].values

y_data = train_df[TARGETS].values
y_data = y_data / y_data.sum(axis=1, keepdims=True)
train_df[TARGETS] = y_data
aux = df.groupby("eeg_id")[["expert_consensus"]].agg("first")
train_df["target"] = aux
train = train_df.reset_index()
all_spectrograms = np.load(Config.PRE_LOADED_SPECTOGRAMS, allow_pickle=True).item()
all_eegs = np.load(Config.PRE_LOADED_EEGS, allow_pickle=True).item()

# KFold split
gkf = GroupKFold(n_splits=5)
train["fold"] = 0
for fold, (_, val_idx) in enumerate(gkf.split(train, train.target, train.patient_id)):
    train.loc[val_idx, "fold"] = fold

kfold_data = {"all": [], "spec": [], "wave": []}
for fold_id in range(5):
    df_train = train[train.fold != fold_id]
    df_valid = train[train.fold == fold_id]
    kfold_data["all"].append(
        get_fold_dls(df_train, df_valid, all_eegs, all_spectrograms, EEGDataset)
    )
    kfold_data["spec"].append(
        get_fold_dls(df_train, df_valid, all_eegs, all_spectrograms, EEGSpecDataset)
    )
    kfold_data["wave"].append(
        get_fold_dls(df_train, df_valid, all_eegs, all_spectrograms, EEGWaveDataset)
    )

In [7]:
len(kfold_data["all"])

5

### Build up the EEGWaveNet and EEGSpecNet model

In [8]:
import torch

torch.set_float32_matmul_precision("high")

In [9]:
loss_name = "KLDiv"
model_name = "EEGModel"
if loss_name == "CrossEntropy":
    loss_cls = CrossEntropyLossWithLogits
else:
    loss_cls = KLDivLossWithLogits

# run task
if model_name == "EEGModel":
    exp = Experiment(kfold_data["all"], train, model_name, loss_cls)
elif model_name == "EEGWaveNet":
    exp = Experiment(kfold_data["wave"], train, model_name, loss_cls)
elif model_name == "EEGSpecNet":
    exp = Experiment(kfold_data["spec"], train, model_name, loss_cls)

config = Config

print(f">>>>>>>> start training: {model_name} >>>>>>>>")
exp.kfold_train(config)

>>>>>>>> start training: EEGModel >>>>>>>>
500


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Running trainer.fit


2024-04-06 03:18:37.923066: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-06 03:18:37.923164: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-06 03:18:38.039874: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Fold 0: Epoch 0 validation loss 1.2857465744018555
Fold 0: Epoch 0 validation KDL score 1.2857465744018555


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 0: Epoch 0 validation loss 1.0597150325775146
Fold 0: Epoch 0 validation KDL score 1.0597150325775146


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 0: Epoch 1 validation loss 1.0514053106307983
Fold 0: Epoch 1 validation KDL score 1.0514053106307983


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 0: Epoch 2 validation loss 1.0261166095733643
Fold 0: Epoch 2 validation KDL score 1.0261166095733643


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 0: Epoch 3 validation loss 0.961030900478363
Fold 0: Epoch 3 validation KDL score 0.961030900478363


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 0: Epoch 4 validation loss 0.8951680064201355
Fold 0: Epoch 4 validation KDL score 0.8951680064201355


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

100%|██████████| 16/16 [00:50<00:00,  3.15s/it]


500
Running trainer.fit


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 0 validation loss 1.2724164724349976
Fold 1: Epoch 0 validation KDL score 1.2724164724349976


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 0 validation loss 1.1719779968261719
Fold 1: Epoch 0 validation KDL score 1.1719779968261719


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 0 validation loss 0.9886029958724976
Fold 1: Epoch 0 validation KDL score 0.9886029958724976


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 1 validation loss 0.9081205725669861
Fold 1: Epoch 1 validation KDL score 0.9081205725669861


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 1 validation loss 0.8578295707702637
Fold 1: Epoch 1 validation KDL score 0.8578295707702637


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 2 validation loss 0.8296566009521484
Fold 1: Epoch 2 validation KDL score 0.8296566009521484


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 3 validation loss 0.7530131340026855
Fold 1: Epoch 3 validation KDL score 0.7530131340026855


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 3 validation loss 0.7209646701812744
Fold 1: Epoch 3 validation KDL score 0.7209646701812744


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 4 validation loss 0.6777540445327759
Fold 1: Epoch 4 validation KDL score 0.6777540445327759


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 5 validation loss 0.6515794992446899
Fold 1: Epoch 5 validation KDL score 0.6515794992446899


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 7 validation loss 0.6423954963684082
Fold 1: Epoch 7 validation KDL score 0.6423954963684082


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 1: Epoch 8 validation loss 0.6406174898147583
Fold 1: Epoch 8 validation KDL score 0.6406174898147583


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

100%|██████████| 16/16 [00:51<00:00,  3.20s/it]


500
Running trainer.fit


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Fold 2: Epoch 0 validation loss 1.4559731483459473
Fold 2: Epoch 0 validation KDL score 1.4559731483459473


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 2: Epoch 0 validation loss 1.1487455368041992
Fold 2: Epoch 0 validation KDL score 1.1487455368041992


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 2: Epoch 0 validation loss 0.9781564474105835
Fold 2: Epoch 0 validation KDL score 0.9781564474105835


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 2: Epoch 1 validation loss 0.8353769183158875
Fold 2: Epoch 1 validation KDL score 0.8353769183158875


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 2: Epoch 3 validation loss 0.7744247317314148
Fold 2: Epoch 3 validation KDL score 0.7744247317314148


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 2: Epoch 4 validation loss 0.7471546530723572
Fold 2: Epoch 4 validation KDL score 0.7471546530723572


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

100%|██████████| 16/16 [00:54<00:00,  3.44s/it]


500
Running trainer.fit


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 0 validation loss 1.552046775817871
Fold 3: Epoch 0 validation KDL score 1.552046775817871


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 0 validation loss 1.2569184303283691
Fold 3: Epoch 0 validation KDL score 1.2569184303283691


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 0 validation loss 0.9798586368560791
Fold 3: Epoch 0 validation KDL score 0.9798586368560791


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 1 validation loss 0.937307596206665
Fold 3: Epoch 1 validation KDL score 0.937307596206665


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 1 validation loss 0.9172518849372864
Fold 3: Epoch 1 validation KDL score 0.9172518849372864


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 2 validation loss 0.8661419153213501
Fold 3: Epoch 2 validation KDL score 0.8661419153213501


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 3 validation loss 0.8442893028259277
Fold 3: Epoch 3 validation KDL score 0.8442893028259277


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 4 validation loss 0.7222583293914795
Fold 3: Epoch 4 validation KDL score 0.7222583293914795


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 5 validation loss 0.7162591814994812
Fold 3: Epoch 5 validation KDL score 0.7162591814994812


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 6 validation loss 0.7138615250587463
Fold 3: Epoch 6 validation KDL score 0.7138615250587463


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 7 validation loss 0.701764702796936
Fold 3: Epoch 7 validation KDL score 0.701764702796936


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 7 validation loss 0.6890671253204346
Fold 3: Epoch 7 validation KDL score 0.6890671253204346


Validation: |          | 0/? [00:00<?, ?it/s]

Fold 3: Epoch 8 validation loss 0.664002001285553
Fold 3: Epoch 8 validation KDL score 0.664002001285553


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

100%|██████████| 16/16 [00:53<00:00,  3.36s/it]


500
Running trainer.fit


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Fold 4: Epoch 0 validation loss 1.419403076171875
Fold 4: Epoch 0 validation KDL score 1.419403076171875


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 4: Epoch 0 validation loss 1.099717140197754
Fold 4: Epoch 0 validation KDL score 1.099717140197754


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 4: Epoch 2 validation loss 0.998495876789093
Fold 4: Epoch 2 validation KDL score 0.998495876789093


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Fold 4: Epoch 4 validation loss 0.942540168762207
Fold 4: Epoch 4 validation KDL score 0.942540168762207


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

100%|██████████| 16/16 [00:53<00:00,  3.33s/it]


OOF Score for solution = 0.779089081497595
