In [1]:
import torch
import os
from torch.utils.data import TensorDataset,ConcatDataset,DataLoader,WeightedRandomSampler
import matplotlib.pyplot as plt
from torch import nn
from torch.nn.functional import relu
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay,classification_report
from sklearn.model_selection import train_test_split
import numpy as np

colors = {
    'Train': '#007AFF',  # Apple Blue
    'Test': '#FF9500'    # Apple Orange
}
device = 'cuda'
path_to_pt_ekyn = f'../pt_ekyn'

def moving_average(data, window_size=10):
    return np.convolve(data, np.ones(window_size), 'valid') / window_size

def load_ekyn(id,condition,path_to_pt_ekyn=path_to_pt_ekyn):
    X,y = torch.load(f'{path_to_pt_ekyn}/{id}_{condition}.pt',weights_only=False)
    X = X.unsqueeze(1)
    X = X[:,:,::10] # 500 Hz -> 50 Hz
    return X,y

In [2]:
"""
(4) Normalize separately in training 
"""
ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))
print(ids)
trains = []
tests = []
n_rats = 16

for id in ids[:n_rats]:
    print(id)
    X,y = load_ekyn(id,'PF')
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=.2,shuffle=True,stratify=y,random_state=0)
    trains.append(X_train)
    trains.append(y_train)
    tests.append(X_test)
    tests.append(y_test)

trainloader = DataLoader(TensorDataset(*trains),batch_size=512,shuffle=True)
testloader = DataLoader(TensorDataset(*tests),batch_size=512,shuffle=True)

['A1-0', 'A1-1', 'A4-0', 'B1-0', 'B3-1', 'C1-0', 'C4-0', 'C4-1', 'D1-0', 'E1-0', 'E2-1', 'E4-0', 'E4-1', 'F1-0', 'F1-1', 'F5-1']
A1-0
A1-1
A4-0
B1-0
B3-1
C1-0
C4-0
C4-1
D1-0
E1-0
E2-1
E4-0
E4-1
F1-0
F1-1
F5-1


In [None]:
class SimpleNorm(nn.Module):
    def __init__(self,eps):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.shift = nn.Parameter(torch.tensor(0.0))
    def forward(self,x):
        mean = x.flatten().mean()
        std = x.flatten().std()
        x = (x - mean) / (std + self.eps)
        return x * self.scale + self.shift
    
class DynamicNormCNNSleepStager(nn.Module):
    def __init__(self, n_rats, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.norms = nn.ModuleList([SimpleNorm(1e-5) for _ in range(n_rats)])

        self.c1 = nn.Conv1d(in_channels=1,out_channels=64,kernel_size=7,padding='same')
        self.c2 = nn.Conv1d(in_channels=64,out_channels=64,kernel_size=5,padding='same')
        self.c3 = nn.Conv1d(in_channels=64,out_channels=64,kernel_size=3,padding='same')

        self.mp = nn.MaxPool1d(kernel_size=2)
        self.gap = nn.AdaptiveAvgPool1d(1)

        self.fc1 = nn.Linear(in_features=64,out_features=32)
        self.classifier = nn.Linear(in_features=32,out_features=3)
    def forward(self,x):
        x = torch.cat([self.norms[i](x[i]) for i in range(len(x))])
        
        x = self.c1(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c2(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c3(x)
        x = relu(x)
        x = self.mp(x)

        x = self.gap(x)
        x = x.squeeze()

        x = self.fc1(x)
        x = self.classifier(x)
        return x
    
model = DynamicNormCNNSleepStager(n_rats)
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-3)
criterion = nn.CrossEntropyLoss()
model.to(device)

DynamicNormCNNSleepStager(
  (norms): ModuleList(
    (0-15): 16 x SimpleNorm()
  )
  (c1): Conv1d(1, 64, kernel_size=(7,), stride=(1,), padding=same)
  (c2): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=same)
  (c3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same)
  (mp): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (gap): AdaptiveAvgPool1d(output_size=1)
  (fc1): Linear(in_features=64, out_features=32, bias=True)
  (classifier): Linear(in_features=32, out_features=3, bias=True)
)

: 

In [None]:
trainlossi = []
testlossi = []
window_size = 10
validation_frequency_epochs = 1
best_dev_loss = torch.inf
best_dev_loss_epoch = 0

plt.style.use('dark_background')
n_validations = 0
for epoch in tqdm(range(5000)):
    for data in trainloader:
        data = [d.to(device) for d in data]
        Xi = data[::2]
        yi = torch.cat(data[1::2])
        optimizer.zero_grad()
        logits = model(Xi)
        loss = criterion(logits,yi)
        loss.backward()
        optimizer.step()
        trainlossi.append(loss.item())

    if epoch % validation_frequency_epochs == 0:
        model.eval()
        with torch.no_grad():
            testlossi.append(torch.hstack([criterion(model([d.to(device) for d in data[::2]]),torch.cat([d.to(device) for d in data[1::2]]).to(device)).cpu() for data in testloader]).mean().item())

        if testlossi[-1] < best_dev_loss:
            best_dev_loss = testlossi[-1]
            best_dev_loss_epoch = epoch

        fig,axes = plt.subplots(nrows=1,ncols=1,figsize=(8,15),dpi=300)

        # # Adjust figure background
        # fig.patch.set_facecolor('black')

        # # Adjust axes background
        # axes.set_facecolor('black')

        x_trainlossi = torch.linspace(0,(len(testlossi)-1)*validation_frequency_epochs,len(trainlossi))
        x_testlossi = torch.linspace(0,(len(testlossi)-1)*validation_frequency_epochs,len(testlossi))

        plt.plot(x_trainlossi,trainlossi, label='Train Loss', color=colors['Train'], alpha=0.4, linewidth=1.5)

        if len(trainlossi) > window_size:
            x_trainlossi_ma = torch.linspace(window_size-1,(len(testlossi)-1)*validation_frequency_epochs,len(trainlossi)-(window_size-1))
            trainlossi_ma = moving_average(trainlossi, window_size)
            plt.plot(x_trainlossi_ma, trainlossi_ma, label='Train Loss MA', color=colors['Train'], linestyle='--', linewidth=1.5)

            plt.axvline(x=x_trainlossi_ma[trainlossi_ma.argmin()],color=colors['Train'], alpha=0.4)
            min_trainlossi_ma = torch.tensor(trainlossi_ma).min()
            plt.axhline(y=min_trainlossi_ma,color=colors['Train'], alpha=0.4)

            # Add text on the right-hand side at the orange line value
            plt.text(plt.xlim()[1] + .1, min_trainlossi_ma, f'{min_trainlossi_ma:.2f}', 
                    verticalalignment='center', horizontalalignment='left', color=colors['Train'], fontweight='bold')


        plt.plot(x_testlossi,testlossi, label='Test Loss', color=colors['Test'], alpha=1, linewidth=1.5)

        plt.axvline(x=x_testlossi[torch.tensor(testlossi).argmin()],color=colors['Test'], alpha=0.4)
        min_testlossi = torch.tensor(testlossi).min()
        plt.axhline(y=min_testlossi,color=colors['Test'], alpha=0.4)

        # Add text on the right-hand side at the orange line value
        plt.text(plt.xlim()[1] + .1, min_testlossi, f'{min_testlossi:.2f}', 
                verticalalignment='center', horizontalalignment='left', color=colors['Test'], fontweight='bold')
        

        plt.xlabel('epoch',fontweight='bold')
        plt.ylabel('loss',fontweight='bold')

        plt.ylim([0,1])
        plt.savefig('loss.jpg',bbox_inches='tight')
        plt.close()
        model.train()



 11%|█▏        | 572/5000 [27:51<44:12:07, 35.94s/it]

In [None]:
y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in trainloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))

In [None]:
ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))

X,y = torch.load(f'{path_to_pt_ekyn}/{ids[9]}_PF.pt',weights_only=False)
X = X.unsqueeze(1)
X = X[:,:,::10] # 500 Hz -> 50 Hz

testloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

Xi,yi = next(iter(testloader))
print(Xi.shape,yi.shape,yi.argmax(dim=1).bincount())

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))