In [6]:
import torch
import os
from torch.utils.data import TensorDataset,ConcatDataset,DataLoader,WeightedRandomSampler
import matplotlib.pyplot as plt
from torch import nn
from torch.nn.functional import relu
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay,classification_report
from sklearn.model_selection import train_test_split
import numpy as np
import h5py
import torch
import os

def moving_average(data, window_size=10):
    return np.convolve(data, np.ones(window_size), 'valid') / window_size

colors = {
    'Train': '#007AFF',  # Apple Blue
    'Test': '#FF9500'    # Apple Orange
}

device = 'cuda'
path_to_pt_ekyn = f'../../pt_ekyn'
path_to_pt_snezana_mice = f'../../pt_snezana_mice'

ekyn_ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))
snezana_mice_ids = sorted(set([recording_filename.split('.')[0] for recording_filename in os.listdir(path_to_pt_snezana_mice)]))
print(len(ekyn_ids),ekyn_ids)
print(len(snezana_mice_ids),snezana_mice_ids)

16 ['A1-0', 'A1-1', 'A4-0', 'B1-0', 'B3-1', 'C1-0', 'C4-0', 'C4-1', 'D1-0', 'E1-0', 'E2-1', 'E4-0', 'E4-1', 'F1-0', 'F1-1', 'F5-1']
58 ['21-HET-1', '21-HET-10', '21-HET-11', '21-HET-12', '21-HET-13', '21-HET-2', '21-HET-3', '21-HET-4', '21-HET-5', '21-HET-7', '21-HET-8', '21-HET-9', '21-KO-1', '21-KO-10', '21-KO-11', '21-KO-12', '21-KO-2', '21-KO-3', '21-KO-4', '21-KO-5', '21-KO-6', '21-KO-7', '21-KO-8', '21-KO-9', '21-WK-1', '21-WK-10', '21-WK-11', '21-WK-12', '21-WK-13', '21-WK-15', '21-WK-16', '21-WK-17', '21-WK-18', '21-WK-2', '21-WK-3', '21-WK-4', '21-WK-5', '21-WK-6', '21-WK-8', '21-WK-9', '21-WT-1', '21-WT-10', '21-WT-12', '21-WT-13', '21-WT-2', '21-WT-3', '21-WT-4', '21-WT-5', '21-WT-6', '21-WT-7', '21-WT-8', '21-WT-9', '354', '378', '381', '382', '386', '429']


In [None]:
X, y = torch.load(f'{path_to_pt_ekyn}/{id}_PF.pt',weights_only=False)
X.shape,y.shape
trainloader = DataLoader(TensorDataset(X,y),batch_size=32,shuffle=True)

In [None]:
X, y = torch.load(f'{path_to_pt_snezana_mice}/{snezana_mice_ids[0]}.pt',weights_only=False)
X.shape,y.shape

In [None]:

# with h5py.File("eeg_data.h5", "w") as f:
#     ekyn_group = f.create_group("ekyn")
#     for i,id in enumerate(ekyn_ids):
#         rat_group = ekyn_group.create_group(f"{i}")
#         for j,condition in enumerate(['PF','Vehicle']):
#             condition_group = rat_group.create_group(f"{j}")
#             X, y = torch.load(f'{path_to_pt_ekyn}/{id}_{condition}.pt',weights_only=False)
#             X = X.unsqueeze(1)
#             X = X[:,:,::10] # 500 Hz -> 50 Hz
#             condition_group.create_dataset("X", data=X)
#             condition_group.create_dataset("y", data=y)

#             condition_group.attrs["rat_id"] = i
#             condition_group.attrs["raw_rat_id"] = id
#             condition_group.attrs["condition_id"] = j
#             condition_group.attrs["raw_condition_id"] = condition

#     snezana_mice_group = f.create_group("mice")
#     for i,id in enumerate(snezana_mice_ids):
#         mice_group = snezana_mice_group.create_group(f"{i}")
#         X, y = torch.load(f'{path_to_pt_snezana_mice}/{id}.pt',weights_only=False)
#         X = X.unsqueeze(1)
#         X = X[:,:,::10] # 500 Hz -> 50 Hz
#         mice_group.create_dataset("X", data=X)
#         mice_group.create_dataset("y", data=y)

#         mice_group.attrs["mouse_id"] = i
#         mice_group.attrs["raw_mouse_id"] = id


# print("Data saved to eeg_data.h5!")

In [7]:
with h5py.File("eeg_data.h5", "r") as f:
    trainloader = DataLoader(
        ConcatDataset(
            [
                TensorDataset(torch.from_numpy(f['ekyn'][key]['0']['X'][:]),torch.from_numpy(f['ekyn'][key]['0']['y'][:])) for key in list(f['ekyn'].keys())[:8]
            ]
        )
        ,batch_size=512,shuffle=True)
    testloader = DataLoader(
        ConcatDataset(
            [
                TensorDataset(torch.from_numpy(f['mice'][key]['X'][:]),torch.from_numpy(f['mice'][key]['y'][:])) for key in list(f['mice'].keys())[:2]
            ]
        )
        ,batch_size=512,shuffle=True)

In [8]:
Xi,yi = next(iter(trainloader))
print(Xi.shape,yi.shape,yi.argmax(dim=1).bincount())

Xi,yi = next(iter(testloader))
print(Xi.shape,yi.shape,yi.argmax(dim=1).bincount())

torch.Size([512, 1, 500]) torch.Size([512, 3]) tensor([ 33, 236, 243])
torch.Size([512, 1, 500]) torch.Size([512, 3]) tensor([ 16, 187, 309])


In [9]:
class SimpleNorm(nn.Module):
    def __init__(self,eps):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.shift = nn.Parameter(torch.tensor(0.0))
    def forward(self,x):
        mean = x.flatten().mean()
        std = x.flatten().std()
        x = (x - mean) / (std + self.eps)
        return x * self.scale + self.shift
    
class CNNSleepStager(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.norm = SimpleNorm(1e-5)
        self.c1 = nn.Conv1d(in_channels=1,out_channels=4,kernel_size=7,padding='same')
        self.c2 = nn.Conv1d(in_channels=4,out_channels=8,kernel_size=5,padding='same')
        self.c3 = nn.Conv1d(in_channels=8,out_channels=16,kernel_size=3,padding='same')
        self.c4 = nn.Conv1d(in_channels=16,out_channels=32,kernel_size=3,padding='same')

        self.mp = nn.MaxPool1d(kernel_size=2)
        self.gap = nn.AdaptiveAvgPool1d(1)

        self.fc1 = nn.Linear(in_features=32,out_features=16)
        self.classifier = nn.Linear(in_features=16,out_features=3)
    def forward(self,x):
        x = self.norm(x)

        x = self.c1(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c2(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c3(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c4(x)
        x = relu(x)
        x = self.mp(x)

        x = self.gap(x)
        x = x.squeeze()

        x = self.fc1(x)
        x = self.classifier(x)
        return x
    
model = CNNSleepStager()
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-3)
criterion = nn.CrossEntropyLoss()
model.to(device)

CNNSleepStager(
  (norm): SimpleNorm()
  (c1): Conv1d(1, 4, kernel_size=(7,), stride=(1,), padding=same)
  (c2): Conv1d(4, 8, kernel_size=(5,), stride=(1,), padding=same)
  (c3): Conv1d(8, 16, kernel_size=(3,), stride=(1,), padding=same)
  (c4): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=same)
  (mp): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (gap): AdaptiveAvgPool1d(output_size=1)
  (fc1): Linear(in_features=32, out_features=16, bias=True)
  (classifier): Linear(in_features=16, out_features=3, bias=True)
)

In [10]:
trainlossi = []
testlossi = []

window_size = 10
validation_frequency_epochs = 20
best_dev_loss = torch.inf
best_dev_loss_epoch = 0

plt.style.use('dark_background')

for epoch in tqdm(range(5000)):
    for Xi,yi in trainloader:
        Xi,yi = Xi.to(device),yi.to(device)
        optimizer.zero_grad()
        logits = model(Xi)
        loss = criterion(logits,yi)
        loss.backward()
        optimizer.step()
        trainlossi.append(loss.item())

    if epoch % validation_frequency_epochs == 0:
        model.eval()
        with torch.no_grad():
            testlossi.append(torch.hstack([criterion(model(Xi.to(device)),yi.to(device)).cpu() for Xi,yi in testloader]).mean().item())

        if testlossi[-1] < best_dev_loss:
            best_dev_loss = testlossi[-1]
            best_dev_loss_epoch = epoch

        fig,axes = plt.subplots(nrows=1,ncols=1,figsize=(8,15),dpi=300)

        # # Adjust figure background
        # fig.patch.set_facecolor('black')

        # # Adjust axes background
        # axes.set_facecolor('black')

        x_trainlossi = torch.linspace(0,(len(testlossi)-1)*validation_frequency_epochs,len(trainlossi))
        x_testlossi = torch.linspace(0,(len(testlossi)-1)*validation_frequency_epochs,len(testlossi))

        plt.plot(x_trainlossi,trainlossi, label='Train Loss', color=colors['Train'], alpha=0.4, linewidth=1.5)

        if len(trainlossi) > window_size:
            x_trainlossi_ma = torch.linspace(window_size-1,(len(testlossi)-1)*validation_frequency_epochs,len(trainlossi)-(window_size-1))
            trainlossi_ma = moving_average(trainlossi, window_size)
            plt.plot(x_trainlossi_ma, trainlossi_ma, label='Train Loss MA', color=colors['Train'], linestyle='--', linewidth=1.5)

            plt.axvline(x=x_trainlossi_ma[trainlossi_ma.argmin()],color=colors['Train'], alpha=0.4)
            min_trainlossi_ma = torch.tensor(trainlossi_ma).min()
            plt.axhline(y=min_trainlossi_ma,color=colors['Train'], alpha=0.4)

            # Add text on the right-hand side at the orange line value
            plt.text(plt.xlim()[1] + .1, min_trainlossi_ma, f'{min_trainlossi_ma:.2f}', 
                    verticalalignment='center', horizontalalignment='left', color=colors['Train'], fontweight='bold')


        plt.plot(x_testlossi,testlossi, label='Test Loss', color=colors['Test'], alpha=1, linewidth=1.5)

        plt.axvline(x=x_testlossi[torch.tensor(testlossi).argmin()],color=colors['Test'], alpha=0.4)
        min_testlossi = torch.tensor(testlossi).min()
        plt.axhline(y=min_testlossi,color=colors['Test'], alpha=0.4)

        # Add text on the right-hand side at the orange line value
        plt.text(plt.xlim()[1] + .1, min_testlossi, f'{min_testlossi:.2f}', 
                verticalalignment='center', horizontalalignment='left', color=colors['Test'], fontweight='bold')
        

        plt.xlabel('epoch',fontweight='bold')
        plt.ylabel('loss',fontweight='bold')

        plt.ylim([0,1])
        plt.savefig('loss.jpg',bbox_inches='tight')
        plt.close()
        model.train()



  1%|▏         | 64/5000 [00:38<49:53,  1.65it/s]  


KeyboardInterrupt: 

In [None]:
y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in trainloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))

In [None]:
ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))

X,y = torch.load(f'{path_to_pt_ekyn}/{ids[9]}_PF.pt',weights_only=False)
X = X.unsqueeze(1)
X = X[:,:,::10] # 500 Hz -> 50 Hz

testloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

Xi,yi = next(iter(testloader))
print(Xi.shape,yi.shape,yi.argmax(dim=1).bincount())

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))