In [None]:
# export CXX=g++-8 CC=gcc-8
# pip install torcheeg

In [1]:
import torch
from torcheeg.models import ATCNet
from sklearn.model_selection import train_test_split
import pandas as pd 
import typing as tp
from pathlib import Path
import numpy as np

In [2]:
train = pd.read_csv('train.csv')

In [3]:
FilePath = tp.Union[str, Path]
Label = tp.Union[int, float, np.ndarray]
from tqdm import tqdm 
CLASSES = ["seizure_vote", "lpd_vote", "gpd_vote", "lrda_vote", "grda_vote", "other_vote"]
N_CLASSES = len(CLASSES)


class HMSHBACSpecDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        image_paths: tp.Sequence[FilePath],
        labels: tp.Sequence[Label]
    ):
        self.image_paths = image_paths
        self.labels = labels
        
        #load em all cowboy 
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index: int):
        
        label = self.labels[index]
        eeg = pd.read_parquet(self.image_paths[index])
        rows = len(eeg)
        offset = (rows-10_000)//2
        eeg = eeg.iloc[offset:offset+10_000]
        eeg = eeg[['Fp1', 'O1', 'Fp2', 'O2']]
        for j, col in enumerate(eeg.columns ) : 
            x = eeg[col].values.astype('float32')
            m = np.nanmean(x)
            if np.isnan(x).mean()<1:
                x = np.nan_to_num(x,nan=m)
            else: 
                x[:] = 0

            eeg[col] = x

        eeg = torch.tensor(eeg.values).T.unsqueeze(0)

        return eeg, label 
    
def get_path_label( train_all: pd.DataFrame):
    """Get file path and target info."""
    
    train, test = train_test_split(train_all, test_size=0.2,  stratify=train_all[CLASSES].values.argmax(axis=1)  )
                                        
    train_paths = []
    test_paths = [] 
                                                                 
    trainlabels = train[CLASSES].values                                               
    for label_id in train["eeg_id"].values:
        img_path = f"train_eegs/{label_id}.parquet"
        train_paths.append(img_path)
                             
    testlabels = test[CLASSES].values                                               
    for label_id in test["eeg_id"].values:
        img_path = f"train_eegs/{label_id}.parquet"
        test_paths.append(img_path)

    val_data = {
        "image_paths": test_paths,
        "labels": testlabels.astype('float32')}

    train_data = {
        "image_paths": train_paths,
        "labels": trainlabels.astype('float32')}
                                                                 
    return train_data, val_data 


In [4]:
train_path_label, val_path_label  = get_path_label(train)

train_dataset = HMSHBACSpecDataset(**train_path_label)
val_dataset = HMSHBACSpecDataset(**val_path_label)

In [5]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, num_workers=6, pin_memory=True,  persistent_workers=True,  shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=64, num_workers=2,  pin_memory=True, persistent_workers=True, shuffle=False, drop_last=True)

In [6]:
#setup basic train loop
device='cuda:0'
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    record_step = int(len(dataloader) / 20)

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        X = batch[0].to(device)
        y = batch[1].to(device)

        # Compute prediction error
        pred = model(X)
        loss = torch.nn.functional.kl_div  (pred.softmax(1), y , reduction='batchmean') 

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % record_step == 0:
            loss, current = loss.item(), batch_idx * len(X)
            print(f"Loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return loss


# validation process
def valid(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X = batch[0].to(device)
            y = batch[1].to(device)

            pred = model(X)
            loss += torch.nn.functional.kl_div  (pred.softmax(1), y , reduction='batchmean').item()
    loss /= num_batches
    print(f"Valid Error: \n Avg loss: {loss:>8f} \n" )
    return loss


In [None]:

model = ATCNet(
    in_channels=1,
    num_classes=6,
    num_windows=8, 
    num_electrodes=4,
    chunk_size=10000,
)

model = model.to('cuda:0')
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=9e-4
)
epochs=1000

best_val_loss=100. 
for t in range(epochs):
    train_loss = train(train_loader, model, None, optimizer) 
    valid_loss = valid(val_loader, model, None)
 
    if valid_loss < best_val_loss:
        _ = torch.save(model.state_dict(),f'/storage/modeling/atcn/model{t}.pt')


  return F.conv2d(input, weight, bias, self.stride,


Loss: 10.135015  [    0/85440]
Loss: 9.621571  [ 4224/85440]
Loss: 14.351191  [ 8448/85440]
Loss: 12.452726  [12672/85440]
Loss: 8.562347  [16896/85440]
Loss: 11.460200  [21120/85440]
Loss: 11.767521  [25344/85440]
Loss: 10.154109  [29568/85440]
Loss: 11.102566  [33792/85440]
Loss: 10.552280  [38016/85440]
Loss: 10.502416  [42240/85440]
Loss: 9.232741  [46464/85440]
Loss: 8.892500  [50688/85440]
Loss: 11.117680  [54912/85440]
Loss: 10.803099  [59136/85440]
Loss: 9.676501  [63360/85440]
Loss: 10.622193  [67584/85440]
Loss: 9.440865  [71808/85440]
Loss: 10.427096  [76032/85440]
Loss: 10.394041  [80256/85440]
Loss: 11.740202  [84480/85440]
Valid Error: 
 Avg loss: 10.999078 

Loss: 10.798790  [    0/85440]
Loss: 10.092026  [ 4224/85440]
Loss: 10.518681  [ 8448/85440]
Loss: 11.969891  [12672/85440]
Loss: 10.131868  [16896/85440]
Loss: 12.109968  [21120/85440]
Loss: 9.401199  [25344/85440]
Loss: 9.719956  [29568/85440]
Loss: 8.239702  [33792/85440]
Loss: 10.159458  [38016/85440]
Loss: 9.966

In [None]:
a = torch.tensor([[1,2,3.]])


In [None]:
torch.nn.functional.kl_div(a.softmax(1), torch.tensor([0.1,0.2,0.7]))

In [None]:
for b,v in train_loader: break