In [None]:
class ConvolutionalProcessingBlock_bn(nn.Module):
    def __init__(self, input_shape, num_filters, kernel_size, padding, bias, dilation):
        super(ConvolutionalProcessingBlock, self).__init__()

        self.num_filters = num_filters
        self.kernel_size = kernel_size
        self.input_shape = input_shape
        self.padding = padding
        self.bias = bias
        self.dilation = dilation

        self.build_module_with_bn()

    def build_module_with_bn(self):
        self.layer_dict = nn.ModuleDict()
        x = torch.zeros(self.input_shape)
        out = x

        self.layer_dict['conv_0'] = nn.Conv2d(in_channels=out.shape[1], out_channels=self.num_filters, bias=self.bias,
                                              kernel_size=self.kernel_size, dilation=self.dilation,
                                              padding=self.padding, stride=1)

        out = self.layer_dict['conv_0'].forward(out)
        
        self.layer_dict['bn_0'] = nn.BatchNorm2d(out_channels)
        out = self.layer_dict['bn_0'](out)
        
        out = F.leaky_relu(out)

        self.layer_dict['conv_1'] = nn.Conv2d(in_channels=out.shape[1], out_channels=self.num_filters, bias=self.bias,
                                              kernel_size=self.kernel_size, dilation=self.dilation,
                                              padding=self.padding, stride=1)
        out = self.layer_dict['conv_1'].forward(out)
        
        self.layer_dict['bn_1'] = nn.BatchNorm2d(out_channels)
        out = self.layer_dict['bn_1'](out)
        
        out = F.leaky_relu(out)

        print(out.shape)

    def forward_with_bn(self, x):
        out = x

        out = self.layer_dict['conv_0'].forward(out)
        out = self.layer_dict['bn_0'](out)
        out = F.leaky_relu(out)

        out = self.layer_dict['conv_1'].forward(out)
        out = self.layer_dict['bn_1'](out)
        out = F.leaky_relu(out)

        return out


class ConvolutionalDimensionalityReductionBlock_bn(nn.Module):
    def __init__(self, input_shape, num_filters, kernel_size, padding, bias, dilation, reduction_factor):
        super(ConvolutionalDimensionalityReductionBlock, self).__init__()

        self.num_filters = num_filters
        self.kernel_size = kernel_size
        self.input_shape = input_shape
        self.padding = padding
        self.bias = bias
        self.dilation = dilation
        self.reduction_factor = reduction_factor
        self.build_module_with_bn()

    def build_module_with_bn(self):
        self.layer_dict = nn.ModuleDict()
        x = torch.zeros(self.input_shape)
        out = x

        self.layer_dict['conv_0'] = nn.Conv2d(in_channels=out.shape[1], out_channels=self.num_filters, bias=self.bias,
                                              kernel_size=self.kernel_size, dilation=self.dilation,
                                              padding=self.padding, stride=1)

        out = self.layer_dict['conv_0'].forward(out)
        self.layer_dict['bn_0'] = nn.BatchNorm2d(out_channels)
        out = self.layer_dict['bn_0'](out)
        out = F.leaky_relu(out)

        out = F.avg_pool2d(out, self.reduction_factor)

        self.layer_dict['conv_1'] = nn.Conv2d(in_channels=out.shape[1], out_channels=self.num_filters, bias=self.bias,
                                              kernel_size=self.kernel_size, dilation=self.dilation,
                                              padding=self.padding, stride=1)

        out = self.layer_dict['conv_1'].forward(out)
        self.layer_dict['bn_1'] = nn.BatchNorm2d(out_channels)
        out = self.layer_dict['bn_1'](out)
        out = F.leaky_relu(out)

        print(out.shape)

    def forward_with_bn(self, x):
        out = x

        out = self.layer_dict['conv_0'].forward(out)
        out = self.layer_dict['bn_0'](out)
        out = F.leaky_relu(out)

        out = F.avg_pool2d(out, self.reduction_factor)

        out = self.layer_dict['conv_1'].forward(out)
        out = self.layer_dict['bn_1'](out)
        out = F.leaky_relu(out)

        return out