In [1]:
import torch
import torch.nn as nn 
import torchvision

In [10]:
class SAE_block(nn.Module):
    def __init__(self, in_channels, r):
        super(SAE_block, self).__init__()
        self.in_channels = in_channels
        self.r = r                     # r is reduction rate in excitation part of the block

        self.block = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), # Squeeze (b, c, h, w) ---> (b, c, 1, 1)
            nn.Conv2d(in_channels, out_channels=in_channels//r, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels//r, in_channels, 1),
            nn.Sigmoid())
    def forward(self, x):
        return  x * self.block(x)

In [12]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
input = torch.randn(64, 16, 224, 224).to(device)
layer = SAE_block(16, r=4).to(device)
output = layer(input)
print(output.shape)

torch.Size([64, 16, 224, 224])


In [13]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, expansion_rate, *kwargs):
        super(self, InvertedResidual).__init__()
        self.in_channels= in_channels
        self.expansion_rate = expansion_rate
        
        self.expansion_layer = nn.Sequential(
            nn.Conv2d(in_channels, in_channels*expansion_rate, kernel_size=1),
            nn.BatchNorm2d(in_channels*expansion_rate),
            nn.ReLU())
        
        self.depthwise_layer = nn.Sequential(
            nn.Conv2d(in_channels*expansion_rate, in_channels*expansion_rate, kernel_size=3, groups=in_channels*expansion_rate, *kwargs),
            nn.BatchNorm2d(in_channels*expansion_rate),
            nn.ReLU())
        
        self.pointwise_layer = nn.Sequential(
            nn.Conv2d(in_channels*expansion_rate, in_channels, kernel_size=1),
            nn.BatchNorm2d(in_channels))
        
    def forward(self, x):
        el = self.expansion_layer(x)
        dl = self.depthwise_layer(el)
        pl = self.pointwise_layer(dl)
        return x + pl

In [None]:
x = torch.randn(64, 16, 224, 224).to(device)
layer1 = InvertedResidual(16, 6, stride=1).to(device)
layer2 = InvertedResidual(16, 6, stride=2).to(device)
