In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features= out_channels),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=out_channels)
        )
        
    def forward(self, x):
        return self.double_conv(x)

In [4]:
class MBConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, expansion_rate=6, se=True):
        super().__init__()
        # To be used in forward function as well
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.expansion_rate = expansion_rate
        self.se = se
        # No of output channels after expansion
        expansion_channels = expansion_rate * in_channels
        se_channels = max(1, int(in_channels * 0.25))
        
        # same padding and only kernel sizes of 3 and 5
        if kernel_size == 3:
            padding = 1
        elif kernel_size == 5:
            padding = 2
        else:
            assert "-- MyError --: unsupported kernel size"
        
        # Expansion
        if expansion_rate != 1:
            self.expand_conv = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=expansion_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(expansion_channels), 
                nn.ReLU()
        )
        # Depthwise convolution
        self.depthwise_conv = nn.Sequential(
            nn.Conv2d(in_channels=expansion_channels, out_channels=expansion_channels, kernel_size=kernel_size,
                        stride=stride, padding=padding, groups=expansion_channels, bias=False),
            nn.BatchNorm2d(expansion_channels),
            nn.ReLU()
        )
        # Squeeze and excitation block
        if se:
            self.se_block = nn.Sequential(
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Conv2d(in_channels=expansion_channels, out_channels=se_channels, kernel_size=1, bias=False),
                nn.ReLU(),
                nn.Conv2d(in_channels=se_channels, out_channels=expansion_channels, kernel_size=1, bias=False),
                nn.Sigmoid()
            )
        # Pointwise Convolution
        self.pointwise_conv = nn.Sequential(
            nn.Conv2d(in_channels=expansion_channels, out_channels=out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
        
    
    def forward(self, inputs):
        x = inputs
        
        if self.expansion_rate != 1:
            x = self.expand_conv(x)
        
        x = self.depthwise_conv(x)
        
        if self.se:
            x = self.se_block(x) * x
        
        x = self.pointwise_conv(x)
        
        if self.in_channels == self.out_channels and self.stride == 1:
            x = x + inputs
        
        return x        

In [None]:
class EffUNet(nn.Module):
    """ U-Net with EfficientNet-B0 encoder """
    def __init__(self, in_channels, classes):
        super().__init__()

        self.down_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )
        self.down_block2(
            
        )