In [1]:
import mne
import torch

In [2]:
class ChannelShuffle(torch.nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        batch_size, channels, length = x.size()
        x = x.view(batch_size, self.groups, channels // self.groups, length)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, channels, length)
        return x


class ResidualBlock1(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock1, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=8, groups=8)
        self.batchnorm1 = torch.nn.BatchNorm1d(64)
        self.shuffle1 = ChannelShuffle(8)
        self.conv2 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=7, groups=8)
        self.batchnorm2 = torch.nn.BatchNorm1d(64)
        self.shuffle2 = ChannelShuffle(8)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        out += x
        return out


class ResidualBlock2(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock2, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 128, kernel_size=16, padding=8, stride=2, groups=16)
        self.batchnorm1 = torch.nn.BatchNorm1d(128)
        self.shuffle1 = ChannelShuffle(16)
        self.conv2 = torch.nn.Conv1d(128, 128, kernel_size=16, padding=7, groups=16)
        self.batchnorm2 = torch.nn.BatchNorm1d(128)
        self.shuffle2 = ChannelShuffle(16)

        self.match_dimensions = torch.nn.Conv1d(64, 128, kernel_size=1, stride=2)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        x = self.match_dimensions(x)

        out += x
        return out


class LightSleepNet(torch.nn.Module):
    def __init__(self):
        super(LightSleepNet, self).__init__()
        self.conv = torch.nn.Conv1d(1, 64, kernel_size=16, padding=7, stride=2)
        self.batchnorm = torch.nn.BatchNorm1d(64)
        self.residual1 = ResidualBlock1()
        self.residual2 = ResidualBlock2()
        self.dropout = torch.nn.Dropout(0.5)
        self.pooling = torch.nn.AdaptiveAvgPool1d(1)
        self.linear = torch.nn.Linear(128, 5)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = torch.relu(x)
        x = self.residual1(x)
        x = self.residual2(x)
        x = self.dropout(x)
        x = self.pooling(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [3]:
SUBJECT = 1

In [4]:
raw = mne.io.read_raw_fif(f"processed/{SUBJECT}/{SUBJECT}_annotator_1_eeg.fif")
channel_names = raw.ch_names
channel_names

Opening raw data file processed/1/1_annotator_1_eeg.fif...
    Range : 0 ... 5723999 =      0.000 ... 28619.995 secs
Ready.


['F3', 'C3', 'O1', 'F4', 'C4', 'O2']

In [5]:
raw.pick(channel_names[0])
epochs = mne.make_fixed_length_epochs(raw, duration=30, preload=True)
annotations = raw.annotations

Not setting metadata
954 matching events found
No baseline correction applied
0 projection items activated
Loading data for 954 events and 6000 original time points ...
0 bad epochs dropped


In [6]:
data_tensor = torch.tensor(epochs.get_data(), dtype=torch.float32).to("cuda")
data_tensor

tensor([[[ 5.7221e-09, -1.2322e-07, -9.0448e-07,  ...,  1.1952e-06,
           8.0148e-07,  2.9259e-07]],

        [[ 1.4076e-07,  5.7259e-07,  1.2287e-06,  ..., -4.3679e-07,
          -1.5679e-07, -4.7951e-07]],

        [[-8.5183e-07, -7.4884e-07, -2.8801e-07,  ...,  5.5505e-07,
           4.8791e-07,  1.1929e-06]],

        ...,

        [[ 1.4363e-06,  1.9024e-06,  2.1679e-06,  ...,  3.6915e-06,
           2.6608e-06,  6.2829e-07]],

        [[-2.1630e-07,  1.6255e-06,  4.6689e-06,  ...,  9.3061e-06,
           9.2466e-06,  9.2840e-06]],

        [[ 9.4373e-06,  9.4930e-06,  9.2550e-06,  ..., -9.7391e-07,
          -1.1395e-06, -1.1578e-06]]], device='cuda:0')

In [7]:
labels_tensor = torch.tensor(list(map(int, annotations.description))).to("cuda")
labels_tensor

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 1, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 0, 1, 2, 2, 2, 2, 2, 2, 0,
        1, 0, 1, 0, 1, 2, 2, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 0, 1, 2, 2,
        2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 2, 1, 1, 1,
        1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1,
        1, 1, 2, 2, 2, 2, 5, 5, 5, 5, 5,

In [8]:
train_dataset = torch.utils.data.TensorDataset(data_tensor, labels_tensor)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

In [9]:
model = LightSleepNet().to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [15]:
model.train()
for training_epoch in range(10):
    print(training_epoch + 1)
    for X_batch, y_batch in train_loader:
        outputs = model(X_batch).float()
        loss = criterion(outputs, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

1


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [12]:
outputs.shape

torch.Size([32, 5])

In [13]:
y_batch

tensor([5, 2, 5, 2, 2, 0, 5, 3, 1, 3, 0, 2, 3, 2, 2, 3, 1, 2, 5, 2, 0, 5, 0, 5,
        2, 2, 0, 5, 1, 3, 0, 1], device='cuda:0')

In [14]:
torch.softmax(outputs, dim=1)

tensor([[0.2433, 0.2324, 0.1811, 0.1750, 0.1682],
        [0.2735, 0.2521, 0.2123, 0.1659, 0.0963],
        [0.2515, 0.2349, 0.1855, 0.1729, 0.1551],
        [0.2714, 0.2436, 0.2120, 0.1664, 0.1066],
        [0.2654, 0.2431, 0.2113, 0.1668, 0.1135],
        [0.2454, 0.2327, 0.1839, 0.1754, 0.1627],
        [0.2460, 0.2311, 0.1839, 0.1773, 0.1617],
        [0.2709, 0.2559, 0.2200, 0.1692, 0.0841],
        [0.2507, 0.2331, 0.1841, 0.1752, 0.1568],
        [0.2588, 0.2389, 0.2030, 0.1744, 0.1249],
        [0.2679, 0.2570, 0.2009, 0.1724, 0.1019],
        [0.2704, 0.2423, 0.2064, 0.1706, 0.1104],
        [0.2778, 0.2468, 0.2229, 0.1668, 0.0857],
        [0.2653, 0.2427, 0.2051, 0.1697, 0.1172],
        [0.2558, 0.2372, 0.1947, 0.1773, 0.1351],
        [0.2752, 0.2469, 0.2284, 0.1621, 0.0875],
        [0.2482, 0.2310, 0.1789, 0.1736, 0.1683],
        [0.2676, 0.2466, 0.2081, 0.1748, 0.1028],
        [0.2407, 0.2355, 0.1766, 0.1754, 0.1718],
        [0.2494, 0.2376, 0.1844, 0.1731, 0.1555],


In [14]:
torch.argmax(outputs, dim=1)

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')