# FlopCounterMode

参考：[the-ideal-pytorch-flop-counter-with-torch-dispatch](https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505)

1.能在算子级别计数浮点运算次数，2.（可选地）在模块层级聚合这些计数，3.捕获反向传播中的浮点运算次数，4.并在即时执行模式下工作。哦，你还可以用它通过任意变换（如 vmap）来计算雅可比矩阵或海森矩阵的浮点运算次数！

In [1]:
from torch.utils.flop_counter import FlopCounterMode

In [2]:
FlopCounterMode??

[0;31mInit signature:[0m
[0mFlopCounterMode[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmods[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mnn[0m[0;34m.[0m[0mmodules[0m[0;34m.[0m[0mmodule[0m[0;34m.[0m[0mModule[0m[0;34m,[0m [0mlist[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mnn[0m[0;34m.[0m[0mmodules[0m[0;34m.[0m[0mmodule[0m[0;34m.[0m[0mModule[0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdepth[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m2[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdisplay[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcustom_mapping[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mdict[0m[0;34m[[0m[0mAny[0m[0;34m,[0m [0mAny[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[

In [3]:
import torch
import torchvision.models as models


with torch.device('meta'):
    inp = torch.randn((8, 3, 224, 224))
    mod = models.resnet18()
with FlopCounterMode() as flop_counter:
    mod(inp)

Module                   FLOP    % Total
--------------------  -------  ---------
ResNet                29.025B    100.00%
 - aten.convolution   29.017B     99.97%
 - aten.addmm          0.008B      0.03%
 ResNet.conv1          1.888B      6.51%
  - aten.convolution   1.888B      6.51%
 ResNet.fc             0.008B      0.03%
  - aten.addmm         0.008B      0.03%
 ResNet.layer1         7.399B     25.49%
  - aten.convolution   7.399B     25.49%
 ResNet.layer2         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer3         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer4         6.577B     22.66%
  - aten.convolution   6.577B     22.66%


In [4]:
import torch
import torchvision.models as models


with torch.device('cpu'):
    inp = torch.randn((8, 3, 224, 224))
    mod = models.resnet18()
with FlopCounterMode() as flop_counter:
    mod(inp)

Module                   FLOP    % Total
--------------------  -------  ---------
ResNet                29.025B    100.00%
 - aten.convolution   29.017B     99.97%
 - aten.addmm          0.008B      0.03%
 ResNet.conv1          1.888B      6.51%
  - aten.convolution   1.888B      6.51%
 ResNet.fc             0.008B      0.03%
  - aten.addmm         0.008B      0.03%
 ResNet.layer1         7.399B     25.49%
  - aten.convolution   7.399B     25.49%
 ResNet.layer2         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer3         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer4         6.577B     22.66%
  - aten.convolution   6.577B     22.66%


In [None]:
import torch
import torchvision.models as models


with torch.device('meta'):
    inp = torch.randn((8, 3, 224, 224))
    mod = models.resnet18()
with FlopCounterMode() as flop_counter:
    mod(inp)

Module                   FLOP    % Total
--------------------  -------  ---------
ResNet                29.025B    100.00%
 - aten.convolution   29.017B     99.97%
 - aten.addmm          0.008B      0.03%
 ResNet.conv1          1.888B      6.51%
  - aten.convolution   1.888B      6.51%
 ResNet.fc             0.008B      0.03%
  - aten.addmm         0.008B      0.03%
 ResNet.layer1         7.399B     25.49%
  - aten.convolution   7.399B     25.49%
 ResNet.layer2         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer3         6.577B     22.66%
  - aten.convolution   6.577B     22.66%
 ResNet.layer4         6.577B     22.66%
  - aten.convolution   6.577B     22.66%


In [None]:
from torch.utils.flop_counter import FlopCounterMode
import torch
import torchvision.models as models


with torch.device('meta'):
    inp = torch.randn((1, 3, 224, 224))
    mod = models.resnet18()
with FlopCounterMode() as flop_counter:
    mod(inp).sum()

Module                     FLOP    % Total
--------------------  ---------  ---------
ResNet                3628.147M    100.00%
 - aten.convolution   3627.123M     99.97%
 - aten.addmm            1.024M      0.03%
 ResNet.conv1          236.028M      6.51%
  - aten.convolution   236.028M      6.51%
 ResNet.fc               1.024M      0.03%
  - aten.addmm           1.024M      0.03%
 ResNet.layer1         924.844M     25.49%
  - aten.convolution   924.844M     25.49%
 ResNet.layer2         822.084M     22.66%
  - aten.convolution   822.084M     22.66%
 ResNet.layer3         822.084M     22.66%
  - aten.convolution   822.084M     22.66%
 ResNet.layer4         822.084M     22.66%
  - aten.convolution   822.084M     22.66%


In [18]:
print(f"{flop_counter.get_total_flops()/1e9:.2f} GFLOPs")

3.63 GFLOPs
