In [1]:
import mne
import torch

In [208]:
class ChannelShuffle(torch.nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        batch_size, channels, length = x.size()
        x = x.view(batch_size, self.groups, channels // self.groups, length)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, channels, length)
        return x


class ResidualBlock1(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock1, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=8, groups=8)
        self.shuffle1 = ChannelShuffle(8)
        self.conv2 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=7, groups=8)
        self.shuffle2 = ChannelShuffle(8)

    def forward(self, x):
        out = self.conv1(x)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        out += x
        return out


class ResidualBlock2(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock2, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 128, kernel_size=16, padding=8, stride=2, groups=16)
        self.shuffle1 = ChannelShuffle(16)
        self.conv2 = torch.nn.Conv1d(128, 128, kernel_size=16, padding=7, groups=16)
        self.shuffle2 = ChannelShuffle(16)

        self.match_dimensions = torch.nn.Conv1d(64, 128, kernel_size=1, stride=2)

    def forward(self, x):
        print(x.shape)
        out = self.conv1(x)
        print(out.shape)
        out = torch.relu(out)
        out = self.shuffle1(out)
        print(out.shape)

        out = self.conv2(out)
        print(out.shape)
        out = torch.relu(out)
        out = self.shuffle2(out)
        print(out.shape)

        x = self.match_dimensions(x)
        print(x.shape)

        out += x
        return out


class LightSleepNet(torch.nn.Module):
    def __init__(self):
        super(LightSleepNet, self).__init__()
        self.conv = torch.nn.Conv1d(1, 64, kernel_size=16, padding=7, stride=2)
        self.residual1 = ResidualBlock1()
        self.residual2 = ResidualBlock2()
        self.dropout = torch.nn.Dropout(0.5)
        self.pooling = torch.nn.AdaptiveAvgPool1d(1)
        self.linear = torch.nn.Linear(128, 5)

    def forward(self, x):
        x = self.conv(x)
        x = torch.relu(x)
        x = self.residual1(x)
        x = self.residual2(x)
        x = self.dropout(x)
        x = self.pooling(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return torch.softmax(x, dim=1)

In [209]:
SUBJECT = 1

In [210]:
raw = mne.io.read_raw_fif(f"processed/{SUBJECT}/{SUBJECT}_annotator_1_eeg.fif")
channel_names = raw.ch_names
channel_names

Opening raw data file processed/1/1_annotator_1_eeg.fif...
    Range : 0 ... 5723999 =      0.000 ... 28619.995 secs
Ready.


['F3', 'C3', 'O1', 'F4', 'C4', 'O2']

In [211]:
raw.pick(channel_names[0])
epochs = mne.make_fixed_length_epochs(raw, duration=30, preload=True)
annotations = raw.annotations

Not setting metadata
954 matching events found
No baseline correction applied
0 projection items activated
Loading data for 954 events and 6000 original time points ...
0 bad epochs dropped


In [212]:
data_tensor = torch.tensor(epochs.get_data(), dtype=torch.float32).to("cuda")
data_tensor

tensor([[[ 5.7221e-09, -1.2322e-07, -9.0448e-07,  ...,  1.1952e-06,
           8.0148e-07,  2.9259e-07]],

        [[ 1.4076e-07,  5.7259e-07,  1.2287e-06,  ..., -4.3679e-07,
          -1.5679e-07, -4.7951e-07]],

        [[-8.5183e-07, -7.4884e-07, -2.8801e-07,  ...,  5.5505e-07,
           4.8791e-07,  1.1929e-06]],

        ...,

        [[ 1.4363e-06,  1.9024e-06,  2.1679e-06,  ...,  3.6915e-06,
           2.6608e-06,  6.2829e-07]],

        [[-2.1630e-07,  1.6255e-06,  4.6689e-06,  ...,  9.3061e-06,
           9.2466e-06,  9.2840e-06]],

        [[ 9.4373e-06,  9.4930e-06,  9.2550e-06,  ..., -9.7391e-07,
          -1.1395e-06, -1.1578e-06]]], device='cuda:0')

In [213]:
labels_tensor = torch.tensor(list(map(float, annotations.description)), dtype=torch.float32).to("cuda")
labels_tensor

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
        1., 2., 2., 2., 0., 0., 0., 0., 0., 0., 1., 1., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 1., 1., 2., 2., 2., 2., 2., 0., 0., 0., 1., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3., 0., 1., 2.,
        2., 2., 2., 2., 2., 0., 1., 0., 1., 0., 1., 2., 2., 1., 0., 1., 1., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 2., 2., 2., 2., 2.,
        2., 2., 2., 3., 3., 3., 3., 2., 3., 3., 3., 3., 3., 3., 0., 1., 2., 2.,
        2., 2., 3., 3., 3., 3., 3., 3., 

In [214]:
train_dataset = torch.utils.data.TensorDataset(data_tensor, labels_tensor)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

In [215]:
model = LightSleepNet().to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [218]:
model.train()
for training_epoch in range(10):
    print(training_epoch + 1)
    for X_batch, y_batch in train_loader:
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

1
torch.Size([32, 64, 3000])
torch.Size([32, 128, 1501])
torch.Size([32, 128, 1501])
torch.Size([32, 128, 1500])
torch.Size([32, 128, 1500])
torch.Size([32, 128, 1500])


RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Float'