<a href="https://colab.research.google.com/github/patrick22414/colab-FLOPs/blob/main/FLOPs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
!pip install ptflops



In [34]:
import io

import torch
from torch import nn
from torch.nn import functional as F

import torchvision

import ptflops

In [35]:
class MBConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        *,
        stride,
        expansion=None,
        mid_channels=None,
        se_ratio = 0,
    ):
        super().__init__()

        assert expansion or mid_channels
        if not mid_channels:
            mid_channels = in_channels * expansion

        self.inv_bottleneck = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        )

        self.depthwise_conv = nn.Sequential(
            nn.Conv2d(
                mid_channels,
                mid_channels,
                kernel_size,
                stride=stride,
                padding=kernel_size // 2,
                groups=mid_channels,
                bias=False,
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        )

        if se_ratio > 0:
            se_mid_channels = mid_channels // se_ratio
            self.se_module = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(mid_channels, se_mid_channels, 1),
                nn.ReLU(),
                nn.Conv2d(se_mid_channels, mid_channels, 1),
                nn.Sigmoid(),
            )
        else:
            self.se_module = None

        self.pointwise_conv = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        )

        self.use_shortcut = (stride == 1 and in_channels == out_channels)

    def forward(self, x):
        y = self.inv_bottleneck(x)
        y = self.depthwise_conv(y)
        if self.se_module is not None:
            y_1 = self.se_module(y)
            y = y * y_1
        y = self.pointwise_conv(y)
        if self.use_shortcut:
            y = y + x

        return y


class MBStage(nn.Module):
    def __init__(
        self,
        num_blocks,
        in_channels,
        out_channels,
        kernel_size,
        *,
        stride=1,
        expansion=None,
        mid_channels=None,
        se_ratio=0,
    ):
        super().__init__()

        assert expansion or mid_channels
        if not isinstance(expansion, list):
            expansion = [expansion] * num_blocks
        if not isinstance(mid_channels, list):
            mid_channels = [mid_channels] * num_blocks

        self.blocks = nn.Sequential(*[
            MBConv(
                in_channels if i == 0 else out_channels,
                out_channels,
                kernel_size,
                stride=stride if i == 0 else 1,
                expansion=e,
                mid_channels=mc,
                se_ratio=se_ratio,
            )
            for i, (e, mc) in enumerate(zip(expansion, mid_channels))
        ])

    def forward(self, x):
        y = self.blocks(x)

        return y


class MobileNetV3(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            MBConv(16, 16, kernel_size=3, stride=1, mid_channels=16)
        )
        self.stages = nn.Sequential(
            MBStage(2, 16, 24, kernel_size=3, stride=2, mid_channels=[64, 72]),
            MBStage(3, 24, 40, kernel_size=5, stride=2, expansion=3, se_ratio=6),
            MBStage(4, 40, 80, kernel_size=3, stride=2, mid_channels=[240, 200, 184, 184]),
            MBStage(2, 80, 112, kernel_size=3, stride=1, expansion=6, se_ratio=6),
            MBStage(3, 112, 160, kernel_size=5, stride=2, expansion=6, se_ratio=6),
        )
        self.one_more = nn.Sequential(
            nn.Conv2d(160, 960, kernel_size=1, bias=False),
            nn.BatchNorm2d(960),
            nn.ReLU(),
        )
        # self.classifier = nn.Sequential(
        #     nn.AdaptiveAvgPool2d(1),
        #     nn.Conv2d(960, 1280, kernel_size=1, bias=False),
        #     nn.ReLU(),
        #     nn.Conv2d(1280, 1000, kernel_size=1, bias=False),
        # )

    def forward(self, x):
        y = self.stem(x)
        y = self.stages(y)
        y = self.one_more(y)

        return y


class EfficientNetB0(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            MBConv(32, 16, kernel_size=3, stride=1, expansion=1)
        )
        self.stages = nn.Sequential(
            MBStage(2, 16, 24, kernel_size=3, stride=2, expansion=6, se_ratio=4),
            MBStage(2, 24, 40, kernel_size=5, stride=2, expansion=6, se_ratio=4),
            MBStage(3, 40, 80, kernel_size=3, stride=2, expansion=6, se_ratio=4),
            MBStage(3, 80, 112, kernel_size=5, stride=1, expansion=6, se_ratio=4),
            MBStage(4, 112, 192, kernel_size=5, stride=2, expansion=6, se_ratio=4),
            MBStage(1, 192, 320, kernel_size=3, stride=1, expansion=6, se_ratio=4),
        )
        self.one_more = nn.Sequential(
            nn.Conv2d(320, 1280, kernel_size=1, bias=False),
            nn.BatchNorm2d(1280),
            nn.ReLU(),
        )
        # self.classifier = nn.Sequential(
        #     nn.AdaptiveAvgPool2d(1),
        #     nn.Conv2d(1280, 1000, kernel_size=1, bias=False),
        # )

    def forward(self, x):
        y = self.stem(x)
        y = self.stages(y)
        y = self.one_more(y)

        return y

In [36]:
reso = 300
input_size = (3, reso, reso)

###
mbv3 = MobileNetV3()

with open("ptflops-mbv3.txt", "w") as fo:
    res = ptflops.get_model_complexity_info(
        mbv3,
        input_size,
        print_per_layer_stat=True,
        ost=fo,
    )
print("MACs: {}; Params: {}".format(*res))

###
enb0 = EfficientNetB0()

with open("ptflops-enb0.txt", "w") as fo:
    res = ptflops.get_model_complexity_info(
        enb0,
        input_size,
        print_per_layer_stat=True,
        ost=fo,
    )
print("MACs: {}; Params: {}".format(*res))

MACs: 0.43 GMac; Params: 2.47 M
MACs: 0.8 GMac; Params: 7.14 M


In [37]:
!grep -A1 'MBStage' ptflops-mbv3.txt
!echo ==============================
!grep -A1 'MBStage' ptflops-enb0.txt

    (0): MBStage(
      0.008 M, 0.319% Params, 0.066 GMac, 15.420% MACs, 
--
    (1): MBStage(
      0.045 M, 1.815% Params, 0.057 GMac, 13.308% MACs, 
--
    (2): MBStage(
      0.131 M, 5.300% Params, 0.059 GMac, 13.704% MACs, 
--
    (3): MBStage(
      0.487 M, 19.722% Params, 0.095 GMac, 21.991% MACs, 
--
    (4): MBStage(
      1.641 M, 66.494% Params, 0.109 GMac, 25.334% MACs, 
    (0): MBStage(
      0.029 M, 0.409% Params, 0.113 GMac, 14.241% MACs, 
--
    (1): MBStage(
      0.079 M, 1.111% Params, 0.075 GMac, 9.455% MACs, 
--
    (2): MBStage(
      0.459 M, 6.428% Params, 0.084 GMac, 10.596% MACs, 
--
    (3): MBStage(
      1.016 M, 14.221% Params, 0.164 GMac, 20.571% MACs, 
--
    (4): MBStage(
      3.874 M, 54.235% Params, 0.189 GMac, 23.718% MACs, 
--
    (5): MBStage(
      1.27 M, 17.785% Params, 0.062 GMac, 7.729% MACs, 
