In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        
        # The number of channels after expansion
        hidden_dim = int(round(in_channels * expand_ratio))
        
        # If expand_ratio is 1, we skip the pointwise convolution
        layers = []
        
        # Pointwise convolution (expand)
        if expand_ratio != 1:
            layers.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False))
            layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.ReLU6(inplace=True))  # ReLU6 is used
        
        # Depthwise convolution
        layers.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False))
        layers.append(nn.BatchNorm2d(hidden_dim))
        layers.append(nn.ReLU6(inplace=True))
        
        # Pointwise convolution (linear bottleneck)
        layers.append(nn.Conv2d(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))
        
        # Skip connection
        self.block = nn.Sequential(*layers)
        
        # Skip connection if stride == 1 and input/output dimensions match
        self.use_res_connect = (stride == 1 and in_channels == out_channels)
    
    def forward(self, x):
        if self.use_res_connect:
            return x + self.block(x)  # Residual connection
        else:
            return self.block(x)