<a href="https://colab.research.google.com/github/patrick22414/colab-FLOPs/blob/main/FLOPs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ptflops

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

import ptflops

In [None]:
class MBConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        *,
        stride,
        expansion=None,
        mid_channels=None,
        se_ratio=0,
    ):
        super().__init__()

        assert expansion or mid_channels
        if not mid_channels:
            mid_channels = in_channels * expansion

        self.inv_bottleneck = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        )

        self.depthwise_conv = nn.Sequential(
            nn.Conv2d(
                mid_channels,
                mid_channels,
                kernel_size,
                stride=stride,
                padding=kernel_size // 2,
                groups=mid_channels,
                bias=False,
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        )

        if se_ratio > 0:
            se_mid_channels = mid_channels // se_ratio
            self.se_module = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(mid_channels, se_mid_channels, 1),
                nn.ReLU(),
                nn.Conv2d(se_mid_channels, mid_channels, 1),
                nn.Sigmoid(),
            )
        else:
            self.se_module = None

        self.pointwise_conv = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        )

        self.use_shortcut = (stride == 1 and in_channels == out_channels)

    def forward(self, x):
        y = self.inv_bottleneck(x)
        y = self.depthwise_conv(y)
        if self.se_module is not None:
            y_1 = self.se_module(y)
            y = y * y_1
        y = self.pointwise_conv(y)
        if self.use_shortcut:
            y = y + x

        return y


class MBStage(nn.Module):
    def __init__(
        self,
        num_blocks,
        in_channels,
        out_channels,
        kernel_size,
        *,
        stride=1,
        expansion=None,
        mid_channels=None,
        se_ratio=0,
    ):
        super().__init__()

        assert expansion or mid_channels
        if not isinstance(expansion, list):
            expansion = [expansion] * num_blocks
        if not isinstance(mid_channels, list):
            mid_channels = [mid_channels] * num_blocks

        self.blocks = nn.Sequential(*[
            MBConv(
                in_channels if i == 0 else out_channels,
                out_channels,
                kernel_size,
                stride=stride if i == 0 else 1,
                expansion=e,
                mid_channels=mc,
                se_ratio=se_ratio,
            )
            for i, (e, mc) in enumerate(zip(expansion, mid_channels))
        ])

    def forward(self, x):
        y = self.blocks(x)

        return y


class MobileNetV3(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            MBConv(16, 16, 3, stride=1, mid_channels=16)
        )
        self.stages = nn.Sequential(
            MBStage(2, 16, 24, 3, stride=2, mid_channels=[64, 72]),
            MBStage(3, 24, 40, 5, stride=2, expansion=3, se_ratio=6),
            MBStage(4, 40, 80, 3, stride=2, mid_channels=[240, 200, 184, 184]),
            MBStage(2, 80, 112, 3, stride=1, expansion=6, se_ratio=6),
            MBStage(3, 112, 160, 5, stride=2, expansion=6, se_ratio=6),
        )
        self.one_more = nn.Sequential(
            nn.Conv2d(160, 960, 1, bias=False),
            nn.BatchNorm2d(960),
            nn.ReLU(),
        )
        # self.classifier = nn.Sequential(
        #     nn.AdaptiveAvgPool2d(1),
        #     nn.Conv2d(960, 1280, 1, bias=False),
        #     nn.ReLU(),
        #     nn.Conv2d(1280, 1000, 1, bias=False),
        # )

    def forward(self, x):
        y = self.stem(x)
        y = self.stages(y)
        y = self.one_more(y)

        return y


class EfficientNetB0(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            MBConv(32, 16, 3, stride=1, expansion=1)
        )
        self.stages = nn.Sequential(
            MBStage(2, 16, 24, 3, stride=2, expansion=6, se_ratio=4),
            MBStage(2, 24, 40, 5, stride=2, expansion=6, se_ratio=4),
            MBStage(3, 40, 80, 3, stride=2, expansion=6, se_ratio=4),
            MBStage(3, 80, 112, 5, stride=1, expansion=6, se_ratio=4),
            MBStage(4, 112, 192, 5, stride=2, expansion=6, se_ratio=4),
            MBStage(1, 192, 320, 3, stride=1, expansion=6, se_ratio=4),
        )
        self.one_more = nn.Sequential(
            nn.Conv2d(320, 1280, 1, bias=False),
            nn.BatchNorm2d(1280),
            nn.ReLU(),
        )
        # self.classifier = nn.Sequential(
        #     nn.AdaptiveAvgPool2d(1),
        #     nn.Conv2d(1280, 1000, 1, bias=False),
        # )

    def forward(self, x):
        y = self.stem(x)
        y = self.stages(y)
        y = self.one_more(y)

        return y

In [None]:
reso = 320
input_size = (3, reso, reso)

###
mbv3 = MobileNetV3()

with open("ptflops-mbv3.txt", "w") as fo:
    res = ptflops.get_model_complexity_info(
        mbv3,
        input_size,
        print_per_layer_stat=True,
        ost=fo,
    )
print("MACs: {}; Params: {}".format(*res))

###
enb0 = EfficientNetB0()

with open("ptflops-enb0.txt", "w") as fo:
    res = ptflops.get_model_complexity_info(
        enb0,
        input_size,
        print_per_layer_stat=True,
        ost=fo,
    )
print("MACs: {}; Params: {}".format(*res))

In [None]:
!grep -A1 'MBStage' ptflops-mbv3.txt
print("\n" * 2)
!grep -A1 'MBStage' ptflops-enb0.txt