In [1]:
import mne
import torch

In [2]:
class ChannelShuffle(torch.nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        ...

    def forward(self, x):
        ...
        return x


class ResidualBlock1(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock1, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=8, groups=8)
        self.shuffle1 = ChannelShuffle(8)
        self.conv2 = torch.nn.Conv1d(64, 64, kernel_size=16, padding=8, groups=8)
        self.shuffle2 = ChannelShuffle(8)

    def forward(self, x):
        out = self.conv1(x)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        out += x
        return out


class ResidualBlock2(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock2, self).__init__()
        self.conv1 = torch.nn.Conv1d(64, 128, kernel_size=16, padding=8, stride=2, groups=16)
        self.shuffle1 = ChannelShuffle(16)
        self.conv2 = torch.nn.Conv1d(128, 128, kernel_size=16, padding=8, groups=16)
        self.shuffle2 = ChannelShuffle(16)

    def forward(self, x):
        out = self.conv1(x)
        out = torch.relu(out)
        out = self.shuffle1(out)

        out = self.conv2(out)
        out = torch.relu(out)
        out = self.shuffle2(out)

        out += x
        return out


class LightSleepNet(torch.nn.Module):
    def __init__(self):
        super(LightSleepNet, self).__init__()
        self.conv = torch.nn.Conv1d(1, 64, kernel_size=16, padding=8, stride=2)
        self.residual1 = ResidualBlock1()
        self.residual2 = ResidualBlock2()
        self.dropout = torch.nn.Dropout(0.5)
        self.pooling = torch.nn.AdaptiveAvgPool1d(1)
        self.linear = torch.nn.Linear(128, 5)

    def forward(self, x):
        x = self.conv(x)
        x = torch.relu(x)
        x = self.residual1(x)
        x = self.residual2(x)
        x = self.dropout(x)
        x = self.pooling(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return torch.softmax(x, dim=1)