In [None]:
import torch
import os
from torch.utils.data import TensorDataset,ConcatDataset,DataLoader,WeightedRandomSampler
import matplotlib.pyplot as plt
from torch import nn
from torch.nn.functional import relu
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay,classification_report
from sklearn.model_selection import train_test_split
import numpy as np

colors = {
    'Train': '#007AFF',  # Apple Blue
    'Test': '#FF9500'    # Apple Orange
}
device = 'cuda'
path_to_pt_ekyn = f'../pt_ekyn'

def moving_average(data, window_size=10):
    return np.convolve(data, np.ones(window_size), 'valid') / window_size

def load_ekyn(id,condition,path_to_pt_ekyn=path_to_pt_ekyn):
    X,y = torch.load(f'{path_to_pt_ekyn}/{id}_{condition}.pt',weights_only=False)
    X = X.unsqueeze(1)
    X = X[:,:,::10] # 500 Hz -> 50 Hz
    return X,y

In [None]:
"""
Don't Normalize at all?
"""
# Normalize together before training
# ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))
# print(ids[:2])
# X = torch.cat([load_ekyn(id,'PF')[0] for id in ids[:2]])
# y = torch.cat([load_ekyn(id,'PF')[1] for id in ids[:2]])
# mean = X.flatten().mean()
# std = X.flatten().std()
# X = (X - mean) / (std + 1e-5)

# trainloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

"""
Normalize separately before training
"""
# ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))
# print(ids[:2])
# X1,y1 = load_ekyn(ids[0],'PF')
# X2,y2 = load_ekyn(ids[1],'PF')

# mean = X1.flatten().mean()
# std = X1.flatten().std()
# X1 = (X1 - mean) / (std + 1e-5)

# mean = X2.flatten().mean()
# std = X2.flatten().std()
# X2 = (X2 - mean) / (std + 1e-5)

# X = torch.cat([X1,X2])
# y = torch.cat([y1,y2])

# trainloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

"""
Normalize together in training
"""
ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))
print(ids[:2])
X1,y1 = load_ekyn(ids[0],'PF')
X2,y2 = load_ekyn(ids[1],'PF')

X = torch.cat([X1,X2])
y = torch.cat([y1,y2])

trainloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

"""
(4) Normalize separately in training 
"""
# ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))
# print(ids[:2])
# X1,y1 = load_ekyn(ids[0],'PF')
# X2,y2 = load_ekyn(ids[1],'PF')

# trainloader = DataLoader(TensorDataset(X1,y1,X2,y2),batch_size=512,shuffle=True)

In [None]:
class SimpleNorm(nn.Module):
    def __init__(self,eps):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.shift = nn.Parameter(torch.tensor(0.0))
    def forward(self,x):
        mean = x.flatten().mean()
        std = x.flatten().std()
        x = (x - mean) / (std + self.eps)
        return x * self.scale + self.shift
    
class CNNSleepStager(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.norm1 = SimpleNorm(1e-5)
        # self.norm2 = SimpleNorm(1e-5)

        self.c1 = nn.Conv1d(in_channels=1,out_channels=64,kernel_size=7,padding='same')
        self.c2 = nn.Conv1d(in_channels=64,out_channels=128,kernel_size=5,padding='same')

        self.mp = nn.MaxPool1d(kernel_size=2)
        self.gap = nn.AdaptiveAvgPool1d(1)

        self.fc1 = nn.Linear(in_features=128,out_features=64)
        self.classifier = nn.Linear(in_features=64,out_features=3)
    # def forward(self,x1,x2):
    def forward(self,x):
        x = self.norm1(x)
        # x = torch.cat([self.norm1(x1),self.norm2(x2)])

        x = self.c1(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c2(x)
        x = relu(x)
        x = self.mp(x)

        x = self.gap(x)
        x = x.squeeze()

        x = self.fc1(x)
        x = self.classifier(x)
        return x
    
model = CNNSleepStager()
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-3)
criterion = nn.CrossEntropyLoss()
model.to(device)

In [None]:
trainlossi = []
testlossi = []
window_size = 10
validation_frequency_epochs = 20
best_dev_loss = torch.inf
best_dev_loss_epoch = 0

plt.style.use('dark_background')
n_validations = 0
for epoch in tqdm(range(5000)):
    for Xi,yi in trainloader:
    # for Xi1,yi1,Xi2,yi2 in trainloader:
        Xi,yi = Xi.to(device),yi.to(device)
        # Xi1,yi1,Xi2,yi2 = Xi1.to(device),yi1.to(device),Xi2.to(device),yi2.to(device)
        # yi = torch.cat([yi1,yi2])
        optimizer.zero_grad()
        # logits = model(Xi1,Xi2)
        logits = model(Xi)
        loss = criterion(logits,yi)
        loss.backward()
        optimizer.step()
        trainlossi.append(loss.item())

    if epoch % validation_frequency_epochs == 0:
        n_validations += 1
        fig,axes = plt.subplots(nrows=1,ncols=1,figsize=(8,15),dpi=300)

        # Train
        x_trainlossi = torch.linspace(0,(n_validations-1)*validation_frequency_epochs,len(trainlossi))
        plt.plot(x_trainlossi,trainlossi, label='Train Loss', color=colors['Train'], alpha=0.4, linewidth=1.5)

        if len(trainlossi) > window_size:
            x_trainlossi_ma = torch.linspace(window_size-1,(n_validations-1)*validation_frequency_epochs,len(trainlossi)-(window_size-1))
            trainlossi_ma = moving_average(trainlossi, window_size)
            plt.plot(x_trainlossi_ma, trainlossi_ma, label='Train Loss MA', color=colors['Train'], linestyle='--', linewidth=1.5)

            # plt.axvline(x=x_trainlossi_ma[trainlossi_ma.argmin()],color=colors['Train'], alpha=0.4)
            # min_trainlossi_ma = torch.tensor(trainlossi_ma).min()
            mean_trainlossi_ma = trainlossi_ma[-50:].mean()
            # plt.axhline(y=min_trainlossi_ma,color=colors['Train'], alpha=0.4)
            plt.axhline(y=mean_trainlossi_ma,color=colors['Train'], alpha=0.4)

            # Add text on the right-hand side at the orange line value
            plt.text(plt.xlim()[1] + .1, mean_trainlossi_ma, f'{mean_trainlossi_ma:.2f}', 
                    verticalalignment='center', horizontalalignment='left', color=colors['Train'], fontweight='bold')

        plt.xlabel('epoch',fontweight='bold')
        plt.ylabel('loss',fontweight='bold')
        plt.ylim([0,1])
        plt.savefig('loss.jpg',bbox_inches='tight')
        plt.close()



In [None]:
y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in trainloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))

In [None]:
ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))

X,y = torch.load(f'{path_to_pt_ekyn}/{ids[9]}_PF.pt',weights_only=False)
X = X.unsqueeze(1)
X = X[:,:,::10] # 500 Hz -> 50 Hz

testloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

Xi,yi = next(iter(testloader))
print(Xi.shape,yi.shape,yi.argmax(dim=1).bincount())

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))