In [None]:
import torch
import os
from torch.utils.data import TensorDataset,ConcatDataset,DataLoader,WeightedRandomSampler
import matplotlib.pyplot as plt
from torch import nn
from torch.nn.functional import relu
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay,classification_report
from sklearn.model_selection import train_test_split
import numpy as np

def moving_average(data, window_size=10):
    return np.convolve(data, np.ones(window_size), 'valid') / window_size

colors = {
    'Train': '#007AFF',  # Apple Blue
    'Test': '#FF9500'    # Apple Orange
}

device = 'cuda'
path_to_pt_ekyn = f'../pt_ekyn'

In [None]:
import h5py
import torch
from torch.utils.data import Dataset, DataLoader
ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))
print(len(ids),ids)

id = ids[0]
condition = 'PF'

X,y = torch.load(f'{path_to_pt_ekyn}/{id}_{condition}.pt',weights_only=False)
X = X.unsqueeze(1)
X = X[:,:,::10] # 500 Hz -> 50 Hz

# Saving your EEG data to HDF5 (one-time setup)
with h5py.File("eeg_data.h5", "w") as f:
    # Assuming your EEG data is a tensor of shape [samples, channels, time]
    f.create_dataset("signals", data=X)  # Convert torch tensor to numpy
    f.create_dataset("labels", data=y)

# Custom Dataset for dynamic loading
class EEGDataset(Dataset):
    def __init__(self, h5_file):
        self.file = h5py.File(h5_file, "r")
        self.signals = self.file["signals"]
        self.labels = self.file["labels"]
    
    def __len__(self):
        return len(self.signals)
    
    def __getitem__(self, idx):
        signal = torch.tensor(self.signals[idx])  # Load only this sample
        label = torch.tensor(self.labels[idx])
        return signal, label

dataset = EEGDataset("eeg_data.h5")
# trainloader = DataLoader(dataset, batch_size=32, shuffle=True)
trainloader = DataLoader(TensorDataset(X,y),batch_size=32,shuffle=True)

In [None]:
Xi,yi = next(iter(trainloader))
print(Xi.shape,yi.shape,yi.argmax(dim=1).bincount())


In [None]:
class SimpleNorm(nn.Module):
    def __init__(self,eps):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.shift = nn.Parameter(torch.tensor(0.0))
    def forward(self,x):
        mean = x.flatten().mean()
        std = x.flatten().std()
        x = (x - mean) / (std + self.eps)
        return x * self.scale + self.shift
    
class CNNSleepStager(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.norm = SimpleNorm(1e-5)
        self.c1 = nn.Conv1d(in_channels=1,out_channels=4,kernel_size=7,padding='same')
        self.c2 = nn.Conv1d(in_channels=4,out_channels=8,kernel_size=5,padding='same')
        self.c3 = nn.Conv1d(in_channels=8,out_channels=16,kernel_size=3,padding='same')
        self.c4 = nn.Conv1d(in_channels=16,out_channels=32,kernel_size=3,padding='same')

        self.mp = nn.MaxPool1d(kernel_size=2)
        self.gap = nn.AdaptiveAvgPool1d(1)

        self.fc1 = nn.Linear(in_features=32,out_features=16)
        self.classifier = nn.Linear(in_features=16,out_features=3)
    def forward(self,x):
        x = self.norm(x)

        x = self.c1(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c2(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c3(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c4(x)
        x = relu(x)
        x = self.mp(x)

        x = self.gap(x)
        x = x.squeeze()

        x = self.fc1(x)
        x = self.classifier(x)
        return x
    
model = CNNSleepStager()
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-3)
criterion = nn.CrossEntropyLoss()
model.to(device)

In [None]:
trainlossi = []

for epoch in tqdm(range(5000)):
    for Xi,yi in trainloader:
        Xi,yi = Xi.to(device),yi.to(device)
        optimizer.zero_grad()
        logits = model(Xi)
        loss = criterion(logits,yi)
        loss.backward()
        optimizer.step()
        trainlossi.append(loss.item())

In [None]:
y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in trainloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))

In [None]:
ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))

X,y = torch.load(f'{path_to_pt_ekyn}/{ids[9]}_PF.pt',weights_only=False)
X = X.unsqueeze(1)
X = X[:,:,::10] # 500 Hz -> 50 Hz

testloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)

Xi,yi = next(iter(testloader))
print(Xi.shape,yi.shape,yi.argmax(dim=1).bincount())

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).cpu(),yi.argmax(dim=1)]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))