# imports

In [1]:
from datetime import datetime
import random
# import einops
import wandb
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from functools import lru_cache

device = 'cuda' if t.cuda.is_available() else 'cpu'

# utils

In [2]:
import gc 
def GC():
    gc.collect()
    t.cuda.empty_cache()

In [3]:
# validation / inference loss
@t.no_grad()
def eval(model, x, y, do_eval=True):
    assert not x.isnan().any()
    assert not y.isnan().any()
    if do_eval: model.eval()
    else: model.train()
    logs = model(x.to(device)).log_softmax(-1)
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    loss = kl_loss(logs, y.to(device))
    model.train()
    return loss

In [4]:
def noise(data, alpha):
    std = data.std(dim=(2, 3), keepdim=True)
    noise = t.randn_like(data, device=device) * std * alpha
    return data + noise

def roll(data, shift):
    return t.roll(data, shifts=shift, dims=2)

@t.no_grad()
def augment_data(data, alpha=0.1, shift=10):
    # data → (batch, channel, seq, frequency) (e.g. (200, 4, 300, 100))
    data = data.to(device)
    data = noise(data, alpha)
    data = roll(data, shift)
    return data

# config

In [5]:
batch_size = 55
# batch_size = 200
prefetch_factor = 10
num_workers = 3

# data

In [6]:
train_path = './hms-harmful-brain-activity-classification/train_spectrograms/'
train_spec_path = './hms-harmful-brain-activity-classification/train_spectrograms/'
BASE_PATH = './hms-harmful-brain-activity-classification/'
PRE_PROCESSED_PATH = './spectrograms/'

# TODO: merge several models together
# TODO: when submitting round values close to 0 to exactly 0 and rebalance the rest to sum() == 1 for free boost

TARGETS = ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote','other_vote']

In [7]:
def remove_overlaps(df, key='spectrogram_id'):
    ''' This makes the dataset 10x smaller, but it should be closer to the leaderboard '''
    return df.groupby(key).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True)

In [8]:
train_df = pd.read_csv(f'{BASE_PATH}/train.csv')
train_df = remove_overlaps(train_df)

  return df.groupby(key).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True)


In [9]:
class Dataset(Dataset):
    def __init__(self):
        super().__init__()
        self.dataframe = train_df

    def __len__(self):
        return len(self.dataframe)

    @lru_cache(maxsize=None)
    def __getitem__(self, idx): # preprocessed version
        row = self.dataframe.iloc[idx]
        id = row['spectrogram_id']
        sub_id = row['spectrogram_sub_id']
        path = f'{PRE_PROCESSED_PATH}/{id}_{sub_id}.pt'
        data = t.load(path)
        labels = row[TARGETS].values.astype(np.float64)
        labels = labels / np.sum(labels)
        labels_out = t.tensor(labels, dtype=t.float64)
        return data, labels_out

In [10]:
# def plot_spectogram(spec_df, prefixes, title = "Spectogram"):
#     fig = sp.make_subplots(rows=len(prefixes), cols=1, subplot_titles=prefixes)
#     for i, prefix in enumerate(prefixes):
#         prefix_df = spec_df.filter(regex=f'^{prefix}', axis=1)
#         epsilon = 1e-10
#         fig.add_trace(go.Heatmap(z=np.log(prefix_df + epsilon).T,
#                                  y=pd.to_numeric(prefix_df.columns.str.replace(f"{prefix}_", '')),
#                                  coloraxis="coloraxis"),
#                       row=i+1, col=1)
#          # Update x-axis and y-axis labels
#         fig.update_xaxes(title_text="Time(Seconds)", row=i+1, col=1)
#         fig.update_yaxes(title_text="Frequency(Hz)", row=i+1, col=1)
#         # update coloraxis
#         fig.update_layout(coloraxis = {'colorscale':'Jet'}, height=1500,title_text=title)
#     fig.show()

In [11]:
dataset = Dataset()
ids = train_df['spectrogram_id'].unique()
np.random.shuffle(ids)
split = int(len(ids) * 0.95)

train_ids = ids[:split]
test_ids = ids[split:]

now = datetime.now().strftime("%Y-%m-%d_%Hh%M")
t.save(t.tensor(train_ids), f'./splits/{now}_train_ids.pt')
t.save(t.tensor(train_ids), f'./splits/{now}_test_ids.pt')

train_indices = train_df[train_df['spectrogram_id'].isin(train_ids)].index.tolist()
test_indices = train_df[train_df['spectrogram_id'].isin(test_ids)].index.tolist()

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, shuffle=True)

len(train_dataset), len(test_dataset)

(10581, 557)

# model 👯‍♀️

## conv1d + GRU

In [12]:
(*[(1, 2, 3) for _ in range(3)], )

((1, 2, 3), (1, 2, 3), (1, 2, 3))

In [22]:
class CNNSkipBlock(nn.Module):
    def __init__(self, in_channels, out_channels, repeat, dropout, stride=1):
        super(CNNSkipBlock, self).__init__()
        self.pre = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding='same', padding_mode='reflect')
        self.convs = nn.Sequential(
            *(nn.Sequential(
                nn.LeakyReLU(),
                nn.Dropout(dropout),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding='same', padding_mode='reflect')) 
            for _ in range(repeat)),
        )
        self.post = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.pre(x)
        x = x + self.convs(x) # weird skip
        return self.post(x)

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout, stride=1):
        super(CNNBlock, self).__init__()
        downsample = in_channels != out_channels
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding='same', padding_mode='reflect'),
            # nn.BatchNorm2d(out_channels), # TODO: I hate batchnorm -_-
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            *(nn.MaxPool2d(kernel_size=2, stride=2),) if downsample else ()
        )

    def forward(self, x):
        return self.conv(x)
        
class Model(nn.Module):
    # def __init__(self, convs=[4, 64, 64, 128, 128, 256, 256, 256, 257, 257, 257, 258, 258, 258, 259], hidden=256, dropout=0.4):
    # def __init__(self, convs=[4, 8, 8, 8, 8, 16, 16, 16, 16, 32, 32, 32, 32, 64, 64, 64, 128, 128, 128, 259], hidden=256, dropout=0.3):
    # def __init__(self, convs=[4, 30, 30, 30, 30, 31, 31, 31, 31, 32, 32, 32, 32, 64, 64, 64, 128, 128, 128, 259], hidden=256, dropout=0.3):
    def __init__(self, convs=[4, 30, 30, 30, 30, 64, 64, 64, 64, 90, 90, 90, 90, 127, 127, 127, 128, 128, 128, 259], hidden=256, dropout=0.3):
    # def __init__(self, convs=[4, 8, 8, 8, 16, 16, 16, 32, 32, 32, 32, 64, 64, 64, 128, 128, 256], hidden=512, dropout=0.5):
    # def __init__(self, convs=[4, 8, 8, 8, 8, 8, 16, 16, 16, 16, 16, 32, 32, 32, 32, 32, 64, 64, 64, 128, 259], hidden=256, dropout=0.4):
    # def __init__(self, convs=[4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 16, 16, 16, 16, 16, 32, 32, 32, 32, 32, 64, 64, 64, 128, 259], hidden=256, dropout=0.4):
    # def __init__(self, convs=[4, 8, 8, 8, 16, 16, 16, 32, 32, 32, 32, 32, 64, 64, 64, 64, 64, 64, 128, 128, 128, 259], hidden=256, dropout=0.4):
    # def __init__(self, convs_b=[4, 8, 16, 32, 64, 128, 256], repeat=3, hidden=256, dropout=0.5):
        super().__init__()

        # self.cnn = nn.Sequential(
        #     *[CNNSkipBlock(in_chan, out_chan, repeat=repeat, dropout=dropout) for in_chan, out_chan in zip(convs_b, convs_b[1:])]
        # )
        self.cnn = nn.Sequential(
            *[CNNBlock(in_chan, out_chan, dropout=dropout) for in_chan, out_chan in zip(convs, convs[1:])]
        )
        self.head = nn.Sequential(
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(259 * 4, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 6),
        )

    def forward(self, x):
        x = self.cnn(x)
        # print(x.shape)
        x = self.head(x)
        return x

def test():
    m = Model().to(device)
    x, y = next(train_dataloader.__iter__())
    print(f'{x.shape=}')
    r = m(x.to(device))
    print(f'{r.shape=}')
    
# test()

# train

In [23]:
GC()
model = Model().to(device)
opt = t.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
print(f'model has {sum(p.numel() for p in model.parameters())} params')

model has 1891134 params


## in memory

In [24]:
def memory_train(model, opt, wnb=True, data_augmentation=False):
    model.train()
    validation_test, validation_test_label = next(test_dataloader.__iter__())
    validation_train, validation_train_label = next(train_dataloader.__iter__())

    if wnb:
        wandb.init(project='kaggle-eeg-rc')
        now = datetime.now().strftime("%Y-%m-%d_%Hh%M")
        wandb.log({'val_test':   eval(model, validation_test, validation_test_label, do_eval=True), 'now': f'{now}'})
        wandb.log({'val_train':  eval(model, validation_train, validation_train_label, do_eval=True), 'now': f'{now}'})

    for epoch in tqdm(range(100000)):
        for x_train, y_train in tqdm(train_dataloader):
            # TODO: use variable alpha based on mini_epoch
            if data_augmentation: x_train = augment_data(x_train,
                                                         alpha=random.choice([0.01, 0.05, 0.1, 0.15]), #, 0.2])) #, 0.3]))
                                                         shift=random.choice(list(range(20))))
            logs = model(x_train.to(device)).log_softmax(-1)
            kl_loss = nn.KLDivLoss(reduction="batchmean")
            loss = kl_loss(logs, y_train.to(device))
            opt.zero_grad()
            loss.backward()
            opt.step()
            if wnb: wandb.log({'loss': loss.item()})
    
        now = datetime.now().strftime("%Y-%m-%d_%Hh%M")
        if wnb:
            wandb.log({'val_test':   eval(model, validation_test, validation_test_label, do_eval=True), 'now': f'{now}'})
            wandb.log({'val_train':  eval(model, validation_train, validation_train_label, do_eval=True), 'now': f'{now}'})
        if epoch % 10 == 0:
            t.save(model.state_dict(), f'weights/cnn_{now}.pt')
    if wnb: wandb.finish()

memory_train(model, opt, wnb=True, data_augmentation=True)



VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,██▄▅▆▃▄▃▃▃▃▄▃▃▄▃▂▂▃▃▃▃▂▂▂▄▂▃▂▁▃▄▁▂▃▃▂▁▂▂
val_test,█▄▅▄▃▂▂▂▂▂▁▂▂▂▁▂▂▂▂▂▂▂▁▂▂▁▂▃▂▁▃▁▂▂▃▂▁▃▃▂
val_train,█▅▆▅▅▃▃▂▃▃▃▃▃▂▃▃▂▃▃▂▃▂▃▂▂▁▃▁▂▁▃▃▂▃▂▃▂▁▂▁

0,1
loss,0.42017
now,2024-03-31_04h44
val_test,0.71752
val_train,0.66745


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112622676106791, max=1.0…

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

## staged

In [None]:
def populate_buffer(buffer_size=300):
    while True:
        limited_replay_buffer = []
        tq = tqdm(train_dataloader)
        for x_train, y_train in tq:
            limited_replay_buffer.append((x_train, y_train))
            if len(limited_replay_buffer) >= buffer_size:
                yield limited_replay_buffer
                del limited_replay_buffer
                limited_replay_buffer = []

def staged_train(model, opt, mini_epochs, wnb=True, data_augmentation=False):
    model.train()
    validation_test, validation_test_label = next(test_dataloader.__iter__())
    validation_train, validation_train_label = next(train_dataloader.__iter__())

    if wnb:
        wandb.init(project='kaggle-eeg-rc')
        now = datetime.now().strftime("%Y-%m-%d_%Hh%M")
        wandb.log({'val_test':   eval(model, validation_test, validation_test_label, do_eval=True), 'now': f'{now}'})
        wandb.log({'val_train':  eval(model, validation_train, validation_train_label, do_eval=True), 'now': f'{now}'})

    # rb = [next(iter(train_dataloader))]
    # for _ in range(1000):
        # replay_buffer = rb
    for replay_buffer in populate_buffer():
        for epoch in tqdm(range(mini_epochs)):
            for x_train, y_train in replay_buffer:
                # TODO: use variable alpha based on mini_epoch
                if data_augmentation: x_train = augment_data(x_train, alpha=random.choice([0.01, 0.05, 0.1])) #, 0.15, 0.2, 0.3]))
                logs = model(x_train.to(device)).log_softmax(-1)
                kl_loss = nn.KLDivLoss(reduction="batchmean")
                loss = kl_loss(logs, y_train.to(device))
                opt.zero_grad()
                loss.backward()
                opt.step()
                if wnb: wandb.log({'loss': loss.item()})
        
            now = datetime.now().strftime("%Y-%m-%d_%Hh%M")
            if wnb:
                wandb.log({'val_test':   eval(model, validation_test, validation_test_label, do_eval=True), 'now': f'{now}'})
                wandb.log({'val_train':  eval(model, validation_train, validation_train_label, do_eval=True), 'now': f'{now}'})
            t.save(model.state_dict(), f'weights/gru-4-splits_{now}.pt')
        del replay_buffer
    if wnb: wandb.finish()

# staged_train(model, opt, mini_epochs=20, wnb=True, data_augmentation=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpeluche[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/1856 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

KeyboardInterrupt: 

# save / load

In [None]:
# t.save(model.state_dict(),'model-weights4.pt')

In [None]:
# model = SeparatedGRU().to(device)
# model.load_state_dict(t.load('weights/gru-4-splits_2024-03-24_15h03.pt', map_location=device))

In [None]:
x_train, y_train = next(train_dataloader.__iter__())
x_val, y_val = next(test_dataloader.__iter__())

print(f'train(): train: {eval(model, x_train, y_train, do_eval=False)}')
print(f'train(): test:  {eval(model, x_val, y_val, do_eval=False)}')
print('--')
print(f'eval(): train   {eval(model, x_train, y_train, do_eval=True)}')
print(f'eval(): test    {eval(model, x_val, y_val, do_eval=True)}')


train(): train: 0.3067223824690105

train(): test:  0.7667477726504927

--

eval(): train   0.38911351894595714

eval(): test    0.7780192006451506


# submit

In [None]:
@t.no_grad()
def submit(model, test_dataloader, test_df):
    model.eval()
    res = []
    for batch in test_dataloader:
        prob = model(batch.to(device)).softmax(-1)
        res.append(prob.detach().cpu())

    res = t.cat(res, dim=0)
    sub = test_df[["eeg_id"]].copy()
    sub[TARGETS] = res
    sub.to_csv('submission.csv',index=False)
    print('Submission shape',sub.shape)
    display(sub.head())