In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import seaborn as sns

import torch
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
import glob
from tqdm.notebook import tqdm, trange
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# preprocessing and loading functionalities

In [None]:
train_df = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/train.csv')

processed_train = train_df.groupby('eeg_id')[['patient_id']].agg('first')
class_score_train = train_df.groupby('eeg_id')[train_df.columns[-6:]].agg('sum')
class_score_train[train_df.columns[-6:]] = class_score_train.values / class_score_train.sum(axis=1).values.reshape(-1, 1)

processed_train = pd.concat([processed_train, class_score_train], axis=1)
processed_train = processed_train.reset_index()

In [None]:
EEG_PATH = '/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/'
COLS_INTEREST = ['Fp1', 'C3', 'F7', 'T5', 'Fz', 'Cz', 'Pz', 'Fp2', 'F4', 'C4', 'P4', 'F8', 'T4', 'T6', 'O2', 'EKG']

In [None]:
from torch.utils.data import Dataset, DataLoader
from scipy.signal import butter, sosfilt
from sklearn.preprocessing import MinMaxScaler
from sklearn.impute import SimpleImputer

class EEGDataset(Dataset):
    def __init__(self, csv, eeg_path):
        super().__init__()
        self.csv = csv
        self.eeg_path = eeg_path
        self.eeg_file_paths = glob.glob(self.eeg_path+'*')
        
    def __len__(self):
        return len(self.csv)
    
    def standarize(self, eeg):
        scaler = MinMaxScaler()
        eeg = scaler.fit_transform(eeg)
        
        return eeg
    
    def butter_lowpass_filter(self, eeg):
        filtered_eeg = np.zeros(eeg.shape)
        for i, signal in enumerate(eeg):
            sos = butter(4, 30, output='sos', fs=200)
            filtered_signal = sosfilt(sos, signal)
            
            filtered_eeg[i,:] = filtered_signal
            
        del eeg
        return filtered_eeg
        
    def read_eeg(self, eeg_file_path):
        eeg = pd.read_parquet(eeg_file_path)
        
        imputer = SimpleImputer()
        eeg = imputer.fit_transform(eeg)
        #eeg = eeg.values
        eeg = self.standarize(eeg)
        eeg = eeg.T
        
        #select slice of len 10_000
        offset = int((np.clip(np.random.randn(), -2, 2) + 2)*(eeg.shape[1] - 10_000) / 4)
        eeg = eeg[:, offset:offset+10_000]
        
        #clean eeg
        eeg = self.butter_lowpass_filter(eeg)
        
        eeg = np.expand_dims(eeg, axis=0)
        
        return eeg 
    
    def __getitem__(self, idx):
        y = self.csv.loc[idx, self.csv.columns[-6:]].values.reshape(-1,)
        eeg_id = self.csv.loc[idx, 'eeg_id']
        
        eeg_file_path = self.eeg_path + str(eeg_id) + '.parquet'
        X = self.read_eeg(eeg_file_path)
        
        return X, y
        
        
dataset = EEGDataset(processed_train, EEG_PATH)
X, y = dataset[10]
X.shape, y.shape

# Some Visualization

In [None]:
X,_ = dataset[100]

plt.subplots(20, 1, figsize=(15, 10))
for i, signal in enumerate(X):
    plt.subplot(20, 1, i+1)
    plt.plot(signal[:,0])
plt.show()    

In [None]:
%%time
import torch.nn as nn

class EEGNet(nn.Module):
    def __init__(self,eeg_channels, num_temporal_filters, num_spacial_channels, num_classes):
        super().__init__()
        self.num_temporal_filters = num_temporal_filters
        self.num_spacial_channels = num_spacial_channels
        self.eeg_channels = eeg_channels
        self.num_classes = num_classes

        self.parallel_networks = nn.ModuleList([nn.Sequential(
            nn.Conv2d(1, 1, (1, 4*(2**i)), padding='same'),
            self._spacial_filter(4*(2**i))
        ) for i in range(num_temporal_filters)])  
        
        self.conv_block1 = self._conv_block(num_temporal_filters*num_spacial_channels, 64, (1, 64), (1, 4))
        self.conv_block2 = self._conv_block(64, 16, (1, 64), (1, 2))
        
        self.linear1 = nn.Linear(10_000, 64)
        self.relu = nn.ReLU()
        
        self.linear2 = nn.Linear(64, 1)
        self.linear3 = nn.Linear(1, num_classes)
        self.softmax = nn.Softmax(dim=1)
     
    def _conv_block(self, in_features, out_features, kernel, pool=False, padding='same'):
        return nn.Sequential(
            nn.Conv2d(in_features, out_features,kernel, padding=padding),
            nn.ELU(0.5),
            nn.AvgPool2d(pool) if pool else nn.AvgPool2d(1),
            nn.Dropout(0.2)
        )
        
    def _spacial_filter(self, kernel_width):
        return nn.Sequential(
            self._conv_block(1, 64, (self.eeg_channels, 1), padding='valid'),
            self._conv_block(64, self.num_spacial_channels, (1, kernel_width), pool=(1, 2))
        )    
    
    def forward(self, x, is_train=True):

        x = torch.concat([parallel_filter(x) for parallel_filter in self.parallel_networks],dim=1)
        #print(x.shape)
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        #print(x.shape)
        
        x = torch.flatten(x, start_dim=1)
        x = self.linear1(x)
        x = self.relu(x)
        
        x = self.linear2(x)
        x = self.relu(x)
        
        x = self.linear3(x)
        if not is_train:
            x = self.softmax(x)
            return x
        x = nn.functional.log_softmax(x, dim=1) 
        return x
            
model = EEGNet(20, 5, 4, 6)

def test_model(model):
    x = torch.randn((10, 1, 20, 10_000))
    y = model(x)
    
    return y.shape, y, y.sum(axis=1)
            
test_model(model)            
            

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

train_df, test_val_df = train_test_split(processed_train, test_size=0.25)
test_df, val_df = train_test_split(test_val_df, test_size=0.5)

train_data = EEGDataset(train_df.reset_index(), EEG_PATH)
test_data = EEGDataset(test_df.reset_index(), EEG_PATH)
val_data = EEGDataset(val_df.reset_index(), EEG_PATH)

train_data = DataLoader(train_data, batch_size=64)
test_data = DataLoader(test_data, batch_size=64)
val_data = DataLoader(val_data, batch_size=64)

len(train_data), len(test_data), len(val_data)

In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

model = EEGNet(20, 5, 4, 6)
model.to(device)
lr = 1e-2
optimizer = Adam(model.parameters(), lr)
#scheduler = StepLR(optimizer, step_size=201, gamma=0.6)
criterion = nn.KLDivLoss(reduction='mean')

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
EPOCHS=5

for epoch in trange(EPOCHS):
    for i, (X, y) in enumerate(tqdm(train_data)):
        X, y = X.to(device), y.to(device)
        X, y = X.to(torch.float), y.to(torch.float) 
        model.train()
        optimizer.zero_grad()
        y_pred = model(X)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        #scheduler.step()
        if i%100==0:
            print(loss)
        
    avg_loss = 0    
    for X, y in tqdm(val_data):
        X, y = X.to(torch.float), y.to(torch.float)
        X, y = X.to(device), y.to(device)
        model.eval()
        y_pred = model(X)
        loss = criterion(y_pred, y)
        avg_loss += loss.detach().cpu().numpy()
        #avg_loss = torch.tensor(avg_loss)
    
    print(f'EHOCH:{epoch}, VAL_LOSS: {avg_loss / len(val_data)}')

In [None]:
torch.save(model.state_dict(), '/kaggle/working/model')