In [1]:
%load_ext autoreload
%autoreload 2

import torch
from src.architectures.feature_extractors.mobilenet import MobileNet
from src.module.multiclass import MulticlassImageClassifier
from src.architectures.head import MulticlassLinearClassificationHead
from pytorch_lightning import seed_everything

from torchinfo import summary

count_params = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)

seed_everything(42)
img = torch.rand(1, 3, 224, 224)

# summary(net, input_data=img)

Global seed set to 42


In [6]:
v2_backbone = MobileNet(3, 1.0, "v2")
v3_small_backbone = MobileNet(3, 1.0, "v3_small")
v3_large_backbone = MobileNet(3, 1.0, "v3_large")


v2 = MulticlassImageClassifier(
    v2_backbone,
    head = MulticlassLinearClassificationHead(v2_backbone.out_dim, num_classes=1000),
    classes=[str(i) for i in range(1000)]
).net

v3_small = MulticlassImageClassifier(
    v3_small_backbone,
    head = MulticlassLinearClassificationHead(v3_small_backbone.out_dim, num_classes=1000),
    classes=[str(i) for i in range(1000)]
).net

v3_large = MulticlassImageClassifier(
    v3_large_backbone,
    head = MulticlassLinearClassificationHead(v3_large_backbone.out_dim, num_classes=1000),
    classes=[str(i) for i in range(1000)]
).net

  rank_zero_warn(
  rank_zero_warn(


In [7]:
from torchvision.models import mobilenet_v2, mobilenet_v3_small, mobilenet_v3_large 
v2_torch = mobilenet_v2()
v3_small_torch = mobilenet_v3_small()
v3_large_torch = mobilenet_v3_large()


In [25]:
count_params(v3_large_torch), count_params(v3_large), 

(5483032, 5483320)

In [26]:
summary(v3_large_torch, input_data=img, depth=2)

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                             Output Shape              Param #
MobileNetV3                                        [1, 1000]                 --
├─Sequential: 1-1                                  [1, 960, 7, 7]            --
│    └─Conv2dNormActivation: 2-1                   [1, 16, 112, 112]         464
│    └─InvertedResidual: 2-2                       [1, 16, 112, 112]         464
│    └─InvertedResidual: 2-3                       [1, 24, 56, 56]           3,440
│    └─InvertedResidual: 2-4                       [1, 24, 56, 56]           4,440
│    └─InvertedResidual: 2-5                       [1, 40, 28, 28]           10,328
│    └─InvertedResidual: 2-6                       [1, 40, 28, 28]           20,992
│    └─InvertedResidual: 2-7                       [1, 40, 28, 28]           20,992
│    └─InvertedResidual: 2-8                       [1, 80, 14, 14]           32,080
│    └─InvertedResidual: 2-9                       [1, 80, 14, 14]           34,760
│    └─

In [27]:
summary(v3_large, input_data=img, depth=3)

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                                  Output Shape              Param #
Sequential                                              [1, 1000]                 --
├─MobileNetV3Large: 1-1                                 [1, 1280, 1, 1]           --
│    └─Sequential: 2-1                                  --                        --
│    │    └─CNNBlock: 3-1                               [1, 16, 112, 112]         464
│    │    └─Bottleneck: 3-2                             [1, 16, 112, 112]         752
│    │    └─Bottleneck: 3-3                             [1, 24, 56, 56]           3,440
│    │    └─Bottleneck: 3-4                             [1, 24, 56, 56]           4,440
│    │    └─Bottleneck: 3-5                             [1, 40, 28, 28]           10,328
│    │    └─Bottleneck: 3-6                             [1, 40, 28, 28]           20,992
│    │    └─Bottleneck: 3-7                             [1, 40, 28, 28]           20,992
│    │    └─Bottleneck: 3-8             

In [24]:
v3_small_torch.features[-1]

Conv2dNormActivation(
  (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
  (2): Hardswish()
)

In [19]:
v3_small[0].net.last_stage

EfficientLastStage(
  (conv_1): SEBlock(
    (block): CNNBlock(
      (conv): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (batch_norm): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation_fn): Hardswish()
    )
    (squeeze): Sequential(
      (0): AdaptiveAvgPool2d(output_size=(1, 1))
      (1): Flatten(start_dim=1, end_dim=-1)
    )
    (excitation): Sequential(
      (0): Linear(in_features=576, out_features=144, bias=True)
      (1): ReLU()
      (2): Linear(in_features=144, out_features=576, bias=True)
      (3): Hardsigmoid()
    )
  )
  (global_pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (conv_2): CNNBlock(
    (conv): Conv2d(576, 1024, kernel_size=(1, 1), stride=(1, 1))
    (activation_fn): Hardswish()
  )
)