In [1]:
import mne
import torch

In [2]:
class ChannelShuffle(torch.nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        batch_size, channels, length = x.size()
        x = x.view(batch_size, self.groups, channels // self.groups, length)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, channels, length)
        return x


class ResidualBlock1(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock1, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=8, groups=8)
        self.batchnorm1 = torch.nn.BatchNorm1d(64)
        self.shuffle1 = ChannelShuffle(8)
        self.conv2 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=7, groups=8)
        self.batchnorm2 = torch.nn.BatchNorm1d(64)
        self.shuffle2 = ChannelShuffle(8)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        out += x
        return out


class ResidualBlock2(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock2, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 128, kernel_size=16, padding=8, stride=2, groups=16)
        self.batchnorm1 = torch.nn.BatchNorm1d(128)
        self.shuffle1 = ChannelShuffle(16)
        self.conv2 = torch.nn.Conv1d(128, 128, kernel_size=16, padding=7, groups=16)
        self.batchnorm2 = torch.nn.BatchNorm1d(128)
        self.shuffle2 = ChannelShuffle(16)

        self.match_dimensions = torch.nn.Conv1d(64, 128, kernel_size=1, stride=2)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        x = self.match_dimensions(x)

        out += x
        return out


class LightSleepNet(torch.nn.Module):
    def __init__(self):
        super(LightSleepNet, self).__init__()
        self.conv = torch.nn.Conv1d(1, 64, kernel_size=16, padding=7, stride=2)
        self.batchnorm = torch.nn.BatchNorm1d(64)
        self.residual1 = ResidualBlock1()
        self.residual2 = ResidualBlock2()
        self.dropout = torch.nn.Dropout(0.5)
        self.pooling = torch.nn.AdaptiveAvgPool1d(1)
        self.linear = torch.nn.Linear(128, 5)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = torch.relu(x)
        x = self.residual1(x)
        x = self.residual2(x)
        x = self.dropout(x)
        x = self.pooling(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [31]:
SUBJECT = 1

In [43]:
raw = mne.io.read_raw_fif(f"processed/{SUBJECT}/{SUBJECT}_annotator_1_eeg.fif")
channel_names = raw.ch_names
channel_names

Opening raw data file processed/1/1_annotator_1_eeg.fif...
    Range : 0 ... 5723999 =      0.000 ... 28619.995 secs
Ready.


['F3', 'C3', 'O1', 'F4', 'C4', 'O2']

In [44]:
raw.pick(channel_names[0])
epochs = mne.make_fixed_length_epochs(raw, duration=30, preload=True)
annotations = raw.annotations

Not setting metadata
954 matching events found
No baseline correction applied
0 projection items activated
Loading data for 954 events and 6000 original time points ...
0 bad epochs dropped


In [45]:
data_tensor = torch.tensor(epochs.get_data(), dtype=torch.float32).to("cuda")
data_tensor

tensor([[[ 5.7221e-09, -1.2322e-07, -9.0448e-07,  ...,  1.1952e-06,
           8.0148e-07,  2.9259e-07]],

        [[ 1.4076e-07,  5.7259e-07,  1.2287e-06,  ..., -4.3679e-07,
          -1.5679e-07, -4.7951e-07]],

        [[-8.5183e-07, -7.4884e-07, -2.8801e-07,  ...,  5.5505e-07,
           4.8791e-07,  1.1929e-06]],

        ...,

        [[ 1.4363e-06,  1.9024e-06,  2.1679e-06,  ...,  3.6915e-06,
           2.6608e-06,  6.2829e-07]],

        [[-2.1630e-07,  1.6255e-06,  4.6689e-06,  ...,  9.3061e-06,
           9.2466e-06,  9.2840e-06]],

        [[ 9.4373e-06,  9.4930e-06,  9.2550e-06,  ..., -9.7391e-07,
          -1.1395e-06, -1.1578e-06]]], device='cuda:0')

In [46]:
annotations_int = annotations.description.astype(int)
annotations_int[annotations_int == 5] = 4
labels_tensor = torch.tensor(annotations_int).to("cuda")
labels_tensor

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 1, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 0, 1, 2, 2, 2, 2, 2, 2, 0,
        1, 0, 1, 0, 1, 2, 2, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 0, 1, 2, 2,
        2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 2, 1, 1, 1,
        1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1,
        1, 1, 2, 2, 2, 2, 4, 4, 4, 4, 4,

In [47]:
train_dataset = torch.utils.data.TensorDataset(data_tensor, labels_tensor)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

In [60]:
model = LightSleepNet().to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [66]:
for training_epoch in range(100):
    print(training_epoch + 1)

    correct = 0
    total = 0
    for X_batch, y_batch in train_loader:
        outputs = model(X_batch).float()
        loss = criterion(outputs, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = outputs.max(1)
        print(predicted)
        correct += predicted.eq(y_batch).sum().item()
        total += y_batch.size(0)

    print(correct / total * 100)

1
tensor([2, 2, 0, 1, 4, 1, 1, 3, 3, 2, 2, 0, 0, 2, 4, 3, 4, 2, 2, 1, 3, 0, 4, 2,
        3, 0, 4, 3, 4, 0, 2, 4], device='cuda:0')
tensor([2, 1, 0, 2, 0, 0, 2, 4, 2, 3, 4, 2, 2, 3, 3, 3, 3, 2, 1, 0, 2, 4, 2, 2,
        1, 0, 3, 4, 2, 2, 4, 2], device='cuda:0')
tensor([2, 0, 3, 1, 2, 2, 0, 3, 3, 2, 3, 2, 2, 0, 4, 3, 4, 0, 2, 2, 2, 4, 2, 4,
        2, 1, 2, 2, 3, 2, 0, 3], device='cuda:0')
tensor([2, 3, 2, 2, 0, 3, 3, 4, 3, 1, 2, 1, 2, 0, 0, 2, 0, 3, 2, 4, 0, 0, 2, 3,
        2, 2, 4, 0, 1, 4, 3, 0], device='cuda:0')
tensor([0, 2, 0, 2, 2, 3, 0, 4, 1, 2, 2, 4, 2, 3, 4, 0, 2, 4, 1, 2, 3, 2, 3, 1,
        1, 0, 3, 1, 2, 2, 2, 1], device='cuda:0')
tensor([2, 3, 1, 2, 1, 3, 0, 0, 0, 3, 2, 0, 2, 2, 4, 0, 4, 1, 0, 3, 4, 1, 2, 0,
        3, 2, 4, 4, 0, 1, 2, 3], device='cuda:0')
tensor([2, 0, 4, 3, 2, 0, 2, 2, 4, 1, 3, 3, 1, 2, 2, 0, 2, 0, 3, 3, 2, 2, 2, 1,
        0, 0, 0, 3, 2, 2, 4, 2], device='cuda:0')
tensor([4, 3, 2, 2, 1, 2, 2, 2, 3, 0, 0, 0, 2, 2, 2, 2, 3, 2, 2, 2, 2, 1, 3, 3,
        

KeyboardInterrupt: 

In [72]:
raw = mne.io.read_raw_fif(f"processed/2/2_annotator_1_eeg.fif")
channel_names = raw.ch_names
raw.pick(channel_names[0])
data_tensor = torch.tensor(epochs.get_data(), dtype=torch.float32).to("cuda")

annotations_int = annotations.description.astype(int)
annotations_int[annotations_int == 5] = 4
labels_tensor = torch.tensor(annotations_int).to("cuda")

Opening raw data file processed/2/2_annotator_1_eeg.fif...
    Range : 0 ... 5645999 =      0.000 ... 28229.995 secs
Ready.


In [73]:
test_dataset = torch.utils.data.TensorDataset(data_tensor, labels_tensor)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

In [74]:
test_loss = 0
correct = 0
total = 0

with torch.no_grad():
    for X_batch, y_batch in test_loader:

        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        test_loss += loss.item()

        _, predicted = outputs.max(1)
        print(predicted)
        correct += predicted.eq(y_batch).sum().item()
        total += y_batch.size(0)

tensor([2, 2, 0, 2, 0, 2, 2, 0, 2, 2, 0, 2, 2, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1,
        1, 1, 1, 2, 2, 1, 2, 1], device='cuda:0')
tensor([1, 1, 1, 1, 4, 2, 2, 2, 0, 0, 0, 1, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 3, 2], device='cuda:0')
tensor([2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 1,
        1, 1, 0, 0, 1, 0, 0, 0], device='cuda:0')
tensor([1, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 3, 3, 3, 3, 3, 0], device='cuda:0')
tensor([2, 2, 2, 2, 2, 2, 2, 0, 1, 1, 1, 2, 1, 2, 2, 2, 0, 2, 2, 2, 0, 0, 0, 0,
        0, 1, 1, 1, 1, 1, 2, 2], device='cuda:0')
tensor([1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2,
        2, 2, 2, 2, 3, 3, 2, 3], device='cuda:0')
tensor([1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2,

In [75]:
test_loss / len(test_loader)

1.508563132584095

In [76]:
100 * correct / total

57.9664570230608

In [77]:
correct

553

In [78]:
total

954

In [41]:
predicted

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2], device='cuda:0')