In [6]:
import numpy as np
from torch import nn
from torch.nn import init
from torch.nn.functional import elu
from braindecode.torch_ext.modules import Expression, AvgPool2dWithConv
from braindecode.torch_ext.functions import identity
from braindecode.torch_ext.util import np_to_var


class Deep4Net(object):
    """
    Deep ConvNet model from [1]_.
    References
    ----------
    .. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., 
       Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017).
       Deep learning with convolutional neural networks for EEG decoding and
       visualization.
       arXiv preprint arXiv:1703.05051.
    """
    def __init__(self, in_chans,
                 n_classes,
                 input_time_length,
                 final_conv_length,
                 n_filters_time=25,
                 n_filters_spat=25,
                 filter_time_length=10,
                 pool_time_length=3,
                 pool_time_stride=3,
                 n_filters_2=50,
                 filter_length_2=10,
                 n_filters_3=100,
                 filter_length_3=10,
                 n_filters_4=200,
                 filter_length_4=10,
                 first_nonlin=elu,
                 first_pool_mode='max',
                 first_pool_nonlin=identity,
                 later_nonlin=elu,
                 later_pool_mode='max',
                 later_pool_nonlin=identity,
                 drop_prob=0.5,
                 double_time_convs=False,
                 split_first_layer=True,
                 batch_norm=True,
                 batch_norm_alpha=0.1,
                 stride_before_pool=False):
        if final_conv_length == 'auto':
            assert input_time_length is not None

        self.__dict__.update(locals())
        del self.self

    def create_network(self):
        if self.stride_before_pool:
            conv_stride = self.pool_time_stride
            pool_stride = 1
        else:
            conv_stride = 1
            pool_stride = self.pool_time_stride
        pool_class_dict = dict(max=nn.MaxPool2d, mean=AvgPool2dWithConv)
        first_pool_class = pool_class_dict[self.first_pool_mode]
        later_pool_class = pool_class_dict[self.later_pool_mode]
        model = nn.Sequential()
        if self.split_first_layer:
            model.add_module('dimshuffle', Expression(_transpose_time_to_spat))
            model.add_module('conv_time', nn.Conv2d(1, self.n_filters_time,
                                                    (
                                                    self.filter_time_length, 1),
                                                    stride=1, ))
            model.add_module('conv_spat',
                             nn.Conv2d(self.n_filters_time, self.n_filters_spat,
                                       (1, self.in_chans),
                                       stride=(conv_stride, 1),
                                       bias=not self.batch_norm))
            n_filters_conv = self.n_filters_spat
        else:
            model.add_module('conv_time',
                             nn.Conv2d(self.in_chans, self.n_filters_time,
                                       (self.filter_time_length, 1),
                                       stride=(conv_stride, 1),
                                       bias=not self.batch_norm))
            n_filters_conv = self.n_filters_time
        if self.batch_norm:
            model.add_module('bnorm',
                             nn.BatchNorm2d(n_filters_conv,
                                            momentum=self.batch_norm_alpha,
                                            affine=True,
                                            eps=1e-5),)
        model.add_module('conv_nonlin', Expression(self.first_nonlin))
        model.add_module('pool',
                         first_pool_class(
                             kernel_size=(self.pool_time_length, 1),
                             stride=(pool_stride, 1)))
        model.add_module('pool_nonlin', Expression(self.first_pool_nonlin))

        def add_conv_pool_block(model, n_filters_before,
                                n_filters, filter_length, block_nr):
            suffix = '_{:d}'.format(block_nr)
            model.add_module('drop' + suffix,
                             nn.Dropout(p=self.drop_prob))
            model.add_module('conv' + suffix.format(block_nr),
                             nn.Conv2d(n_filters_before, n_filters,
                                       (filter_length, 1),
                                       stride=(conv_stride, 1),
                                       bias=not self.batch_norm))
            if self.batch_norm:
                model.add_module('bnorm' + suffix,
                             nn.BatchNorm2d(n_filters,
                                            momentum=self.batch_norm_alpha,
                                            affine=True,
                                            eps=1e-5))
            model.add_module('nonlin' + suffix,
                             Expression(self.later_nonlin))

            model.add_module('pool' + suffix,
                             later_pool_class(
                                 kernel_size=(self.pool_time_length, 1),
                                 stride=(pool_stride, 1)))
            model.add_module('pool_nonlin' + suffix,
                             Expression(self.later_pool_nonlin))

        add_conv_pool_block(model, n_filters_conv, self.n_filters_2,
                            self.filter_length_2, 2)
        add_conv_pool_block(model, self.n_filters_2, self.n_filters_3,
                            self.filter_length_3, 3)
        add_conv_pool_block(model, self.n_filters_3, self.n_filters_4,
                            self.filter_length_4, 4)


        model.eval()
        if self.final_conv_length == 'auto':
            out = model(np_to_var(np.ones(
                (1, self.in_chans, self.input_time_length,1),
                dtype=np.float32)))
            n_out_time = out.cpu().data.numpy().shape[2]
            self.final_conv_length = n_out_time
        model.add_module('conv_classifier',
                             nn.Conv2d(self.n_filters_4, self.n_classes,
                                       (self.final_conv_length, 1), bias=True))
        model.add_module('softmax', nn.LogSoftmax())
        model.add_module('squeeze',  Expression(_squeeze_final_output))

        # Initialization, xavier is same as in our paper...
        # was default from lasagne
        init.xavier_uniform(model.conv_time.weight, gain=1)
        # maybe no bias in case of no split layer and batch norm
        if self.split_first_layer or (not self.batch_norm):
            init.constant(model.conv_time.bias, 0)
        if self.split_first_layer:
            init.xavier_uniform(model.conv_spat.weight, gain=1)
            if not self.batch_norm:
                init.constant(model.conv_spat.bias, 0)
        if self.batch_norm:
            init.constant(model.bnorm.weight, 1)
            init.constant(model.bnorm.bias, 0)
        param_dict = dict(list(model.named_parameters()))
        for block_nr in range(2,5):
            conv_weight = param_dict['conv_{:d}.weight'.format(block_nr)]
            init.xavier_uniform(conv_weight, gain=1)
            if not self.batch_norm:
                conv_bias = param_dict['conv_{:d}.bias'.format(block_nr)]
                init.constant(conv_bias, 0)
            else:
                bnorm_weight = param_dict['bnorm_{:d}.weight'.format(block_nr)]
                bnorm_bias = param_dict['bnorm_{:d}.bias'.format(block_nr)]
                init.constant(bnorm_weight, 1)
                init.constant(bnorm_bias, 0)

        init.xavier_uniform(model.conv_classifier.weight, gain=1)
        init.constant(model.conv_classifier.bias, 0)

        # Start in eval mode
        model.eval()
        return model


# remove empty dim at end and potentially remove empty time dim
# do not just use squeeze as we never want to remove first dim
def _squeeze_final_output(x):
    assert x.size()[3] == 1
    x = x[:,:,:,0]
    if x.size()[2] == 1:
        x = x[:,:,0]
    return x


def _transpose_time_to_spat(x):
    return x.permute(0, 3, 2, 1)


In [8]:
from torchsummary import summary
net = Deep4Net(1, 60, 500, 10)
model = net.create_network()
# dir(model)
print(model)
summary(model, (1, 60, 500))



Sequential(
  (dimshuffle): Expression(expression=_transpose_time_to_spat)
  (conv_time): Conv2d (1, 25, kernel_size=(10, 1), stride=(1, 1))
  (conv_spat): Conv2d (25, 25, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bnorm): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True)
  (conv_nonlin): Expression(expression=elu)
  (pool): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), dilation=(1, 1))
  (pool_nonlin): Expression(expression=identity)
  (drop_2): Dropout(p=0.5)
  (conv_2): Conv2d (25, 50, kernel_size=(10, 1), stride=(1, 1), bias=False)
  (bnorm_2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True)
  (nonlin_2): Expression(expression=elu)
  (pool_2): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), dilation=(1, 1))
  (pool_nonlin_2): Expression(expression=identity)
  (drop_3): Dropout(p=0.5)
  (conv_3): Conv2d (50, 100, kernel_size=(10, 1), stride=(1, 1), bias=False)
  (bnorm_3): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True)
  (nonlin_3): Expression(expression=elu)
  

RuntimeError: Given groups=1, weight[25, 1, 10, 1], so expected input[1, 500, 60, 1] to have 1 channels, but got 500 channels instead

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=60, out_channels=10, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Net()
if torch.cuda.is_available():
    model.cuda()

In [16]:
summary(model, (60, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 24, 24]           15010
            Conv2d-2             [-1, 20, 8, 8]            5020
         Dropout2d-3             [-1, 20, 8, 8]               0
            Linear-4                   [-1, 50]           16050
            Linear-5                   [-1, 10]             510
Total params: 36590
Trainable params: 36590
Non-trainable params: 0
----------------------------------------------------------------
