In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MBConvSE(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        expansion_factor=6,
        kernel_size=3,
        stride=1,
        se_ratio=0.25,
    ) -> None:
        super().__init__()
        hidden_dim = in_channels * expansion_factor
        self.use_residual = in_channels == out_channels and stride == 1

        # Expansion
        self.expand = (
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=hidden_dim,
                kernel_size=1,
                bias=False,
            )
            if expansion_factor > 1
            else None
        )
        self.bn1 = nn.BatchNorm2d(hidden_dim) if expansion_factor > 1 else None

        # Depthwise Convolution
        self.depthwise = nn.Conv2d(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            kernel_size=kernel_size,
            bias=False,
            stride=stride,
            padding=kernel_size // 2,
            groups=hidden_dim,
        )
        self.bn2 = nn.BatchNorm2d(hidden_dim)

        # Squeeze-and-Excite
        se_hidden_dim = max(1, int(hidden_dim * se_ratio))
        self.se = (
            nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(hidden_dim, se_hidden_dim, kernel_size=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(se_hidden_dim, hidden_dim, kernel_size=1),
                nn.Sigmoid(),
            )
            if se_ratio > 0
            else None
        )

        # Projection
        self.project = nn.Conv2d(
            hidden_dim, out_channels=out_channels, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = x

        if self.expand:
            print("EXPANSION USED")
            out = F.relu6(self.bn1(self.expand(out)))

        out = F.relu6(self.bn2(self.depthwise(out)))

        if self.se:
            print("SE USED")
            out *= self.se(out)

        out = self.bn3(self.project(out))

        if self.use_residual:
            print("RESIDUAL USED")
            out += x

        return out

In [9]:
inp = torch.randn(1, 16, 32, 32)
model = MBConvSE(16, 16)
model(inp).size()

EXPANSION USED
SE USED
RESIDUAL USED


torch.Size([1, 16, 32, 32])