In [1]:
import mne
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from lightsleepnet import LightSleepNet

In [2]:
CHANNEL_NAME = "F3"
ANNOTATOR = "annotator_1"
SLEEP_STAGES = ["W", "N1", "N2", "N3", "REM"]

In [3]:
subjects_data = []
subjects_labels = []

for subject in range(1, 11):
    print(f"==={subject}===")
    raw = mne.io.read_raw_fif(f"processed/{subject}/{subject}_{ANNOTATOR}_eeg.fif")
    raw.pick(CHANNEL_NAME)
    epochs = mne.make_fixed_length_epochs(raw, duration=30, preload=True)
    annotations = raw.annotations

    data_tensor = torch.tensor(epochs.get_data(), dtype=torch.float32).to("cuda")

    annotations_int = annotations.description.astype(int)
    annotations_int[annotations_int == 5] = 4
    labels_tensor = torch.tensor(annotations_int).to("cuda")

    subjects_data.append(data_tensor)
    subjects_labels.append(labels_tensor)

===1===
Opening raw data file processed/1/1_annotator_1_eeg.fif...
    Range : 0 ... 5723999 =      0.000 ... 28619.995 secs
Ready.
Not setting metadata
954 matching events found
No baseline correction applied
0 projection items activated
Loading data for 954 events and 6000 original time points ...
0 bad epochs dropped
===2===
Opening raw data file processed/2/2_annotator_1_eeg.fif...
    Range : 0 ... 5645999 =      0.000 ... 28229.995 secs
Ready.
Not setting metadata
941 matching events found
No baseline correction applied
0 projection items activated
Loading data for 941 events and 6000 original time points ...
0 bad epochs dropped
===3===
Opening raw data file processed/3/3_annotator_1_eeg.fif...
    Range : 0 ... 4943999 =      0.000 ... 24719.995 secs
Ready.
Not setting metadata
824 matching events found
No baseline correction applied
0 projection items activated
Loading data for 824 events and 6000 original time points ...
0 bad epochs dropped
===4===
Opening raw data file proc

In [4]:
SUBJECT_TEST = 1

In [6]:
# for subject_test in range(1, 11):
for subject_test in [SUBJECT_TEST]:
    print(f"===TEST SUBJECT: {subject_test}===")

    subjects_data_train = subjects_data[:subject_test - 1] + subjects_data[subject_test:]
    subjects_label_train = subjects_labels[:subject_test - 1] + subjects_labels[subject_test:]
    subjects_data_train_tensor = torch.cat(subjects_data_train)
    subjects_label_train_tensor = torch.cat(subjects_label_train)
    train_dataset = torch.utils.data.TensorDataset(subjects_data_train_tensor, subjects_label_train_tensor)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

    subject_data_test_tensor = subjects_data[subject_test - 1]
    subject_label_test_tensor = subjects_labels[subject_test - 1]
    test_dataset = torch.utils.data.TensorDataset(subject_data_test_tensor, subject_label_test_tensor)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = LightSleepNet().to("cuda")
    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()

    for training_epoch in range(100):
        print(f"{training_epoch + 1:<5}", end="")

        train_loss = 0
        correct = 0
        total = 0
        for X_batch, y_batch in train_loader:
            X_batch.requires_grad = True

            outputs = model(X_batch).float()
            loss = criterion(outputs, y_batch)

            weights = []

            for epoch_idx in range(X_batch.size(0)):
                epoch_data = X_batch[epoch_idx].unsqueeze(0)
                epoch_loss = criterion(model(epoch_data), y_batch[epoch_idx].unsqueeze(0))

                epoch_grads = torch.autograd.grad(epoch_loss, epoch_data, retain_graph=True)[0]
                grad_norms = epoch_grads.norm(p=2, dim=1).squeeze(0)

                delta = 0.1 * grad_norms.std().item()
                density = ((grad_norms.unsqueeze(1) - grad_norms.unsqueeze(0)).abs() < delta).sum(dim=1).float()

                epoch_weight = density.mean()
                weights.append(epoch_weight.detach())

            weights = torch.tensor(weights, device="cuda")
            weights /= weights.sum()
            # print(weights)
            weighted_loss = (weights * loss).sum()
            train_loss += weighted_loss.item()

            optimizer.zero_grad()
            weighted_loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            correct += predicted.eq(y_batch).sum().item()
            total += y_batch.size(0)

        print(f"Loss: {train_loss / len(train_loader):.3f}  Accuracy: {100 * correct / total:.3f}")

===TEST SUBJECT: 1===
1    Loss: 0.789  Accuracy: 68.116
2    


KeyboardInterrupt



In [8]:
y_pred = []
test_loss = 0

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        outputs = model(
            X_batch,
            (
                subject_data_test_tensor.mean(dim=(0, 2)),
                subject_data_test_tensor.var(dim=(0, 2), unbiased=False)
            )
        )
        loss = criterion(outputs, y_batch)
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        y_pred.append(predicted.cpu())

y_pred = torch.cat(y_pred).numpy()
y_true = subject_label_test_tensor.to("cpu").numpy()

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)

print(f"Loss: {test_loss / len(test_loader):.3f}")
print(f"Accuracy: {accuracy:.3f}")
for sleep_stage, precision_val, f1_val, recall_val in zip(SLEEP_STAGES, precision, recall, f1):
    print(f"{sleep_stage:7}precision={precision_val:.3f}  recall={recall_val:.3f}  f1={f1_val:.3f}")

KeyboardInterrupt: 