In [123]:
import mne
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [88]:
class ChannelShuffle(torch.nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        batch_size, channels, length = x.size()
        x = x.view(batch_size, self.groups, channels // self.groups, length)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, channels, length)
        return x


class ResidualBlock1(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock1, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=8, groups=8)
        self.batchnorm1 = torch.nn.BatchNorm1d(64)
        self.shuffle1 = ChannelShuffle(8)
        self.conv2 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=7, groups=8)
        self.batchnorm2 = torch.nn.BatchNorm1d(64)
        self.shuffle2 = ChannelShuffle(8)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        out += x
        return out


class ResidualBlock2(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock2, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 128, kernel_size=16, padding=8, stride=2, groups=16)
        self.batchnorm1 = torch.nn.BatchNorm1d(128)
        self.shuffle1 = ChannelShuffle(16)
        self.conv2 = torch.nn.Conv1d(128, 128, kernel_size=16, padding=7, groups=16)
        self.batchnorm2 = torch.nn.BatchNorm1d(128)
        self.shuffle2 = ChannelShuffle(16)

        self.match_dimensions = torch.nn.Conv1d(64, 128, kernel_size=1, stride=2)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        x = self.match_dimensions(x)

        out += x
        return out


class LightSleepNet(torch.nn.Module):
    def __init__(self):
        super(LightSleepNet, self).__init__()
        self.conv = torch.nn.Conv1d(1, 64, kernel_size=16, padding=7, stride=2)
        self.batchnorm = torch.nn.BatchNorm1d(64)
        self.residual1 = ResidualBlock1()
        self.residual2 = ResidualBlock2()
        self.dropout = torch.nn.Dropout(0.5)
        self.pooling = torch.nn.AdaptiveAvgPool1d(1)
        self.linear = torch.nn.Linear(128, 5)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = torch.relu(x)
        x = self.residual1(x)
        x = self.residual2(x)
        x = self.dropout(x)
        x = self.pooling(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [119]:
CHANNEL_NAME = "F3"
ANNOTATOR = "annotator_1"
SLEEP_STAGES = ["W", "N1", "N2", "N3", "REM"]

In [94]:
subjects_data = []
subjects_labels = []

for subject in range(1, 11):
    print(f"==={subject}===")
    raw = mne.io.read_raw_fif(f"processed/{subject}/{subject}_{ANNOTATOR}_eeg.fif")
    raw.pick(CHANNEL_NAME)
    epochs = mne.make_fixed_length_epochs(raw, duration=30, preload=True)
    annotations = raw.annotations

    data_tensor = torch.tensor(epochs.get_data(), dtype=torch.float32).to("cuda")

    annotations_int = annotations.description.astype(int)
    annotations_int[annotations_int == 5] = 4
    labels_tensor = torch.tensor(annotations_int).to("cuda")

    subjects_data.append(data_tensor)
    subjects_labels.append(labels_tensor)

===1===
Opening raw data file processed/1/1_annotator_1_eeg.fif...
    Range : 0 ... 5723999 =      0.000 ... 28619.995 secs
Ready.
Not setting metadata
954 matching events found
No baseline correction applied
0 projection items activated
Loading data for 954 events and 6000 original time points ...
0 bad epochs dropped
===2===
Opening raw data file processed/2/2_annotator_1_eeg.fif...
    Range : 0 ... 5645999 =      0.000 ... 28229.995 secs
Ready.
Not setting metadata
941 matching events found
No baseline correction applied
0 projection items activated
Loading data for 941 events and 6000 original time points ...
0 bad epochs dropped
===3===
Opening raw data file processed/3/3_annotator_1_eeg.fif...
    Range : 0 ... 4943999 =      0.000 ... 24719.995 secs
Ready.
Not setting metadata
824 matching events found
No baseline correction applied
0 projection items activated
Loading data for 824 events and 6000 original time points ...
0 bad epochs dropped
===4===
Opening raw data file proc

In [95]:
SUBJECT_TEST = 10

In [104]:
# for subject_test in range(1, 11):
for subject_test in [SUBJECT_TEST]:
    print(f"===SUBJECT TEST: {subject_test}===")

    subjects_data_train = subjects_data[:subject_test - 1] + subjects_data[subject_test:]
    subjects_label_train = subjects_labels[:subject_test - 1] + subjects_labels[subject_test:]
    subjects_data_train_tensor = torch.cat(subjects_data_train)
    subjects_label_train_tensor = torch.cat(subjects_label_train)
    train_dataset = torch.utils.data.TensorDataset(subjects_data_train_tensor, subjects_label_train_tensor)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

    subject_data_test_tensor = subjects_data[subject_test - 1]
    subject_label_test_tensor = subjects_labels[subject_test - 1]
    test_dataset = torch.utils.data.TensorDataset(subject_data_test_tensor, subject_label_test_tensor)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = LightSleepNet().to("cuda")
    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()

    for training_epoch in range(100):
        print(training_epoch + 1)

        train_loss = 0
        correct = 0
        total = 0
        for X_batch, y_batch in train_loader:
            outputs = model(X_batch).float()
            loss = criterion(outputs, y_batch)
            train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = outputs.max(1)
            correct += predicted.eq(y_batch).sum().item()
            total += y_batch.size(0)

        print(f"train\taccuracy: {100 * correct / total:.3f}\tloss: {train_loss:.3f}")

        test_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for X_batch, y_batch in test_loader:

                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                test_loss += loss.item()

                _, predicted = outputs.max(1)
                correct += predicted.eq(y_batch).sum().item()
                total += y_batch.size(0)

        print(f"test\taccuracy: {100 * correct / total:.3f}\tloss: {test_loss:.3f}")

===SUBJECT TEST: 10===
1
train   accuracy: 68.133   loss: 202.714
test    accuracy: 30.402   loss: 41.736
2
train   accuracy: 75.287   loss: 156.648
test    accuracy: 32.663   loss: 44.439
3
train   accuracy: 76.918   loss: 144.354
test    accuracy: 34.171   loss: 43.188
4
train   accuracy: 77.511   loss: 142.161
test    accuracy: 35.050   loss: 43.883
5
train   accuracy: 78.142   loss: 136.152
test    accuracy: 34.171   loss: 50.051
6
train   accuracy: 78.512   loss: 136.797
test    accuracy: 34.422   loss: 44.915
7
train   accuracy: 78.920   loss: 131.651
test    accuracy: 40.955   loss: 40.709
8
train   accuracy: 79.538   loss: 128.536
test    accuracy: 33.920   loss: 46.939
9
train   accuracy: 80.292   loss: 124.871
test    accuracy: 37.437   loss: 44.514
10
train   accuracy: 80.230   loss: 124.990
test    accuracy: 34.925   loss: 48.130
11
train   accuracy: 80.180   loss: 123.235
test    accuracy: 36.055   loss: 46.055
12
train   accuracy: 80.279   loss: 123.655
test    accuracy: 

In [130]:
y_pred = []

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        outputs = model(X_batch)
        _, predicted = outputs.max(1)
        y_pred.append(predicted.cpu())

y_pred = torch.cat(y_pred).numpy()
y_true = subject_label_test_tensor.to("cpu").numpy()

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)

print(f"Accuracy: {accuracy:.3f}")
for sleep_stage, precision_val, f1_val, recall_val in zip(SLEEP_STAGES, precision, recall, f1):
    print(f"{sleep_stage:7}precision={precision_val:.3f}  recall={recall_val:.3f}  f1={f1_val:.3f}")

Accuracy: 0.553
W      precision=0.669  recall=0.620  f1=0.578
N1     precision=0.584  recall=0.492  f1=0.425
N2     precision=0.396  recall=0.523  f1=0.768
N3     precision=0.776  recall=0.789  f1=0.804
REM    precision=0.780  recall=0.364  f1=0.237
