In [9]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader,ConcatDataset,TensorDataset
from torch.utils.data.sampler import WeightedRandomSampler
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score,classification_report
import numpy as np
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt

def moving_average(data, window_size=10):
    """Compute the moving average of a list."""
    return np.convolve(data, np.ones(window_size), 'valid') / window_size

class SleepStageClassifier(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.c1 = nn.Conv1d(in_channels=1,out_channels=64,kernel_size=3)
        self.gap = nn.AdaptiveAvgPool1d(output_size=1)
        self.fc1 = nn.Linear(in_features=64,out_features=3)
    def forward(self,x):
        x = self.c1(x)
        x = nn.functional.relu(x)
        x = self.gap(x)
        x = x.flatten(1,2)
        x = self.fc1(x)
        return x
    
def get_dataloaders():
    X,y = torch.load(f'pt_ekyn_robust_50hz/A1-0_PF.pt')
    X = X.unsqueeze(1)
    dataset = TensorDataset(X,y)
    labels = torch.argmax(y, dim=1)
    class_counts = torch.bincount(labels)
    num_classes = len(class_counts)
    class_weights = 1. / class_counts.float()
    weights = class_weights[labels]
    trainloader = DataLoader(dataset, batch_size=32, sampler=WeightedRandomSampler(weights, num_samples=len(weights), replacement=True))

    X,y = torch.load(f'pt_ekyn_robust_50hz/A1-1_Vehicle.pt')
    X = X.unsqueeze(1)
    dataset = TensorDataset(X,y)
    labels = torch.argmax(y, dim=1)
    class_counts = torch.bincount(labels)
    num_classes = len(class_counts)
    class_weights = 1. / class_counts.float()
    weights = class_weights[labels]
    testloader = DataLoader(dataset, batch_size=512, sampler=WeightedRandomSampler(weights, num_samples=len(weights), replacement=True))
    return trainloader,testloader

def evaluate(dataloader,model,criterion,device):
    with torch.no_grad():
        p = torch.vstack([torch.hstack([model(Xi.to(device)),yi.to(device)]) for Xi,yi in dataloader]).cpu()
        p = torch.hstack([p,p[:,:3].softmax(dim=1).argmax(axis=1).unsqueeze(1)])
        logits = p[:,:3]
        y_true = p[:,3:6].argmax(axis=1)
        y_pred = p[:,6:]
        f1 = f1_score(y_true,y_pred,average='macro')
        loss = criterion(logits,y_true).item()
        report = classification_report(y_pred=y_pred,y_true=y_true,output_dict=True)
    return loss,f1,report

In [10]:
experiments_dir = 'data'
device = 'mps'

trainloader,testloader = get_dataloaders()
model = SleepStageClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(params=model.parameters(),lr=3e-4,weight_decay=1e-2)

In [11]:
model.to(device)

trainlossi = []
testlossi = []
trainf1i = []
trainf1p = []
trainf1s = []
trainf1w = []
testf1i = []
testf1p = []
testf1s = []
testf1w = []
best_dev_loss = torch.inf
best_dev_loss_epoch = 0
best_dev_f1 = 0
best_dev_f1_epoch = 0

In [None]:
window_size = 10
for epoch in tqdm(range(1000)):
    for Xi, yi in trainloader:
        Xi, yi = Xi.to(device), yi.to(device)
        logits = model(Xi)
        loss = criterion(logits, yi)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate
    loss, f1, report = evaluate(trainloader, model, criterion, device)
    trainlossi.append(loss)
    trainf1i.append(report['macro avg']['f1-score'])
    trainf1p.append(report['0']['f1-score'])
    trainf1s.append(report['1']['f1-score'])
    trainf1w.append(report['2']['f1-score'])
    
    loss, f1, report = evaluate(testloader, model, criterion, device)
    testlossi.append(loss)
    testf1i.append(report['macro avg']['f1-score'])
    testf1p.append(report['0']['f1-score'])
    testf1s.append(report['1']['f1-score'])
    testf1w.append(report['2']['f1-score'])

    # Update best models
    if testlossi[-1] < best_dev_loss:
        torch.save(model.state_dict(), 'model_bestdevloss.pt')
        best_dev_loss = testlossi[-1]
        best_dev_loss_epoch = epoch
    if testf1i[-1] > best_dev_f1:
        torch.save(model.state_dict(), 'model_bestdevf1.pt')
        best_dev_f1 = testf1i[-1]
        best_dev_f1_epoch = epoch

    # Plotting
    fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(6,10))

    # Define colors for train and test
    train_color = '#1f77b4'
    test_color = '#ff7f0e'

    # First subplot - Loss
    ax[0].plot(trainlossi, label='Train Loss', color=train_color, alpha=0.4)
    ax[0].plot(testlossi, label='Test Loss', color=test_color, alpha=0.4)
    # Moving average for loss
    if len(trainlossi) > window_size:
        ax[0].plot(range(window_size-1, len(trainlossi)), moving_average(trainlossi, window_size), label='Train Loss MA', color=train_color, linestyle='--')
    if len(testlossi) > window_size:
        ax[0].plot(range(window_size-1, len(testlossi)), moving_average(testlossi, window_size), label='Test Loss MA', color=test_color, linestyle='--')
    ax[0].axhline(best_dev_loss, color='r', linestyle=':')
    ax[0].axvline(best_dev_loss_epoch, color='r', linestyle=':')
    ax[0].legend()

    # Second subplot - F1 Score
    ax[1].plot(trainf1i, label='Train F1', color=train_color, alpha=0.4)
    ax[1].plot(testf1i, label='Test F1', color=test_color, alpha=0.4)
    # Moving average for F1
    if len(trainf1i) > window_size:
        ax[1].plot(range(window_size-1, len(trainf1i)), moving_average(trainf1i, window_size), label='Train F1 MA', color=train_color, linestyle='--')
    if len(testf1i) > window_size:
        ax[1].plot(range(window_size-1, len(testf1i)), moving_average(testf1i, window_size), label='Test F1 MA', color=test_color, linestyle='--')
    ax[1].axhline(best_dev_f1, color='r', linestyle=':')
    ax[1].axvline(best_dev_f1_epoch, color='r', linestyle=':')
    ax[1].legend()

    # Third subplot - F1 Scores by class
    labels = ['Paradoxical', 'Slow Wave', 'Wakefulness']
    for i, (train, test, label) in enumerate(zip([trainf1p, trainf1s, trainf1w], [testf1p, testf1s, testf1w], labels)):
        # ax[2].plot(train, label=f'{label} (Train)', color=train_color, linestyle='-', alpha=0.4)
        # ax[2].plot(test, label=f'{label} (Test)', color=test_color, linestyle='--', alpha=0.4)
        # Moving average for each class
        if len(train) > window_size:
            ax[2].plot(range(window_size-1, len(train)), moving_average(train, window_size), color=train_color, linestyle='-.')
        if len(test) > window_size:
            ax[2].plot(range(window_size-1, len(test)), moving_average(test, window_size), color=test_color, linestyle=':')

    plt.tight_layout()
    plt.savefig('loss_with_ma.jpg')
    plt.close()
    
    torch.save(model.state_dict(), 'model.pt')

 17%|█▋        | 173/1000 [04:06<20:22,  1.48s/it]