In [1]:
import mne
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
class ChannelShuffle(torch.nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        batch_size, channels, length = x.size()
        x = x.view(batch_size, self.groups, channels // self.groups, length)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, channels, length)
        return x


class ResidualBlock1(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock1, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=8, groups=8)
        self.batchnorm1 = torch.nn.BatchNorm1d(64)
        self.shuffle1 = ChannelShuffle(8)
        self.conv2 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=7, groups=8)
        self.batchnorm2 = torch.nn.BatchNorm1d(64)
        self.shuffle2 = ChannelShuffle(8)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        out += x
        return out


class ResidualBlock2(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock2, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 128, kernel_size=16, padding=8, stride=2, groups=16)
        self.batchnorm1 = torch.nn.BatchNorm1d(128)
        self.shuffle1 = ChannelShuffle(16)
        self.conv2 = torch.nn.Conv1d(128, 128, kernel_size=16, padding=7, groups=16)
        self.batchnorm2 = torch.nn.BatchNorm1d(128)
        self.shuffle2 = ChannelShuffle(16)

        self.match_dimensions = torch.nn.Conv1d(64, 128, kernel_size=1, stride=2)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        x = self.match_dimensions(x)

        out += x
        return out


class LightSleepNet(torch.nn.Module):
    def __init__(self):
        super(LightSleepNet, self).__init__()
        self.conv = torch.nn.Conv1d(1, 64, kernel_size=16, padding=7, stride=2)
        self.batchnorm = torch.nn.BatchNorm1d(64)
        self.residual1 = ResidualBlock1()
        self.residual2 = ResidualBlock2()
        self.dropout = torch.nn.Dropout(0.5)
        self.pooling = torch.nn.AdaptiveAvgPool1d(1)
        self.linear = torch.nn.Linear(128, 5)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = torch.relu(x)
        x = self.residual1(x)
        x = self.residual2(x)
        x = self.dropout(x)
        x = self.pooling(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [3]:
CHANNEL_NAME = "F3"
ANNOTATOR = "annotator_1"
SLEEP_STAGES = ["W", "N1", "N2", "N3", "REM"]

In [4]:
subjects_data = []
subjects_labels = []

for subject in range(1, 11):
    print(f"==={subject}===")
    raw = mne.io.read_raw_fif(f"processed/{subject}/{subject}_{ANNOTATOR}_eeg.fif")
    raw.pick(CHANNEL_NAME)
    epochs = mne.make_fixed_length_epochs(raw, duration=30, preload=True)
    annotations = raw.annotations

    data_tensor = torch.tensor(epochs.get_data(), dtype=torch.float32).to("cuda")

    annotations_int = annotations.description.astype(int)
    annotations_int[annotations_int == 5] = 4
    labels_tensor = torch.tensor(annotations_int).to("cuda")

    subjects_data.append(data_tensor)
    subjects_labels.append(labels_tensor)

===1===
Opening raw data file processed/1/1_annotator_1_eeg.fif...
    Range : 0 ... 5723999 =      0.000 ... 28619.995 secs
Ready.
Not setting metadata
954 matching events found
No baseline correction applied
0 projection items activated
Loading data for 954 events and 6000 original time points ...
0 bad epochs dropped
===2===
Opening raw data file processed/2/2_annotator_1_eeg.fif...
    Range : 0 ... 5645999 =      0.000 ... 28229.995 secs
Ready.
Not setting metadata
941 matching events found
No baseline correction applied
0 projection items activated
Loading data for 941 events and 6000 original time points ...
0 bad epochs dropped
===3===
Opening raw data file processed/3/3_annotator_1_eeg.fif...
    Range : 0 ... 4943999 =      0.000 ... 24719.995 secs
Ready.
Not setting metadata
824 matching events found
No baseline correction applied
0 projection items activated
Loading data for 824 events and 6000 original time points ...
0 bad epochs dropped
===4===
Opening raw data file proc

In [5]:
SUBJECT_TEST = 1

In [6]:
# for subject_test in range(1, 11):
for subject_test in [SUBJECT_TEST]:
    print(f"===TEST SUBJECT: {subject_test}===")

    subjects_data_train = subjects_data[:subject_test - 1] + subjects_data[subject_test:]
    subjects_label_train = subjects_labels[:subject_test - 1] + subjects_labels[subject_test:]
    subjects_data_train_tensor = torch.cat(subjects_data_train)
    subjects_label_train_tensor = torch.cat(subjects_label_train)
    train_dataset = torch.utils.data.TensorDataset(subjects_data_train_tensor, subjects_label_train_tensor)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

    subject_data_test_tensor = subjects_data[subject_test - 1]
    subject_label_test_tensor = subjects_labels[subject_test - 1]
    test_dataset = torch.utils.data.TensorDataset(subject_data_test_tensor, subject_label_test_tensor)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = LightSleepNet().to("cuda")
    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()

    for training_epoch in range(100):
        print(f"{training_epoch + 1:<5}", end="")

        train_loss = 0
        correct = 0
        total = 0
        for X_batch, y_batch in train_loader:
            X_batch.requires_grad = True

            outputs = model(X_batch).float()
            loss = criterion(outputs, y_batch)

            weights = []

            for epoch_idx in range(X_batch.size(0)):
                epoch_data = X_batch[epoch_idx].unsqueeze(0)
                epoch_loss = criterion(model(epoch_data), y_batch[epoch_idx].unsqueeze(0))

                epoch_grads = torch.autograd.grad(epoch_loss, epoch_data, retain_graph=True)[0]
                grad_norms = epoch_grads.norm(p=2, dim=1).squeeze(0)

                delta = 0.1 * grad_norms.std().item()
                density = ((grad_norms.unsqueeze(1) - grad_norms.unsqueeze(0)).abs() < delta).sum(dim=1).float()

                epoch_weight = density.mean()
                weights.append(epoch_weight)

            weights = torch.tensor(weights, device="cuda")
            weights /= weights.sum()
            # print(weights)
            weighted_loss = (weights * loss).sum()
            train_loss += weighted_loss.item()

            optimizer.zero_grad()
            weighted_loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            correct += predicted.eq(y_batch).sum().item()
            total += y_batch.size(0)

        print(f"Loss: {train_loss:.3f}  Accuracy: {100 * correct / total:.3f}")

===TEST SUBJECT: 1===
1    Loss: 201.805  Accuracy: 67.083
2    Loss: 152.078  Accuracy: 75.186
3    Loss: 147.166  Accuracy: 75.539
4    Loss: 140.391  Accuracy: 77.240
5    Loss: 136.279  Accuracy: 77.631
6    Loss: 133.074  Accuracy: 78.286
7    Loss: 131.896  Accuracy: 78.664
8    Loss: 127.939  Accuracy: 79.004
9    Loss: 126.434  Accuracy: 79.597
10   Loss: 125.915  Accuracy: 79.332
11   Loss: 121.764  Accuracy: 80.088
12   Loss: 119.662  Accuracy: 80.504
13   Loss: 117.922  Accuracy: 80.681
14   Loss: 122.334  Accuracy: 80.038
15   Loss: 117.438  Accuracy: 80.630
16   Loss: 118.807  Accuracy: 80.781
17   Loss: 115.486  Accuracy: 80.769
18   Loss: 115.249  Accuracy: 81.046
19   Loss: 114.695  Accuracy: 81.147
20   Loss: 116.358  Accuracy: 80.504
21   Loss: 113.981  Accuracy: 81.437
22   Loss: 114.139  Accuracy: 81.298
23   Loss: 114.486  Accuracy: 81.172
24   Loss: 115.363  Accuracy: 81.122
25   Loss: 112.503  Accuracy: 81.311
26   Loss: 113.535  Accuracy: 81.727
27   Loss: 110.4

In [7]:
y_pred = []
test_loss = 0

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        y_pred.append(predicted.cpu())

y_pred = torch.cat(y_pred).numpy()
y_true = subject_label_test_tensor.to("cpu").numpy()

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)

print(f"Loss: {test_loss:.3f}")
print(f"Accuracy: {accuracy:.3f}")
for sleep_stage, precision_val, f1_val, recall_val in zip(SLEEP_STAGES, precision, recall, f1):
    print(f"{sleep_stage:7}precision={precision_val:.3f}  recall={recall_val:.3f}  f1={f1_val:.3f}")

Loss: 30.894
Accuracy: 0.684
W      precision=0.874  recall=0.732  f1=0.630
N1     precision=0.434  recall=0.496  f1=0.580
N2     precision=0.686  recall=0.743  f1=0.810
N3     precision=0.756  recall=0.787  f1=0.820
REM    precision=0.744  recall=0.395  f1=0.269
