In [None]:
class DepthwiseConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias = False, auto_padding = False, padding = 0, stride = 1):
        assert out_channels%in_channels == 0, '#out_channels is not divisible by #in_channels'
        
        super(DepthwiseConv2D, self).__init__()
        padding = check_AutoPadding(kernel_size, auto_padding)
        
        self.depthwise = nn.Conv2d(in_channels, out_channels,
                                   kernel_size = kernel_size, padding = padding, stride = stride,
                                   groups = in_channels, bias = bias)
        
    def forward(self, X):
        return self.depthwise(X)
        
        

class DepthwiseSeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, depth = 1, bias = False, auto_padding = False, padding = 0, stride = 1):
        super(DepthwiseSeparableConv2D, self).__init__()
        padding = check_AutoPadding(kernel_size, auto_padding)
            
        depthwise = nn.Conv2d(in_channels, in_channels, 
                              kernel_size = kernel_size,  padding = padding, stride = stride, 
                              groups = in_channels, bias = bias)
        pointwise = nn.Conv2d(in_channels, depth * out_channels, kernel_size = 1, bias = bias)
        
        self.depthwise_separable_convolution = nn.Sequential(depthwise,
                                                             pointwise)
        
    def forward(self, X):
        return self.depthwise_separable_convolution(X)

In [None]:
class EEGNet(nn.Module):
    def __init__(self, C, T, F1, D, F2, dropout_rate, num_classes):
        super(EEGNet, self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, kernel_size = (1, 64), stride = 1, padding = 0),
            nn.BatchNorm2d(F1),
            DepthwiseConv2D(F1, D*F1, kernel_size = (C, 1), stride = 1, padding = 0),
            nn.BatchNorm2d(D*F1),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = (1, 4), stride = 4, padding = 0),
            nn.Dropout(dropout_rate)
        )
        
        self.block2 = nn.Sequential(
            DepthwiseSeparableConv2D(D*F1, F2, kernel_size = (1, 16), stride = 1, padding = 0),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = (1, 8), stride = 8, padding = 0),
            nn.Dropout(dropout_rate)
        )
        
        self.classifier = nn.Linear(F2 * int(T/32), num_classes)
        
    def forward(X):
        X = self.block1(X)
        X = self.block2(X)
        X = self.classifier(X)
        return X

In [None]:
class CompactNet(nn.Module):
    def __init__(self, C, T, F1, D, F2, dropout_rate, num_classes):
        super(CompactNet, self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, kernel_size = (1, 256), stride = 1, padding = 0),
            nn.BatchNorm2d(F1),
            DepthwiseConv2D(F1, D*F1, kernel_size = (C, 1), stride = 1, padding = 0),
            nn.BatchNorm2d(D*F1),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = (1, 4), stride = 4, padding = 0),
            nn.Dropout(dropout_rate)
        )
        
        self.block2 = nn.Sequential(
            DepthwiseSeparableConv2D(D*F1, F2, kernel_size = (1, 16), stride = 1, padding = 0),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = (1, 8), stride = 8, padding = 0),
            nn.Dropout(dropout_rate)
        )
        
        self.classifier = nn.Linear(F2 * int(T/32), num_classes)
        
    def forward(X):
        X = self.block1(X)
        X = self.block2(X)
        X = self.classifier(X)
        return X