In [1]:
import torch
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score

from isruc_sleep import load_dataset, load_annotations
from lightsleepnet import LightSleepNet

In [2]:
CHANNELS = ["F3-A2", "C3-A2", "O1-A2", "F4-A1", "C4-A1", "O2-A1"]
SLEEP_STAGES = ["W", "N1", "N2", "N3", "REM"]

In [3]:
subjects_data = load_dataset()
subjects_labels = load_annotations()

Loading subject 1
Loading subject 2
Loading subject 3
Loading subject 4
Loading subject 5
Loading subject 6
Loading subject 7
Loading subject 8
Loading subject 9
Loading subject 10


In [4]:
SUBJECT_TEST = 1

In [5]:
# for subject_test in range(1, 11):
for subject_test in [SUBJECT_TEST]:
    print(f"===TEST SUBJECT: {subject_test}===")

    subjects_data_train = subjects_data[:subject_test - 1] + subjects_data[subject_test:]
    subjects_label_train = subjects_labels[:subject_test - 1] + subjects_labels[subject_test:]
    subjects_data_train_tensor = torch.cat(subjects_data_train)
    subjects_label_train_tensor = torch.cat(subjects_label_train)
    train_dataset = torch.utils.data.TensorDataset(subjects_data_train_tensor, subjects_label_train_tensor)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

    subject_data_test_tensor = subjects_data[subject_test - 1]
    subject_label_test_tensor = subjects_labels[subject_test - 1]
    test_dataset = torch.utils.data.TensorDataset(subject_data_test_tensor, subject_label_test_tensor)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = LightSleepNet().to("cuda")
    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()

    for training_epoch in range(100):
        print(f"{training_epoch + 1:<5}", end="")

        train_loss = 0
        correct = 0
        total = 0
        for X_batch, y_batch in train_loader:
            X_batch.requires_grad = True

            outputs = model(X_batch).float()
            loss = criterion(outputs, y_batch)

            weights = []

            for epoch_idx in range(X_batch.size(0)):
                epoch_data = X_batch[epoch_idx].unsqueeze(0)
                epoch_loss = criterion(model(epoch_data), y_batch[epoch_idx].unsqueeze(0))

                epoch_grads = torch.autograd.grad(epoch_loss, epoch_data, retain_graph=True)[0]
                grad_norms = epoch_grads.norm(p=2, dim=1).squeeze(0)

                delta = 0.1 * grad_norms.std().item()
                density = ((grad_norms.unsqueeze(1) - grad_norms.unsqueeze(0)).abs() < delta).sum(dim=1).float()

                epoch_weight = density.mean()
                weights.append(epoch_weight.detach())

            weights = torch.tensor(weights, device="cuda")
            weights /= weights.sum()
            # print(weights)
            weighted_loss = (weights * loss).sum()
            train_loss += weighted_loss.item()

            optimizer.zero_grad()
            weighted_loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            correct += predicted.eq(y_batch).sum().item()
            total += y_batch.size(0)

        print(f"Loss: {train_loss / len(train_loader):.3f}  Accuracy: {100 * correct / total:.3f}")

===TEST SUBJECT: 1===
1    Loss: 0.785  Accuracy: 68.834
2    Loss: 0.608  Accuracy: 75.665
3    Loss: 0.590  Accuracy: 75.892
4    Loss: 0.569  Accuracy: 77.202
5    Loss: 0.560  Accuracy: 77.795
6    Loss: 0.556  Accuracy: 77.580
7    Loss: 0.541  Accuracy: 77.996
8    Loss: 0.537  Accuracy: 78.715
9    Loss: 0.518  Accuracy: 79.206
10   Loss: 0.509  Accuracy: 79.534
11   Loss: 0.508  Accuracy: 79.067
12   Loss: 0.494  Accuracy: 80.000
13   Loss: 0.502  Accuracy: 79.521
14   Loss: 0.486  Accuracy: 80.189
15   Loss: 0.479  Accuracy: 80.428
16   Loss: 0.480  Accuracy: 80.164
17   Loss: 0.479  Accuracy: 80.491
18   Loss: 0.479  Accuracy: 80.416
19   Loss: 0.474  Accuracy: 80.290
20   Loss: 0.469  Accuracy: 81.147
21   Loss: 0.466  Accuracy: 80.756
22   Loss: 0.465  Accuracy: 81.260
23   Loss: 0.465  Accuracy: 80.945
24   Loss: 0.466  Accuracy: 81.096
25   Loss: 0.464  Accuracy: 81.134
26   Loss: 0.455  Accuracy: 81.512
27   Loss: 0.449  Accuracy: 81.865
28   Loss: 0.456  Accuracy: 81.68

In [6]:
y_pred = []
test_loss = 0

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        y_pred.append(predicted.cpu())

y_pred = torch.cat(y_pred).numpy()
y_true = subject_label_test_tensor.to("cpu").numpy()

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)

print(f"Loss: {test_loss / len(test_loader):.3f}")
print(f"Accuracy: {accuracy:.3f}")
for sleep_stage, precision_val, recall_val, f1_val in zip(SLEEP_STAGES, precision, recall, f1):
    print(f"{sleep_stage:7}precision={precision_val:.3f}  recall={recall_val:.3f}  f1={f1_val:.3f}")

Loss: 1.137
Accuracy: 0.682
W      precision=0.891  recall=0.745  f1=0.812
N1     precision=0.451  recall=0.538  f1=0.490
N2     precision=0.637  recall=0.909  f1=0.749
N3     precision=0.882  recall=0.545  f1=0.674
REM    precision=0.875  recall=0.235  f1=0.371


In [8]:
confusion_matrix(y_true, y_pred)

array([[123,  27,  12,   3,   0],
       [  9,  64,  43,   0,   3],
       [  4,  19, 339,  10,   1],
       [  1,   0,  80,  97,   0],
       [  1,  32,  58,   0,  28]])