In [1]:
import torch
import torch.nn as nn
import math

In [2]:
__all__ = ['mobilenetv4_conv_small', 'mobilenetv4_conv_medium','mobilenetv4_conv_large','mobilenetv4_hybrid_medium','mobilenetv4_hybrid_large']

The make_divisble function (note: it should be make_divisible, likely a typo) is a utility often used in deep learning (e.g., MobileNet, EfficientNet) to ensure that a given value (typically the number of channels or filters) is divisible by a specific number (e.g., 8 or 16). This helps maintain alignment with hardware requirements for memory and computational efficiency.

In [3]:
def make_divisible(value,divisor,min_value=None,round_down_protect=True):
  if min_value is None:
    min_value=divisor
  new_value=max(min_value,int(value +divisor/2)//divisor*divisor) #This line rounds value to the nearest multiple of divisor, but ensures it's at least min_value.
  if round_down_protect and new_value < 0.9*value: #If new_value is less than 90% of the original value, it increments it by one more divisor to avoid excessive downscaling.
    new_value+=divisor
  return new_value


| Term                     | Meaning                                                            |
| ------------------------ | ------------------------------------------------------------------ |
| `inplace=True`           | Modify the input tensor directly (save memory, faster, but risky)  |
| `(kernel_size - 1) // 2` | Padding to preserve input size during convolution ("same padding") |


In [5]:
class ConvBN(nn.Module):
  def __init__(self,in_channels,out_channels,kernel_size,stride=1):
    super(ConvBN,self).__init__()
    self.block=nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,stride,(kernel_size -1)//2,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),

    )
  def forward(self,x):
    return self.block(x)

| Parameter                | Type    | Description                                                                                                                                    |
| ------------------------ | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| `self`                   | object  | Reference to the class instance (standard in Python class methods).                                                                            |
| `in_channels`            | `int`   | Number of input channels to the block.                                                                                                         |
| `out_channels`           | `int`   | Number of output channels from the block.                                                                                                      |
| `expand_ratio`           | `float` | Expansion factor for hidden (intermediate) channels, e.g., in MobileNet blocks. Determines how much to increase channels in expansion phase.   |
| `start_dw_kernel_size`   | `int`   | Kernel size for the **initial depthwise convolution**. Helps control local receptive field at the beginning of the block.                      |
| `middle_dw_kernel_size`  | `int`   | Kernel size for the **middle depthwise convolution**, possibly the main feature extraction part of the block.                                  |
| `stride`                 | `int`   | Stride used in depthwise convolution, controls spatial downsampling (1 = no downsampling, 2 = halve spatial size).                             |
| `middle_dw_downsample`   | `bool`  | If `True`, the middle depthwise conv uses `stride > 1` for downsampling; if `False`, no downsampling in the middle.                            |
| `use_layer_scale`        | `bool`  | If `True`, applies a **learnable scaling factor** (Layer Scale) to the block's output. Often used to stabilize training in very deep networks. |
| `layer_scale_init_value` | `float` | Initial value for the layer scale parameter (e.g., `1e-5` is a small initial scaling). Used only if `use_layer_scale` is `True`.               |


| Code Line                                                     | Purpose                                                                                                                                                                              |
| ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `if use_layer_scale:`                                         | Conditional check to determine whether to apply **Layer Scale**.                                                                                                                     |
| `self.gamma = nn.Parameter(...)`                              | Creates a learnable parameter `gamma`, a 1D tensor of shape `(out_channels,)`, initialized with `layer_scale_init_value`. This parameter will later scale the block’s output.        |
| `torch.ones((out_channels))`                                  | Initializes the scale as a tensor of ones — one scale per output channel.                                                                                                            |
| `nn.Parameter(..., requires_grad=True)`                       | Ensures that `gamma` is a learnable parameter, updated via backpropagation.                                                                                                          |
| `self.use_layer_scale = use_layer_scale`                      | Stores the flag so that it can be used in `forward()` or elsewhere in the module.                                                                                                    |
| `self.identity = stride == 1 and in_channels == out_channels` | Indicates whether the block can use a **residual/identity connection** (i.e., skip connection), which is only valid if the input and output shapes match spatially and channel-wise. |


In [6]:
class UniversalInvertedBottleneck(nn.Module):
  def __init__(self,
               in_channels,
               out_channels,
               expand_ratio,
               start_dw_kernel_size,
               middle_dw_kernel_size,
               stride,
               middle_dw_downsample: bool = True,
               use_layer_scale:bool=False,
               layer_scale_init_value: float = 1e-5):
    super(UniversalInvertedBottleneck,self).__init__()
    self.start_dw_kernel_size=start_dw_kernel_size
    self.middle_dw_kernel_size=middle_dw_kernel_size

    if start_dw_kernel_size:
      self.start_dw_conv = nn.Conv2d(in_channels,in_channels,start_dw_kernel_size,stride if not middle_dw_downsample else 1,(start_dw_kernel_size-1)//2,groups=in_channels,bias=False)
      self.start_dw_norm=nn.BatchNorm2d(in_channels)


    expand_channels=make_divisible(in_channels*expand_ratio,8)
    self.expand_conv=nn.Conv2d(in_channels,expand_channels,1,1,bias=False)
    self.expand_norm=nn.BatchNorm2d(expand_channels)
    self.expand_act=nn.ReLU(inplace=True)

    if middle_dw_kernel_size:
      self.middle_dw_conv=nn.Conv2d(expand_channels,expand_channels,middle_dw_kernel_size,stride if middle_dw_downsample else 1,(middle_dw_kernel_size -1)//2,groups=expand_channels,bias=False)
      self.middle_dw_norm=nn.BatchNorm2d(expand_channels)
      self.middle_dw_act=nn.ReLU(inplace=True)

    self.proj_conv=nn.Conv2d(expand_channels,out_channels,1,1,bias=False)
    self.proj_norm=nn.BtachNorm2d(out_channels)

    if use_layer_scale:
      self.gamma = nn.Parameter(layer_scale_init_value*torch.ones((out_channels)),requires_grad=True)

    self.use_layer_scale=use_layer_scale
    self.identity=stride == 1 and in_channels == out_channels

  def forward(self,x):
    shortcut = x

    if self.start_dw_kernel_size:
      x=self.start_dw_conv(x)
      x=self.start_dw_norm(x)

    x=self.expand_conv(x)
    x=self.expand_norm(x)
    x=self.expand_act(x)

    if self.middle_dw_kernel_size:
      x=self.middle_dw_conv(x)
      x=self.middle_dw_norm(x)
      x=self.middle_dw_act(x)

    x=self.proj_conv(x)
    x=self.proj_norm(x)

    if self.use_layer_scale_scale:
      x=self.gamma*x

    return x + shortcut if self.identity else x



| Layer/Variable               | Description                                                                                                                         |
| ---------------------------- | ----------------------------------------------------------------------------------------------------------------------------------- |
| `self.avgpool`               | `nn.AdaptiveAvgPool2d((1, 1))` — Reduces each feature map to **1×1** (global average pooling), output shape becomes `(B, C, 1, 1)`. |
| `hidden_channels`            | Set to `1280` — used as the number of channels after the final pointwise conv layer. Common in lightweight CNNs.                    |
| `self.conv`                  | `ConvBN(c, 1280, 1)` — A `1×1` convolution that expands from `c` (last `out_channels`) to 1280. Applies batch norm + activation.    |
| `self.classifier`            | `nn.Linear(1280, num_classes)` — Final fully connected layer mapping from feature vector to output logits.                          |
| `self._initialize_weights()` | Custom method to initialize model weights (usually uses Xavier/He/Kaiming init).                                                    |


| Component     | Purpose                              |
| ------------- | ------------------------------------ |
| `x.size(0)`   | Keeps the **batch size** unchanged   |
| `-1`          | Flattens all remaining dimensions    |
| `x.view(...)` | Reshapes tensor without copying data |


| Module Type      | Initialization Method                   | Reason                                                  |
| ---------------- | --------------------------------------- | ------------------------------------------------------- |
| `nn.Conv2d`      | He (Kaiming) Normal with std `√(2 / n)` | Suitable for ReLU activation (helps maintain variance). |
|                  | Bias: zero                              | Standard to avoid early biasing.                        |
| `nn.BatchNorm2d` | Weights = 1, Biases = 0                 | Keeps initial feature scaling neutral.                  |
| `nn.Linear`      | Normal(0, 0.01) for weights             | Standard small weight init for fully connected layers.  |
|                  | Bias = 0                                | Keeps outputs unbiased at start.                        |


In [9]:
class MobileNetV4(nn.Module):
  def __init__(self,block_specs,num_classes=1000):
    super(MobileNetV4,self).__init__()

    c=3
    layers=[]
    for block_type,*block_cfg in block_specs:
      if block_type =='conv_bn':
        block=ConvBN
        k,s,f=block_cfg
        layers.append(block(c,f,k,s))
      elif block_type =='uib':
        block=UniversalInvertedBlock
        start_k,middle_k,s,f,e=block_cfg
        layers.append(block(c,f,e,start_k,middle_k,s))
      else:
        raise NotImplementedError
      c=f

      self.features=nn.Sequential(*layers)
      #building last several layers
      self.avgpool=nn.AdaptiveAvgPool2d((1,1))
      hidden_channels=1280
      self.conv=ConvBN(c,hidden_channels,1)
      self.classifier=nn.Linear(hidden_channels,num_classes)
      self._initialize_weights()

  def forward(self,x):
    x=self.features(x)
    x=self.avgpool(x)
    x=self.conv(x)
    x=x.view(x.size(0),-1)
    x=self.classifier(x)
    return x

  def _initialize_weights(self):
    for m in self.modules():
      if isinstance(m,nn.Conv2d):
        n=m.kernel_size[0]*m.kernel_size[1]*m.out_channels
        m.weight.data.normal_(0,math,sqrt(2./n))
        if m.bias is not None:
          m.bias.data.zero()
      elif isinstance(m,nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m,nn.Linear):
        m.weight.data.normal_(0,0.01)
        m.bias.data.zero_()




In [11]:
def mobilenetv4_conv_small(**kwargs):
  block_specs=[
        # conv_bn, kernel_size, stride, out_channels
        # uib, start_dw_kernel_size, middle_dw_kernel_size, stride, out_channels, expand_ratio
        # 112px
        ('conv_bn', 3, 2, 32),
        # 56px
        ('conv_bn', 3, 2, 32),
        ('conv_bn', 1, 1, 32),
        # 28px
        ('conv_bn', 3, 2, 96),
        ('conv_bn', 1, 1, 64),
        # 14px
        ('uib', 5, 5, 2, 96, 3.0),  # ExtraDW
        ('uib', 0, 3, 1, 96, 2.0),  # IB
        ('uib', 0, 3, 1, 96, 2.0),  # IB
        ('uib', 0, 3, 1, 96, 2.0),  # IB
        ('uib', 0, 3, 1, 96, 2.0),  # IB
        ('uib', 3, 0, 1, 96, 4.0),  # ConvNext
        # 7px
        ('uib', 3, 3, 2, 128, 6.0),  # ExtraDW
        ('uib', 5, 5, 1, 128, 4.0),  # ExtraDW
        ('uib', 0, 5, 1, 128, 4.0),  # IB
        ('uib', 0, 5, 1, 128, 3.0),  # IB
        ('uib', 0, 3, 1, 128, 4.0),  # IB
        ('uib', 0, 3, 1, 128, 4.0),  # IB
        ('conv_bn', 1, 1, 960),  # Conv
    ]
  return MobileNetV4(block_specs, **kwargs)




In [12]:
def mobilenetv4_conv_medium(*kwargs):
  block_specs = [
        ('conv_bn', 3, 2, 32),
        ('conv_bn', 3, 2, 128),
        ('conv_bn', 1, 1, 48),
        # 3rd stage
        ('uib', 3, 5, 2, 80, 4.0),
        ('uib', 3, 3, 1, 80, 2.0),
        # 4th stage
        ('uib', 3, 5, 2, 160, 6.0),
        ('uib', 3, 3, 1, 160, 4.0),
        ('uib', 3, 3, 1, 160, 4.0),
        ('uib', 3, 5, 1, 160, 4.0),
        ('uib', 3, 3, 1, 160, 4.0),
        ('uib', 3, 0, 1, 160, 4.0),
        ('uib', 0, 0, 1, 160, 2.0),
        ('uib', 3, 0, 1, 160, 4.0),
        # 5th stage
        ('uib', 5, 5, 2, 256, 6.0),
        ('uib', 5, 5, 1, 256, 4.0),
        ('uib', 3, 5, 1, 256, 4.0),
        ('uib', 3, 5, 1, 256, 4.0),
        ('uib', 0, 0, 1, 256, 4.0),
        ('uib', 3, 0, 1, 256, 4.0),
        ('uib', 3, 5, 1, 256, 2.0),
        ('uib', 5, 5, 1, 256, 4.0),
        ('uib', 0, 0, 1, 256, 4.0),
        ('uib', 0, 0, 1, 256, 4.0),
        ('uib', 5, 0, 1, 256, 2.0),
        # FC layers
        ('conv_bn',1,1960),
  ]
  return MobileNetV4(block_specs,**kwargs)



In [14]:
def mobilenetv4_conv_large(**kwargs):
  block_specs = [
        ('conv_bn', 3, 2, 24),
        ('conv_bn', 3, 2, 96),
        ('conv_bn', 1, 1, 48),
        ('uib', 3, 5, 2, 96, 4.0),
        ('uib', 3, 3, 1, 96, 4.0),
        ('uib', 3, 5, 2, 192, 4.0),
        ('uib', 3, 3, 1, 192, 4.0),
        ('uib', 3, 3, 1, 192, 4.0),
        ('uib', 3, 3, 1, 192, 4.0),
        ('uib', 3, 5, 1, 192, 4.0),
        ('uib', 5, 3, 1, 192, 4.0),
        ('uib', 5, 3, 1, 192, 4.0),
        ('uib', 5, 3, 1, 192, 4.0),
        ('uib', 5, 3, 1, 192, 4.0),
        ('uib', 5, 3, 1, 192, 4.0),
        ('uib', 3, 0, 1, 192, 4.0),
        ('uib', 5, 5, 2, 512, 4.0),
        ('uib', 5, 5, 1, 512, 4.0),
        ('uib', 5, 5, 1, 512, 4.0),
        ('uib', 5, 5, 1, 512, 4.0),
        ('uib', 5, 0, 1, 512, 4.0),
        ('uib', 5, 3, 1, 512, 4.0),
        ('uib', 5, 0, 1, 512, 4.0),
        ('uib', 5, 0, 1, 512, 4.0),
        ('uib', 5, 3, 1, 512, 4.0),
        ('uib', 5, 5, 1, 512, 4.0),
        ('uib', 5, 0, 1, 512, 4.0),
        ('uib', 5, 0, 1, 512, 4.0),
        ('uib', 5, 0, 1, 512, 4.0),
        ('conv_bn', 1, 1, 960),
    ]

  return MobileNetV4(block_specs, **kwargs)
