# FLOPs

```{todo} 待重写
```

参考：[FLOPs](https://zhuanlan.zhihu.com/p/663566912?utm_psn=1701033625972346880)

FLOPs（Floating Point Operations，浮点运算数）和 MACs（Multiply-Accumulate Operations，乘加运算数）是常用于计算深度学习模型计算复杂度的指标。它们是快速、简单地了解执行给定计算所需的算术运算数量的方法。例如，在为边缘设备使用不同的模型架构（如 MobileNet 或 DenseNet）时，人们使用 MACs 或 FLOPs 来估计模型性能。同时，使用“估计”这个词的原因是，这两个指标都是近似值，而不是实际运行时性能模型的捕获。然而，它们仍然可以提供有关能量消耗或计算要求的非常有用的洞察，这在边缘计算中非常有用。

FLOPs 特指对浮点数进行的加法、减法、乘法和除法等浮点运算的数量。这些运算在机器学习中涉及的许多数学计算中非常常见，例如矩阵乘法、激活函数和梯度计算。FLOPs 通常用于衡量模型或模型内特定操作的计算成本或复杂度。当需要提供所需算术运算总数的估计时，这非常有用，通常用于衡量计算效率的上下文中。

另一方面，MACs 只计算乘加操作的数量，这涉及将两个数字相乘并相加结果。这种运算是许多线性代数操作的基础，例如矩阵乘法、卷积和点积。在严重依赖线性代数运算的模型中，如卷积神经网络（CNN），MACs 通常用作计算复杂度的更具体度量。

```{note}
全大写的 FLOPS 是“每秒浮点运算数”的缩写，指的是计算速度，通常用作硬件性能的度量。FLOPS 中的“S”表示“秒”，与“P”（作为“每”）一起，通常用于表示比率。
```


一般AI社区的共识是，一个 MAC 大约等于两个 FLOP。

In [1]:
import torch
from torch import nn, Tensor
from torch.fx.node import Node
from tabulate import tabulate
from torch_book.scan.flop import get_FLOPs

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(5, 4, bias=True)
        self.__add_one = True

    def forward(self, x: Tensor):
        x = self.layer(x)
        if self.__add_one:
            x += 1.
        return x

In [2]:
model = SimpleModel()
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(1, 5)

In [3]:
result_header = ['node_name', 'node_op', 'op_target', 'nn_module_stack[-1]', 'flops']
result_table = get_FLOPs(gm, sample_input)
__missing_values = [''] * 4 + ['ERROR']
# table_str = tabulate(result_table, result_header, tablefmt='rst', missingval=__missing_values)
table_str = tabulate(result_table, result_header, tablefmt='fancy_grid', missingval=__missing_values)
print(table_str)
valid_flops_list = list(filter(lambda _f: isinstance(_f, int), list(zip(*result_table))[-1]))
total_flops = sum(valid_flops_list)
num_empty_flops = len(result_table) - len(valid_flops_list)
print(f"total_flops = {total_flops:3,}")

╒═════════════╤═══════════════╤═════════════════════════╤═══════════════════════╤═════════╕
│ node_name   │ node_op       │ op_target               │ nn_module_stack[-1]   │   flops │
╞═════════════╪═══════════════╪═════════════════════════╪═══════════════════════╪═════════╡
│ x           │ placeholder   │ x                       │                       │       0 │
├─────────────┼───────────────┼─────────────────────────┼───────────────────────┼─────────┤
│ layer       │ call_module   │ layer                   │ Linear                │      40 │
├─────────────┼───────────────┼─────────────────────────┼───────────────────────┼─────────┤
│ add         │ call_function │ <built-in function add> │                       │       4 │
├─────────────┼───────────────┼─────────────────────────┼───────────────────────┼─────────┤
│ output      │ output        │ output                  │                       │       0 │
╘═════════════╧═══════════════╧═════════════════════════╧═══════════════════════

## 测试 timm

In [4]:
import timm
vit = timm.create_model('vit_base_patch16_224').eval()
resnet = timm.create_model('resnet50').eval()

In [5]:
gm = torch.fx.symbolic_trace(vit)
sample_input = torch.randn([1, 3, 224, 224])
result_header = ['node_name', 'node_op', 'op_target', 'nn_module_stack[-1]', 'flops']
result_table = get_FLOPs(gm, sample_input)
__missing_values = [''] * 4 + ['ERROR']
# table_str = tabulate(result_table, result_header, tablefmt='rst', missingval=__missing_values)
table_str = tabulate(result_table, result_header, tablefmt='fancy_grid', missingval=__missing_values)
print(table_str)
valid_flops_list = list(filter(lambda _f: isinstance(_f, int), list(zip(*result_table))[-1]))
total_flops = sum(valid_flops_list)
num_empty_flops = len(result_table) - len(valid_flops_list)
print(f"total_flops = {total_flops:3,}")

╒═════════════════════════════════╤═══════════════╤════════════════════════════════════════════════════════╤═══════════════════════╤════════════════╕
│ node_name                       │ node_op       │ op_target                                              │ nn_module_stack[-1]   │ flops          │
╞═════════════════════════════════╪═══════════════╪════════════════════════════════════════════════════════╪═══════════════════════╪════════════════╡
│ x                               │ placeholder   │ x                                                      │                       │ 0              │
├─────────────────────────────────┼───────────────┼────────────────────────────────────────────────────────┼───────────────────────┼────────────────┤
│ getattr_1                       │ call_function │ <built-in function getattr>                            │ PatchEmbed            │ 0              │
├─────────────────────────────────┼───────────────┼─────────────────────────────────────────────────

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


In [6]:
gm = torch.fx.symbolic_trace(resnet)
sample_input = torch.randn([1, 3, 224, 224])
result_header = ['node_name', 'node_op', 'op_target', 'nn_module_stack[-1]', 'flops']
result_table = get_FLOPs(gm, sample_input)
__missing_values = [''] * 4 + ['ERROR']
# table_str = tabulate(result_table, result_header, tablefmt='rst', missingval=__missing_values)
table_str = tabulate(result_table, result_header, tablefmt='fancy_grid', missingval=__missing_values)
print(table_str)
valid_flops_list = list(filter(lambda _f: isinstance(_f, int), list(zip(*result_table))[-1]))
total_flops = sum(valid_flops_list)
num_empty_flops = len(result_table) - len(valid_flops_list)
print(f"total_flops = {total_flops:3,}")

╒═══════════════════════╤═══════════════╤═════════════════════════╤═══════════════════════╤═══════════╕
│ node_name             │ node_op       │ op_target               │ nn_module_stack[-1]   │     flops │
╞═══════════════════════╪═══════════════╪═════════════════════════╪═══════════════════════╪═══════════╡
│ x                     │ placeholder   │ x                       │                       │         0 │
├───────────────────────┼───────────────┼─────────────────────────┼───────────────────────┼───────────┤
│ conv1                 │ call_module   │ conv1                   │ Conv2d                │ 235225088 │
├───────────────────────┼───────────────┼─────────────────────────┼───────────────────────┼───────────┤
│ bn1                   │ call_module   │ bn1                     │ BatchNorm2d           │   3211264 │
├───────────────────────┼───────────────┼─────────────────────────┼───────────────────────┼───────────┤
│ act1                  │ call_module   │ act1                  

In [8]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final
from timm.layers import use_fused_attn, Mlp, DropPath


class Attention(nn.Module):
    '''
    REF: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
    '''
    fused_attn: Final[bool]

    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            qk_norm=False,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    @torch.no_grad()
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class Block(nn.Module):
    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            qk_norm=False,
            proj_drop=0.,
            attn_drop=0.,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            mlp_layer=Mlp,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    @torch.no_grad()
    def forward(self, x):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x

In [9]:
C = 768

# Define the model: an attention block (refer to "timm": https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py)
block = Block(C, num_heads=2, qkv_bias=True)
block.attn.fused_attn = False
block.eval()
model = block

# Input
# N: number of tokens
N = 14**2 + 1
B = 1

gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn([B, N, C])
result_header = ['node_name', 'node_op', 'op_target', 'nn_module_stack[-1]', 'flops']
result_table = get_FLOPs(gm, sample_input)
__missing_values = [''] * 4 + ['ERROR']
# table_str = tabulate(result_table, result_header, tablefmt='rst', missingval=__missing_values)
table_str = tabulate(result_table, result_header, tablefmt='fancy_grid', missingval=__missing_values)
print(table_str)
valid_flops_list = list(filter(lambda _f: isinstance(_f, int), list(zip(*result_table))[-1]))
total_flops = sum(valid_flops_list)
num_empty_flops = len(result_table) - len(valid_flops_list)
print(f"total_flops = {total_flops:3,}")

╒════════════════╤═══════════════╤═════════════════════════════╤═══════════════════════╤═══════════╕
│ node_name      │ node_op       │ op_target                   │ nn_module_stack[-1]   │     flops │
╞════════════════╪═══════════════╪═════════════════════════════╪═══════════════════════╪═══════════╡
│ x              │ placeholder   │ x                           │                       │         0 │
├────────────────┼───────────────┼─────────────────────────────┼───────────────────────┼───────────┤
│ norm1          │ call_module   │ norm1                       │ LayerNorm             │    605184 │
├────────────────┼───────────────┼─────────────────────────────┼───────────────────────┼───────────┤
│ getattr_1      │ call_function │ <built-in function getattr> │ Attention             │         0 │
├────────────────┼───────────────┼─────────────────────────────┼───────────────────────┼───────────┤
│ getitem        │ call_function │ <built-in function getitem> │ Attention             │   