In [1]:
%load_ext autoreload
%autoreload 2
from lib.ekyn import *
from sklearn.model_selection import train_test_split
from torch import bincount
from lib.utils import plot_eeg_and_labels
from lib.models import MLP
from tqdm import tqdm
from torch import optim
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset,DataLoader
import torch
import numpy as np
import random
from lib.utils import evaluate
from sklearn.metrics import ConfusionMatrixDisplay,classification_report
from lib.env import DEVICE
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [None]:
idx = get_ekyn_ids()
train_idx,test_idx = train_test_split(idx,test_size=.25,random_state=0)
print(len(train_idx),len(test_idx))
train_idx = train_idx
test_idx = test_idx
print(train_idx,test_idx)

X,y = load_eeg_label_pairs(ids=train_idx)
print(X.shape,y.shape)
print(bincount(y.argmax(axis=1)))
plot_eeg_and_labels(X,y.argmax(axis=1),start=0,duration=50)
trainloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)
X,y = load_eeg_label_pairs(ids=test_idx)
print(X.shape,y.shape)
print(bincount(y.argmax(axis=1)))
plot_eeg_and_labels(X,y.argmax(axis=1),start=0,duration=50)
devloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

In [None]:
from torch.nn.functional import relu
class BLOCK(nn.Module):
    def __init__(self,n_features,in_feature_maps,out_feature_maps) -> None:
        super().__init__()
        self.in_feature_maps = in_feature_maps
        self.out_feature_maps = out_feature_maps
        if in_feature_maps != out_feature_maps:
            self.c1 = nn.Conv1d(in_channels=in_feature_maps,out_channels=out_feature_maps,kernel_size=3,stride=2,padding=1)
        else:
            self.c1 = nn.Conv1d(in_channels=in_feature_maps,out_channels=out_feature_maps,kernel_size=3,padding='same')
        self.ln1 = nn.LayerNorm(normalized_shape=(n_features))
        self.c2 = nn.Conv1d(in_channels=out_feature_maps,out_channels=out_feature_maps,kernel_size=3,padding='same')
        self.ln2 = nn.LayerNorm(normalized_shape=(n_features))

        self.downsample = nn.Conv1d(in_channels=in_feature_maps,out_channels=out_feature_maps,kernel_size=1,stride=2)
    def forward(self,x):
        identity = x
        x = self.c1(x)
        x = self.ln1(x)
        x = relu(x)
        x = self.c2(x)
        x = self.ln2(x)
        x = relu(x)
        if self.in_feature_maps != self.out_feature_maps:
            x = x + self.downsample(identity)
        else:
            x = x + identity
        return x

class MODEL(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.c1 = nn.Conv1d(in_channels=1,out_channels=4,kernel_size=10,stride=2,padding=4)
        self.ln1 = nn.LayerNorm(normalized_shape=(2500))
        self.mp1 = nn.MaxPool1d(kernel_size=2,stride=2)

        self.block1 = BLOCK(n_features=1250,in_feature_maps=4,out_feature_maps=4)        
        self.block2 = BLOCK(n_features=1250,in_feature_maps=4,out_feature_maps=4)

        self.block3 = BLOCK(n_features=625,in_feature_maps=4,out_feature_maps=8)        
        self.block4 = BLOCK(n_features=625,in_feature_maps=8,out_feature_maps=8)

        self.block5 = BLOCK(n_features=313,in_feature_maps=8,out_feature_maps=16)        
        self.block6 = BLOCK(n_features=313,in_feature_maps=16,out_feature_maps=16)        

        self.classifier = nn.Sequential(
            nn.AvgPool1d(kernel_size=157),
            nn.Flatten(start_dim=1),
            nn.Linear(16,3),
            # nn.ReLU(),
            # nn.Linear(32,3)
        )
    def forward(self,x):
        x = x.reshape(-1,1,5000)
        x = self.c1(x)
        x = self.ln1(x)
        x = relu(x)
        x = self.mp1(x)

        x = self.block1(x)
        x = self.block2(x)

        x = self.block3(x)
        x = self.block4(x)
        
        x = self.block5(x)
        x = self.block6(x)

        # Classifier
        x = self.classifier(x)
        return x
    
model = MODEL()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=3e-4)
params = sum([p.flatten().size()[0] for p in list(model.parameters())])
print("Params: ",params)
lossi = []
trainlossi = []
devlossi = []

In [None]:
Xi,yi = next(iter(trainloader))
model(Xi).shape

In [None]:
model.train()
model.to(DEVICE)
for epoch in tqdm(range(100)):
    for Xi,yi in trainloader:
        Xi,yi = Xi.to(DEVICE),yi.to(DEVICE)
        logits = model(Xi)
        loss = criterion(logits,yi)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lossi.append(loss.item())
    if epoch % 10 == 0:
        plt.plot(torch.tensor(lossi[:len(lossi) - len(lossi)%10]).view(-1,10).mean(axis=1))
        plt.savefig('loss.jpg')
        plt.close()

        loss,_,_,_,_ = evaluate(dataloader=trainloader,model=model,criterion=criterion,DEVICE=DEVICE)
        trainlossi.append(loss)
        loss,_,_,_,_ = evaluate(dataloader=devloader,model=model,criterion=criterion,DEVICE=DEVICE)
        devlossi.append(loss)
        print(loss)

        plt.plot(trainlossi)
        plt.plot(devlossi)
        plt.savefig('dev.jpg')
        plt.close()

In [None]:
print(torch.tensor(lossi[:len(lossi) - len(lossi)%10]).view(-1,10).mean(axis=1)[-1])
# best dev loss : 0.38113740152782866
# best dev loss : 0.37259188603471827
# best dev loss : 0.3583390266806991
# best dev loss : 0.3157875328152268
# best dev loss : 0.299214439590772
# 0.2897572392666781
# 0.2731048853309066

In [None]:
model.to('cpu')
Xi,yi = next(iter(trainloader))

fig,ax = plt.subplots(nrows=len(model.c1.weight),ncols=2,figsize=(8,10))
for i,kernel in enumerate(model.c1.weight.squeeze().detach()):
    ax[i,0].plot(kernel)
for i,kernel in enumerate(model.c1(Xi.reshape(-1,1,5000)).detach()[0]):
    ax[i,1].plot(kernel)

In [None]:
loss,report,y_true,y_pred,y_logits = evaluate(dataloader=trainloader,model=model,criterion=criterion,DEVICE=DEVICE)
ConfusionMatrixDisplay.from_predictions(y_true,y_pred)
print(classification_report(y_true,y_pred))
print(loss)

In [None]:
loss,report,y_true,y_pred,y_logits = evaluate(dataloader=devloader,model=model,criterion=criterion,DEVICE=DEVICE)
ConfusionMatrixDisplay.from_predictions(y_true,y_pred,normalize='true')
print(classification_report(y_true,y_pred))
print(loss)

In [None]:
X,y = load_eeg_label_pairs(ids=test_idx[:1])
print(X.shape,y.shape)
print(bincount(y.argmax(axis=1)))
plot_eeg_and_labels(X,y.argmax(axis=1),start=0,duration=50)
testloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=False)
import matplotlib.pyplot as plt
loss,report,y_true,y_pred,y_logits = evaluate(dataloader=testloader,model=model,criterion=criterion,DEVICE=DEVICE)
ConfusionMatrixDisplay.from_predictions(y_true,y_pred,normalize='true')
print(classification_report(y_true,y_pred))
print(loss)

In [None]:
import matplotlib.patches as patches
start = 190
duration = 200
fig, ax = plt.subplots(nrows=2,ncols=1,figsize=(16,5),dpi=200)

ax[1].plot(X[start:start+duration].flatten(),'black',linewidth=.2)
colors = ['red','green','blue']
epochs = []
for i in range(duration):
    stage = int(y.argmax(axis=1)[start+i])
    ax[1].fill_between([i*5000, (i+1)*5000], y1=-.0003, y2=.0003, color=colors[stage], alpha=0.3)
    epochs.append(i*5000+2500)

red_patch = patches.Patch(color='red', alpha=0.5, label='Paradoxical')
green_patch = patches.Patch(color='green', alpha=0.5, label='Slow-wave')
blue_patch = patches.Patch(color='blue', alpha=0.5, label='Wakefulness')
ax[1].set_ylim([-.0003,.0003])
ax[1].margins(0,0)
plt.legend(handles=[red_patch, green_patch,blue_patch],loc='upper left', bbox_to_anchor=(1.04, 1),
        fancybox=True, shadow=True, ncol=1)
plt.xlabel('epoch (index)')
ax[1].set_ylabel('potential energy (Volts)')
ax[1].set_xticks(epochs[::int(duration/20)],range(duration)[::int(duration/20)]);

ax[0].stackplot(torch.linspace(0,duration-1,duration),y_logits[start:start+duration,0],y_logits[start:start+duration,1],y_logits[start:start+duration,2],colors=['#FF000080','#00FF0080','#0000FF80'])
ax[0].margins(0,0)
ax[0].set_xticks([])