In [None]:
import torch
import os
from torch.utils.data import TensorDataset,ConcatDataset,DataLoader
import matplotlib.pyplot as plt
from torch import nn
from torch.nn.functional import relu
from tqdm import tqdm
import seaborn as sns

device = 'mps'
path_to_pt_ekyn = f'../../sleep/pt_ekyn'

In [None]:
recording_filenames = os.listdir(path_to_pt_ekyn)
id = recording_filenames[0]
print(id)
X,y = torch.load(f'{path_to_pt_ekyn}/{id}')

X = X.unsqueeze(1)
X = X[:,:,::10] # 500 Hz -> 50 Hz

dataloader = DataLoader(TensorDataset(X,y),batch_size=512,shuffle=True)
Xi,yi = next(iter(dataloader))
print(Xi.shape,yi.shape)

In [None]:
class SimpleNorm(nn.Module):
    def __init__(self,eps):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.shift = nn.Parameter(torch.tensor(0.0))
    def forward(self,x):
        mean = x.flatten().mean()
        std = x.flatten().std()
        x = (x - mean) / (std + self.eps)
        return x * self.scale + self.shift
    
class CNNSleepStager(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.norm = SimpleNorm(1e-5)
        self.c1 = nn.Conv1d(in_channels=1,out_channels=64,kernel_size=7)
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(in_features=64,out_features=3)
    def forward(self,x):
        x = self.norm(x)
        x = self.c1(x)
        x = relu(x)
        x = self.gap(x)
        x = x.squeeze()
        x = self.classifier(x)
        return x
    
model = CNNSleepStager()
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-4)
criterion = nn.CrossEntropyLoss()
model.to(device)

In [None]:
lossi = []

for epoch in tqdm(range(200)):
    for Xi,yi in dataloader:
        Xi,yi = Xi.to(device),yi.to(device)
        optimizer.zero_grad()
        logits = model(Xi)
        loss = criterion(logits,yi)
        loss.backward()
        optimizer.step()
        lossi.append(loss.item())
plt.plot(lossi)