In [None]:
%load_ext autoreload
%autoreload 2
from lib.ekyn import *
from sklearn.model_selection import train_test_split
from torch import bincount
from lib.utils import plot_eeg_and_labels
from lib.models import MLP
from tqdm import tqdm
from torch import optim
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset,DataLoader
import torch
import numpy as np
import random
from lib.utils import evaluate
from sklearn.metrics import ConfusionMatrixDisplay,classification_report

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [None]:
idx = get_ekyn_ids()
train_idx,test_idx = train_test_split(idx,test_size=.25,random_state=0)
print(len(train_idx),len(test_idx))
train_idx = train_idx[:2]
test_idx = test_idx
print(train_idx,test_idx)

X,y = load_eeg_label_pairs(ids=train_idx)
print(X.shape,y.shape)
print(bincount(y.argmax(axis=1)))
plot_eeg_and_labels(X,y.argmax(axis=1),start=0,duration=50)
trainloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)
X,y = load_eeg_label_pairs(ids=test_idx)
print(X.shape,y.shape)
print(bincount(y.argmax(axis=1)))
plot_eeg_and_labels(X,y.argmax(axis=1),start=0,duration=50)
devloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

In [324]:
from torch.nn.functional import relu
class MODEL(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.latent_dims = [64]
        self.c1 = nn.Conv1d(in_channels=1,out_channels=self.latent_dims[0],kernel_size=10,stride=2)
        self.ln1 = nn.LayerNorm(normalized_shape=(2496))
        self.mp1 = nn.MaxPool1d(kernel_size=2,stride=2)
        
        self.c2 = nn.Conv1d(in_channels=self.latent_dims[0],out_channels=self.latent_dims[0],kernel_size=3,padding='same')
        self.ln2 = nn.LayerNorm(normalized_shape=(1248))
        self.c3 = nn.Conv1d(in_channels=self.latent_dims[0],out_channels=self.latent_dims[0],kernel_size=3,padding='same')
        self.ln3 = nn.LayerNorm(normalized_shape=(1248))

        # self.c4 = nn.Conv1d(in_channels=8,out_channels=8,kernel_size=3,padding='same')
        # self.ln4 = nn.LayerNorm(normalized_shape=(1248))
        # self.c5 = nn.Conv1d(in_channels=8,out_channels=8,kernel_size=3,padding='same')
        # self.ln5 = nn.LayerNorm(normalized_shape=(1248))
        self.classifier = nn.Sequential(
            nn.AvgPool1d(kernel_size=1248),
            nn.Flatten(start_dim=1),
            nn.Linear(self.latent_dims[0],3),
            # nn.ReLU(),
            # nn.Linear(4,3)
        )
    def forward(self,x):
        x = x.reshape(-1,1,5000)
        x = self.c1(x)
        x = self.ln1(x)
        x = relu(x)
        x = self.mp1(x)
        identity = x
        x = self.c2(x)
        x = self.ln2(x)
        x = relu(x)
        x = self.c3(x)
        x = self.ln3(x)
        x = relu(x)
        x = identity + x
        # Classifier
        x = self.classifier(x)
        return x
    
model = MODEL()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=3e-4)
params = sum([p.flatten().size()[0] for p in list(model.parameters())])
print("Params: ",params)
lossi = []
trainlossi = []
devlossi = []

Params:  10499


In [325]:
Xi,yi = next(iter(trainloader))
model(Xi).shape

torch.Size([512, 3])

In [327]:
model.train()
model.to('mps')
for epoch in tqdm(range(300)):
    for Xi,yi in trainloader:
        Xi,yi = Xi.to('mps'),yi.to('mps')
        logits = model(Xi)
        loss = criterion(logits,yi)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lossi.append(loss.item())
    if epoch % 10 == 0:
        plt.plot(torch.tensor(lossi[:len(lossi) - len(lossi)%10]).view(-1,10).mean(axis=1))
        plt.savefig('loss.jpg')
        plt.close()

        loss,_,_,_,_ = evaluate(dataloader=trainloader,model=model,criterion=criterion,DEVICE='mps')
        trainlossi.append(loss)
        loss,_,_,_,_ = evaluate(dataloader=devloader,model=model,criterion=criterion,DEVICE='mps')
        devlossi.append(loss)

        plt.plot(trainlossi)
        plt.plot(devlossi)
        plt.savefig('dev.jpg')
        plt.close()

100%|██████████| 68/68 [00:00<00:00, 82.93it/s]
100%|██████████| 135/135 [00:01<00:00, 82.73it/s]
100%|██████████| 68/68 [00:00<00:00, 78.75it/s]]
100%|██████████| 135/135 [00:01<00:00, 82.05it/s]
100%|██████████| 68/68 [00:00<00:00, 82.58it/s]]
100%|██████████| 135/135 [00:01<00:00, 83.84it/s]
100%|██████████| 68/68 [00:00<00:00, 81.97it/s]]
100%|██████████| 135/135 [00:01<00:00, 80.84it/s]
100%|██████████| 68/68 [00:00<00:00, 80.73it/s]]
100%|██████████| 135/135 [00:01<00:00, 80.43it/s]
100%|██████████| 68/68 [00:00<00:00, 82.33it/s]]
100%|██████████| 135/135 [00:01<00:00, 78.51it/s]
100%|██████████| 68/68 [00:00<00:00, 79.44it/s]]
100%|██████████| 135/135 [00:01<00:00, 82.98it/s]
100%|██████████| 68/68 [00:00<00:00, 87.29it/s]]
100%|██████████| 135/135 [00:01<00:00, 84.63it/s]
100%|██████████| 68/68 [00:00<00:00, 80.57it/s]]
100%|██████████| 135/135 [00:01<00:00, 79.07it/s]
100%|██████████| 68/68 [00:00<00:00, 87.96it/s]]
100%|██████████| 135/135 [00:01<00:00, 86.96it/s]
100%|██████

KeyboardInterrupt: 

In [None]:
print(torch.tensor(lossi[:len(lossi) - len(lossi)%10]).view(-1,10).mean(axis=1)[-1])
# best dev loss : 0.38113740152782866

In [None]:
model.to('cpu')
fig,ax = plt.subplots(nrows=len(model.c1.weight),ncols=1,figsize=(5,10))
for i,kernel in enumerate(model.c1.weight.squeeze().detach()):
    ax[i].plot(kernel)

In [None]:
Xi,yi = next(iter(trainloader))
fig,ax = plt.subplots(nrows=len(model.c1.weight),ncols=1,figsize=(5,10))
for i,kernel in enumerate(model.c1(Xi.reshape(-1,1,5000)).detach()[0]):
    ax[i].plot(kernel)

In [None]:
loss,report,y_true,y_pred,y_logits = evaluate(dataloader=trainloader,model=model,criterion=criterion,DEVICE='cpu')
ConfusionMatrixDisplay.from_predictions(y_true,y_pred)
print(classification_report(y_true,y_pred))
print(loss)

In [None]:
loss,report,y_true,y_pred,y_logits = evaluate(dataloader=devloader,model=model,criterion=criterion,DEVICE='mps')
ConfusionMatrixDisplay.from_predictions(y_true,y_pred,normalize='true')
print(classification_report(y_true,y_pred))
print(loss)