## Setup

In [None]:
!pip install -q fvcore
!pip install -q git+https://github.com/rwightman/pytorch-image-models

In [None]:
from fvcore.nn import FlopCountAnalysis
import torch
import timm

## BiT Models

In [None]:
all_bit_models = timm.list_models("*bit*")
all_bit_models

In [None]:
input = torch.randn(1, 3, 224, 224)
flop_map_bit = {}

for bit_model_name in all_bit_models:
    if not any(x in bit_model_name for x in ["teacher", "distilled", "in21k"]):
        bit_model = timm.create_model(bit_model_name)
        flops = FlopCountAnalysis(bit_model, input)
        flops = flops.total() / 1e6
        flop_map_bit.update({bit_model_name: f"{flops:.3f} M"})

In [None]:
flop_map_bit

## ViT Models

In [None]:
all_vit_models = timm.list_models("vit*")
all_vit_models

In [None]:
input = torch.randn(1, 3, 224, 224)
flop_map_vit = {}

for vit_model_name in all_vit_models:
    if not any(x in vit_model_name for x in ["384", "in21k", "r26", "r50", "resnet", "tiny", "miil", "sam"]):
        vit_model = timm.create_model(vit_model_name)
        flops = FlopCountAnalysis(vit_model, input)
        flops = flops.total() / 1e6
        flop_map_vit.update({vit_model_name: f"{flops:.3f} M"})

In [None]:
flop_map_vit

In [None]:
# 79.086%
input = torch.randn(1, 3, 224, 224)
vit_model = timm.create_model("vit_small_patch16_224")
flops = FlopCountAnalysis(vit_model, input)
flops.total() / 1e6

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
Unsupported operator aten::add encountered 25 time(s)
Unsupported operator aten::mul encountered 12 time(s)
Unsupported operator aten::softmax encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.



4608.338304

## Other Models

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
other_models = [
    "gernet_l",
    "gernet_m",
    "gernet_s",
    "skresnet18",
    "skresnet34",
    "skresnext50_32x4d",
    "gc_efficientnetv2_rw_t"
]

input = torch.randn(1, 3, 224, 224)
flop_map_others = {}
parameters_others = {}

for other_model_name in other_models:
    print(other_model_name)
    other_model = timm.create_model(other_model_name).eval()
    flops = FlopCountAnalysis(other_model, input)
    flops = flops.total() / 1e6

    parameters = count_parameters(other_model) / 1e6
    parameters_others.update({other_model_name: f"{parameters:.3f} M"})
    flop_map_others.update({other_model_name: f"{flops:.3f} M"})

In [None]:
parameters_others, flop_map_others

## ResNet50

In [None]:
import torchvision

resnet50 = torchvision.models.resnet50()
input = torch.randn(1, 3, 224, 224)
flops = FlopCountAnalysis(resnet50, input)
flops = flops.total() / 1e6
flops