In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torchsummary import summary

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [237]:
class DepthwiseConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias = False, auto_padding = False, padding = 0, stride = 1):
        super(DepthwiseConv2D, self).__init__()
        assert out_channels%in_channels == 0, '#out_channels is not divisible by #in_channels'
        self.depthwise = nn.Conv2d(in_channels, out_channels,
                                   kernel_size = kernel_size, padding = padding, stride = stride,
                                   groups = in_channels, bias = bias)
        
    def forward(self, X):
        return self.depthwise(X)
        
        

class DepthwiseSeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, depth = 1, bias = False, auto_padding = False, padding = 0, stride = 1):
        super(DepthwiseSeparableConv2D, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, 
                              kernel_size = kernel_size,  padding = padding, stride = stride, 
                              groups = in_channels, bias = bias)
        self.pointwise = nn.Conv2d(in_channels, depth * out_channels, kernel_size = 1, bias = bias)
        
        self.depthwise_separable_convolution = nn.Sequential(self.depthwise,
                                                             self.pointwise)
        
    def forward(self, X):
        return self.depthwise_separable_convolution(X)

In [4]:
class EEGNet(nn.Module):
    def __init__(self, C, T, F1, D, F2, dropout_rate, num_classes):
        super(EEGNet, self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, kernel_size = (1, 64), stride = 1, padding = 0),
            nn.BatchNorm2d(F1),
            DepthwiseConv2D(F1, D*F1, kernel_size = (C, 1), stride = 1, padding = 0),
            nn.BatchNorm2d(D*F1),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = (1, 4), stride = 4, padding = 0),
            nn.Dropout(dropout_rate)
        )
        
        self.block2 = nn.Sequential(
            DepthwiseSeparableConv2D(D*F1, F2, kernel_size = (1, 16), stride = 1, padding = 0),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = (1, 8), stride = 8, padding = 0),
            nn.Dropout(dropout_rate)
        )
        
        self.classifier = nn.Linear(F2 * int(T/32), num_classes)
        
    def forward(self, X):
        X = self.block1(X)
        X = self.block2(X)
        X = self.classifier(X)
        return X

In [246]:
class CompactNet(nn.Module):
    def __init__(self, C, T, F1, D, F2, dropout_rate, num_classes):
        super(CompactNet, self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, kernel_size = (1, 257), stride = 1, padding = 0),
            nn.BatchNorm2d(F1),
            DepthwiseConv2D(F1, D*F1, kernel_size = (C, 1), stride = 1, padding = 0),
            nn.BatchNorm2d(D*F1),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = (1, 2), stride = 2, padding = 0),
            nn.Dropout(dropout_rate)
        )
        
        self.block2 = nn.Sequential(
            DepthwiseSeparableConv2D(D*F1, F2, kernel_size = (1, 17), stride = 1, padding = (0, 8)),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = (1, 8), stride = 8, padding = 0),
            nn.Dropout(dropout_rate)
        )
        
        self.classifier = nn.Linear(F2 * int(T/32), num_classes)
        
    def forward(self, X):
        X = self.block1(X)
        X = self.block2(X)
        X = torch.flatten(X, 1)
        X = self.classifier(X)
        return X
    
model = CompactNet(64, 256*2, 96, 1, 96, 0.5, 40).to(device)

In [245]:
summary(model, (1, 64, 512))

torch.Size([2, 1, 64, 512])
torch.Size([2, 96, 1, 128])
torch.Size([2, 96, 1, 16])
torch.Size([2, 1536])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 96, 64, 256]          24,768
       BatchNorm2d-2          [-1, 96, 64, 256]             192
            Conv2d-3           [-1, 96, 1, 256]           6,144
   DepthwiseConv2D-4           [-1, 96, 1, 256]               0
       BatchNorm2d-5           [-1, 96, 1, 256]             192
               ELU-6           [-1, 96, 1, 256]               0
         AvgPool2d-7           [-1, 96, 1, 128]               0
           Dropout-8           [-1, 96, 1, 128]               0
            Conv2d-9           [-1, 96, 1, 128]           1,632
           Conv2d-10           [-1, 96, 1, 128]           1,632
           Conv2d-11           [-1, 96, 1, 128]           9,216
           Conv2d-12           [-1, 96, 1, 128]           9,21