# imports

In [1]:
from datetime import datetime
import einops
import wandb
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset, DataLoader, random_split
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from functools import lru_cache

device = 'cuda' if t.cuda.is_available() else 'cpu'

# utils

In [2]:
import gc 
def GC():
    gc.collect()
    t.cuda.empty_cache()

In [3]:
@t.no_grad()
def eval(model, x, y, do_eval=True):
    assert not x.isnan().any()
    assert not y.isnan().any()
    if do_eval: model.eval()
    else: model.train()
    logs = model(x.to(device)).log_softmax(-1)
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    loss = kl_loss(logs, y.to(device))
    model.train()
    return loss

# config

In [4]:
batch_size = 50
prefetch_factor = 10
num_workers = 3

# data

In [5]:
test_path = './hms-harmful-brain-activity-classification/test_eegs/'
train_path = './hms-harmful-brain-activity-classification/train_eegs/'
BASE_PATH = "./hms-harmful-brain-activity-classification"
class_names = ['Seizure', 'LPD', 'GPD', 'LRDA','GRDA', 'Other']
FEATS_FOR_REAL = ['Fp1', 'F3', 'C3', 'P3', 'F7', 'T3', 'T5', 'O1', 'Fz', 'Cz', 'Pz', 'Fp2', 'F4', 'C4', 'P4', 'F8', 'T4', 'T6', 'O2', 'EKG']
#                   0      1     2     3     4     5     6     7     8     9    10     11    12    13    14    15    16    17    18    19
# group by semantic groups LP, LL, RP, RR https://raw.githubusercontent.com/cdeotte/Kaggle_Images/main/Jan-2024/montage.png
# GROUPS = [
#     ['Fp1', 'F3', 'C3', 'P3', 'O1'],
#     ['Fp1', 'F7', 'T3', 'T5', 'O1'],
#     ['Fp2', 'F4', 'C4', 'P4', 'O2'],
#     ['Fp2', 'F8', 'T4', 'T6', 'O2'],
# ]
GROUPS_IDS = [
    [0, 1, 2, 3, 7],
    [0, 4, 5, 6, 7],
    [11, 12, 13, 14, 18],
    [11, 15, 16, 17, 18],
    # [8, 9, 10, 19] # TODO: try with leftovers?
]
# TODO: add frequency domain with fourier's transform
# TODO: add spectrogram to process with conv2d
# TODO: merge several models together

TARS = {'Seizure':0, 'LPD':1, 'GPD':2, 'LRDA':3, 'GRDA':4, 'Other':5}
TARGETS = ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote','other_vote']

In [6]:
test_df = pd.read_csv(f'{BASE_PATH}/test.csv')
test_df['eeg_path'] = f'{BASE_PATH}/test_eegs/'+test_df['eeg_id'].astype(str)+'.parquet'
test_df['spec_path'] = f'{BASE_PATH}/test_spectrograms/'+test_df['spectrogram_id'].astype(str)+'.parquet'

In [7]:
train_df = pd.read_csv(f'{BASE_PATH}/train.csv')
eeg_path = f'{BASE_PATH}/train_eegs/'+train_df['eeg_id'].astype(str)+'.parquet'
class_name = train_df.expert_consensus.copy()

In [8]:
class Dataset(Dataset):
    def __init__(self, transform=None):
        super().__init__()
        self.dataframe = train_df

    def __len__(self):
        return len(self.dataframe)

    # @lru_cache(maxsize=None)
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        eeg_id = row['eeg_id']
        parq_path = f'{train_path}{eeg_id}.parquet'
        eeg = pd.read_parquet(parq_path)
        start_time_second = row['eeg_label_offset_seconds']
        offset_dp = int(start_time_second * 200)
        duration = 10_000
    
        eeg = eeg.iloc[offset_dp:offset_dp+duration]
        eeg = eeg.ffill(axis=0)
        eeg = eeg.fillna(0)
        labels = row[TARGETS].values.astype(np.float64)
        labels = labels/np.sum(labels)
        samples = t.tensor(eeg[FEATS_FOR_REAL].values)
        labels_out = t.tensor(labels,dtype=t.float64)
        
        # assert not samples.isnan().any()
        # assert not labels_out.isnan().any()
        return samples, labels_out

In [9]:
dataset = Dataset()
train_size = int(len(dataset) * 0.9)
test_size = len(dataset) - train_size
# TODO: check for overlap because train and test seem to be too correlated compared to the leaderboard data
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, shuffle=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, shuffle=True)

# model 👯‍♀️

## conv1d + GRU

In [10]:
class ConvBlock(nn.Module):
    def __init__(self, d_in, d_out, kernel_size, drop):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(d_in, d_out, kernel_size=kernel_size, padding='same', stride=1),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.Conv1d(d_out, d_out, kernel_size=kernel_size, padding='same', stride=1),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.Conv1d(d_out, d_out, kernel_size=kernel_size, padding='same', stride=1),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=0), # reduce sequence size by 2
        )
    def forward(self, x):
        # TODO: add skip for training speed
        return self.model(x)
        
class Model(nn.Module):
    def __init__(self, in_channels=20, gru_hidden_size=128, drop=0.2):
        super().__init__()
        self.pre_out = in_channels * 4
        self.gru_hidden_size = gru_hidden_size
        
        self.pre_process = nn.Sequential(
            nn.BatchNorm1d(in_channels, momentum=None),
            # use conv1d as a denoiser
            # block 1
            ConvBlock(in_channels, in_channels * 2, kernel_size=3, drop=drop),
            nn.BatchNorm1d(in_channels * 2, momentum=None),
            
            # block 2
            ConvBlock(in_channels * 2, in_channels * 4, kernel_size=5, drop=drop),
            nn.BatchNorm1d(self.pre_out, momentum=None),

            # block 3
            ConvBlock(in_channels * 4, in_channels * 4, kernel_size=7, drop=drop),
            nn.BatchNorm1d(self.pre_out, momentum=None),
        )
        
        # TODO: add a learnable first state for GRU or check what is the default
        self.gru = nn.GRU(self.pre_out, self.gru_hidden_size, num_layers=1, batch_first=True, bidirectional=True)

        self.head = nn.Sequential(
            nn.Linear(self.gru_hidden_size * 2, self.gru_hidden_size * 4),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.Linear(self.gru_hidden_size * 4, 6)
        )

    def forward(self, x: ('batch', 'seq', 'channel')):
        # pre_process: (batch, channel, seq) → (batch / 4, channel * 4, seq)
        x = x.permute((0, 2, 1))
        x = self.pre_process(x)
        x = x.permute((0, 2, 1))

        # GRU: (batch, seq, input_size), [(2 * num_layers, batch, hidden_size)] → (batch, seq, 2 * hidden_size)
        x, _ = self.gru(x)
        x = x[:, -1, :]

        # head: (batch, 2 * hidden_size) → (batch, 6)
        x = self.head(x)

        # out: → (batch, 6)
        return x

def scope():
    m = Model().to(device)
    x, y = next(train_dataloader.__iter__())
    r = m(x.to(device))
    print(f'{r.shape=}')
    
# scope()

## transformer

In [11]:
class Transformer(nn.Module):
    def __init__(self, d_chan=20, d_model=256, d_clump=4):
        super().__init__()
        self.d_clump = d_clump

        self.start = nn.Parameter(t.randn(1, 1, d_model))
        self.bn = nn.BatchNorm1d(d_chan)
        self.emb = nn.Linear(d_chan * d_clump, d_model)
        self.llm = nn.Transformer(d_model=d_model, nhead=8, num_encoder_layers=3, num_decoder_layers=0, dim_feedforward=d_model * 2, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, 6)
        )

    def forward(self, x):
        x = self.bn(x.permute((0, 2, 1))).permute((0, 2, 1))
        x = einops.rearrange(x, 'batch (seq clump) channels -> batch seq (clump channels)', clump=self.d_clump)
        x = self.emb(x)
        # add a fake start token
        x = t.cat([self.start.repeat(x.shape[0], 1, 1), x], dim=1)
        x = self.llm.encoder(x)[:, 0]
        return self.head(x)

def scope():
    val, label = next(train_dataloader.__iter__())
    model = Transformer().to(device)
    output = model(val.to(device))

# scope()

## separated GRU

In [14]:
class ConvBlock(nn.Module):
    def __init__(self, d_in, d_out, kernel_size, drop):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(d_in, d_out, kernel_size=kernel_size, padding='same', stride=1),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.Conv1d(d_out, d_out, kernel_size=kernel_size, padding='same', stride=1),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.Conv1d(d_out, d_out, kernel_size=kernel_size, padding='same', stride=1),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=0), # reduce sequence size by 2
        )
    def forward(self, x):
        # TODO: add skip for training speed
        return self.model(x)
        
class SeparatedGRU(nn.Module):
    def __init__(self, in_channels=5, gru_hidden_size=128, drop=0.1):
        super().__init__()
        self.d_split = len(GROUPS_IDS)
        self.pre_out = in_channels * 4
        self.gru_hidden_size = gru_hidden_size
        
        self.pre_process = nn.Sequential(
            nn.BatchNorm1d(in_channels, momentum=None),
            # use conv1d as a denoiser
            # block 1
            ConvBlock(in_channels, in_channels * 2, kernel_size=3, drop=drop),
            # nn.BatchNorm1d(in_channels * 2, momentum=None),
            
            # block 2
            ConvBlock(in_channels * 2, in_channels * 4, kernel_size=5, drop=drop),
            # nn.BatchNorm1d(self.pre_out, momentum=None),

            # block 3
            ConvBlock(in_channels * 4, in_channels * 4, kernel_size=7, drop=drop),
            # nn.BatchNorm1d(self.pre_out, momentum=None),
        )
        
        # TODO: add a learnable first state for GRU or check what is the default
        self.gru = nn.GRU(self.pre_out, self.gru_hidden_size, num_layers=1, batch_first=True, bidirectional=True)

        self.post_gru = nn.Sequential(
            nn.Linear(self.gru_hidden_size * 2, self.gru_hidden_size),
            nn.ReLU(),
            nn.Linear(self.gru_hidden_size, self.gru_hidden_size),
        )

        self.head = nn.Sequential(
            nn.Linear(self.gru_hidden_size * self.d_split, self.gru_hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.Linear(self.gru_hidden_size * 2, 6)
        )

    def forward(self, x: ('batch', 'seq', 'channel')):
        # separate the input into 4 splits (LP, LL, RP, RR)
        splits = [x[:, :, group] for group in GROUPS_IDS]
        # fold it into batch so we can run in parallel
        x = einops.rearrange(t.stack(splits, dim=0), 'group batch seq channel -> (group batch) seq channel')

        # pre_process: (batch, channel, seq) → (batch / 4, channel * 4, seq)
        x = x.permute((0, 2, 1))
        x = self.pre_process(x)
        x = x.permute((0, 2, 1))

        # GRU: (batch, seq, input_size), [(2 * num_layers, batch, hidden_size)] → (batch, seq, 2 * hidden_size)
        x, _ = self.gru(x)
        x = x[:, -1, :]

        # MLP post GRU
        x = self.post_gru(x)

        # unfold the splits
        x = einops.rearrange(x, '(group batch) hidden -> batch (hidden group)', group=self.d_split)

        # head: (batch, 2 * hidden_size) → (batch, 6)
        x = self.head(x)

        # out: → (batch, 6)
        return x

def scope():
    m = SeparatedGRU().to(device)
    x, y = next(train_dataloader.__iter__())
    r = m(x.to(device))
    print(f'{r.shape=}')
    
# scope()

# train

In [15]:
GC()
# model = Model().to(device)
# model = Transformer().to(device)
model = SeparatedGRU().to(device)
# TODO: try cranking the weight decay
# TODO: try using a scheduler
opt = t.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
print(f'model has {sum(p.numel() for p in model.parameters())} params')

model has 311788 params


In [17]:
def train(model, opt, wnb=True, do_eval=True):
    model.train()
    validation_test, validation_test_label = next(test_dataloader.__iter__())
    validation_train, validation_train_label = next(train_dataloader.__iter__())

    # if wnb: wandb.init(project='kaggle-eeg-rc')
    for epoch in range(100):
        i = 0
        tq = tqdm(train_dataloader)
        for x_train, y_train in tq:
            for k in range(1): # the data reading is too slow, so force the GPU to spin
                logs = model(x_train.to(device)).log_softmax(-1)
                kl_loss = nn.KLDivLoss(reduction="batchmean")
                loss = kl_loss(logs, y_train.to(device))
                opt.zero_grad()
                loss.backward()
                opt.step()
                tq.set_description(f'loss = {loss:.4f}')
                if wnb: wandb.log({'loss': loss.item()})
        if wnb and do_eval:
            wandb.log({'validation_test-eval':   eval(model, validation_test, validation_test_label, do_eval=True)})
            wandb.log({'validation_test-train':  eval(model, validation_test, validation_test_label, do_eval=False)})
            wandb.log({'validation_train-eval':  eval(model, validation_train, validation_train_label, do_eval=True)})
            wandb.log({'validation_train-train': eval(model, validation_train, validation_train_label, do_eval=False)})
        now = datetime.now().strftime("%Y-%m-%d_%Hh%M")
        t.save(model.state_dict(), f'weights/gru-4-splits_{now}.pt')
    if wnb: wandb.finish()

train(model, opt, wnb=True)

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

  0%|          | 0/1923 [00:00<?, ?it/s]

# save / load

In [None]:
# t.save(model.state_dict(),'model-weights4.pt')

In [None]:
# model = Model().to(device)
# model.load_state_dict(t.load('model-weights3.pt', map_location=device))

In [None]:
x_train, y_train = next(train_dataloader.__iter__())
x_val, y_val = next(test_dataloader.__iter__())

print(f'train.eval(): {eval(model, x_train, y_train, do_eval=True)}')
print(f'val.eval():  {eval(model, x_val, y_val, do_eval=True)}')
print('--')
print(f'train.train(): {eval(model, x_train, y_train, do_eval=False)}')
print(f'val.train():  {eval(model, x_val, y_val, do_eval=False)}')

In [None]:
@t.no_grad()
def submit():
    model.eval()
    res = []
    # TODO: fix, read from the correct place instead
    for batch, labels in test_dataloader:
        batch = batch.to(device)
        prob = model(batch).softmax(-1)
        res.append(prob.detach().cpu())
        break
    res = t.cat(res, dim=0)
    print(res[0])
    
    pred_df = test_df[["eeg_id"]].copy()
    target_cols = [x.lower()+'_vote' for x in class_names]
    pred_df[target_cols] = res.tolist()
    sub_df = pd.read_csv(f'{BASE_PATH}/sample_submission.csv')
    sub_df = sub_df[["eeg_id"]].copy()
    sub_df = sub_df.merge(pred_df, on="eeg_id", how="left")
    sub_df.to_csv("submission.csv", index=False)
    sub_df.head()
    
# submit()