In [8]:
import mne
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

from isruc_sleep import load_dataset, load_annotations
from lightsleepnet import LightSleepNet

In [2]:
CHANNELS = ["F3-A2", "C3-A2", "O1-A2", "F4-A1", "C4-A1", "O2-A1"]
SLEEP_STAGES = ["W", "N1", "N2", "N3", "REM"]

In [3]:
subjects_data = load_dataset()
subjects_labels = load_annotations()

Loading subject 1
Loading subject 2
Loading subject 3
Loading subject 4
Loading subject 5
Loading subject 6
Loading subject 7
Loading subject 8
Loading subject 9
Loading subject 10


In [5]:
SUBJECT_TEST = 1

In [6]:
# for subject_test in range(1, 11):
for subject_test in [SUBJECT_TEST]:
    print(f"===TEST SUBJECT: {subject_test}===")

    subjects_data_train = subjects_data[:subject_test - 1] + subjects_data[subject_test:]
    subjects_label_train = subjects_labels[:subject_test - 1] + subjects_labels[subject_test:]
    subjects_data_train_tensor = torch.cat(subjects_data_train)
    subjects_label_train_tensor = torch.cat(subjects_label_train)
    train_dataset = torch.utils.data.TensorDataset(subjects_data_train_tensor, subjects_label_train_tensor)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

    subject_data_test_tensor = subjects_data[subject_test - 1]
    subject_label_test_tensor = subjects_labels[subject_test - 1]
    test_dataset = torch.utils.data.TensorDataset(subject_data_test_tensor, subject_label_test_tensor)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = LightSleepNet().to("cuda")
    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()

    for training_epoch in range(100):
        print(f"{training_epoch + 1:<5}", end="")

        train_loss = 0
        correct = 0
        total = 0
        for X_batch, y_batch in train_loader:
            X_batch.requires_grad = True

            outputs = model(X_batch).float()
            loss = criterion(outputs, y_batch)

            weights = []

            for epoch_idx in range(X_batch.size(0)):
                epoch_data = X_batch[epoch_idx].unsqueeze(0)
                epoch_loss = criterion(model(epoch_data), y_batch[epoch_idx].unsqueeze(0))

                epoch_grads = torch.autograd.grad(epoch_loss, epoch_data, retain_graph=True)[0]
                grad_norms = epoch_grads.norm(p=2, dim=1).squeeze(0)

                delta = 0.1 * grad_norms.std().item()
                density = ((grad_norms.unsqueeze(1) - grad_norms.unsqueeze(0)).abs() < delta).sum(dim=1).float()

                epoch_weight = density.mean()
                weights.append(epoch_weight.detach())

            weights = torch.tensor(weights, device="cuda")
            weights /= weights.sum()
            # print(weights)
            weighted_loss = (weights * loss).sum()
            train_loss += weighted_loss.item()

            optimizer.zero_grad()
            weighted_loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            correct += predicted.eq(y_batch).sum().item()
            total += y_batch.size(0)

        print(f"Loss: {train_loss / len(train_loader):.3f}  Accuracy: {100 * correct / total:.3f}")

===TEST SUBJECT: 1===
1    Loss: 0.787  Accuracy: 68.141
2    Loss: 0.604  Accuracy: 75.148
3    Loss: 0.577  Accuracy: 77.114
4    Loss: 0.545  Accuracy: 78.689
5    Loss: 0.538  Accuracy: 77.895
6    Loss: 0.518  Accuracy: 78.979
7    Loss: 0.510  Accuracy: 79.622
8    Loss: 0.508  Accuracy: 79.735
9    Loss: 0.511  Accuracy: 79.849
10   Loss: 0.506  Accuracy: 79.269
11   Loss: 0.495  Accuracy: 80.113
12   Loss: 0.490  Accuracy: 80.088
13   Loss: 0.492  Accuracy: 80.378
14   Loss: 0.490  Accuracy: 80.202
15   Loss: 0.492  Accuracy: 80.000
16   Loss: 0.482  Accuracy: 80.252
17   Loss: 0.483  Accuracy: 80.643
18   Loss: 0.470  Accuracy: 81.033
19   Loss: 0.465  Accuracy: 81.285
20   Loss: 0.476  Accuracy: 80.693
21   Loss: 0.482  Accuracy: 80.870
22   Loss: 0.462  Accuracy: 80.870
23   Loss: 0.464  Accuracy: 81.109
24   Loss: 0.456  Accuracy: 81.979
25   Loss: 0.463  Accuracy: 81.311
26   Loss: 0.459  Accuracy: 81.361
27   Loss: 0.460  Accuracy: 81.752
28   Loss: 0.447  Accuracy: 81.81

KeyboardInterrupt: 

In [7]:
y_pred = []
test_loss = 0

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        y_pred.append(predicted.cpu())

y_pred = torch.cat(y_pred).numpy()
y_true = subject_label_test_tensor.to("cpu").numpy()

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)

print(f"Loss: {test_loss / len(test_loader):.3f}")
print(f"Accuracy: {accuracy:.3f}")
for sleep_stage, precision_val, f1_val, recall_val in zip(SLEEP_STAGES, precision, recall, f1):
    print(f"{sleep_stage:7}precision={precision_val:.3f}  recall={recall_val:.3f}  f1={f1_val:.3f}")

Loss: 1.290
Accuracy: 0.632
W      precision=0.876  recall=0.712  f1=0.600
N1     precision=0.421  recall=0.488  f1=0.580
N2     precision=0.614  recall=0.719  f1=0.869
N3     precision=0.790  recall=0.587  f1=0.466
REM    precision=0.636  recall=0.344  f1=0.235


In [9]:
confusion_matrix(y_true, y_pred)

array([[ 99,  44,  19,   2,   1],
       [  8,  69,  32,   0,  10],
       [  5,  20, 324,  19,   5],
       [  0,   0,  95,  83,   0],
       [  1,  31,  58,   1,  28]])