# 测试模型分析

In [2]:
import torch
from torchvision import models
from torch_book.scan.crawler import crawl_module

In [3]:
TORCHVISION_MODELS = [
    "alexnet",
    "googlenet",
    "vgg11",
    "vgg11_bn",
    "vgg13",
    "vgg13_bn",
    "vgg16",
    "vgg16_bn",
    "vgg19",
    "vgg19_bn",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "inception_v3",
    "squeezenet1_0",
    "squeezenet1_1",
    "wide_resnet50_2",
    "wide_resnet101_2",
    "densenet121",
    "densenet161",
    "densenet169",
    "densenet201",
    "resnext50_32x4d",
    "resnext101_32x8d",
    "mobilenet_v2",
    "shufflenet_v2_x0_5",
    "shufflenet_v2_x1_0",
    "shufflenet_v2_x1_5",
    "shufflenet_v2_x2_0",
    "mnasnet0_5",
    "mnasnet0_75",
    "mnasnet1_0",
    "mnasnet1_3",
]


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

margin = 4
headers = ["Model", "Params (M)", "FLOPs (G)", "MACs (G)", "DMAs (G)", "RF"]
max_w = [20, 10, 10, 10, 10, 10]

info_str = [(" " * margin).join([f"{col_name:<{col_w}}" for col_name, col_w in zip(headers, max_w)])]
info_str.append("-" * len(info_str[0]))
print("\n".join(info_str))
for name in TORCHVISION_MODELS:
    model = models.__dict__[name]().eval().to(device)
    dsize = (3, 224, 224)
    if "inception" in name:
        dsize = (3, 299, 299)
    model_info = crawl_module(model, dsize)

    tot_params = sum(layer["grad_params"] + layer["nograd_params"] for layer in model_info["layers"])
    tot_flops = sum(layer["flops"] for layer in model_info["layers"])
    tot_macs = sum(layer["macs"] for layer in model_info["layers"])
    tot_dmas = sum(layer["dmas"] for layer in model_info["layers"])
    rf = model_info["layers"][0]["rf"]
    print(
        f"{name:<{max_w[0]}} | {tot_params / 1e6:<{max_w[1]}.2f} | {tot_flops / 1e9:<{max_w[2]}.2f} | "
        f"{tot_macs / 1e9:<{max_w[3]}.2f} | {tot_dmas / 1e9:<{max_w[4]}.2f} | {rf:<{max_w[5]}.0f}"
    )



Model                   Params (M)    FLOPs (G)     MACs (G)      DMAs (G)      RF        
------------------------------------------------------------------------------------------
alexnet              | 61.10      | 1.43       | 0.71       | 0.72       | 1         




googlenet            | 6.62       | 3.01       | 1.51       | 1.53       | 1         
vgg11                | 132.86     | 15.23      | 7.61       | 7.64       | 1         
vgg11_bn             | 132.87     | 15.26      | 7.63       | 7.66       | 1         
vgg13                | 133.05     | 22.63      | 11.31      | 11.35      | 1         
vgg13_bn             | 133.05     | 22.68      | 11.33      | 11.37      | 1         
vgg16                | 138.36     | 30.96      | 15.47      | 15.52      | 1         
vgg16_bn             | 138.37     | 31.01      | 15.50      | 15.55      | 1         
vgg19                | 143.67     | 39.28      | 19.63      | 19.69      | 1         
vgg19_bn             | 143.68     | 39.34      | 19.66      | 19.72      | 1         
resnet18             | 11.69      | 3.64       | 1.82       | 1.84       | 1         
resnet34             | 21.80      | 7.34       | 3.67       | 3.70       | 1         
resnet50             | 25.56      | 8.21       | 4.11 



inception_v3         | 23.83      | 11.45      | 5.73       | 5.77       | 1         
squeezenet1_0        | 1.25       | 1.64       | 0.82       | 0.83       | 1         
squeezenet1_1        | 1.24       | 0.70       | 0.35       | 0.36       | 1         
wide_resnet50_2      | 68.88      | 22.84      | 11.43      | 11.51      | 1         
wide_resnet101_2     | 126.89     | 45.58      | 22.80      | 22.95      | 1         
densenet121          | 7.98       | 5.74       | 2.87       | 2.90       | 1         
densenet161          | 28.68      | 15.59      | 7.79       | 7.86       | 1         
densenet169          | 14.15      | 6.81       | 3.40       | 3.44       | 1         
densenet201          | 20.01      | 8.70       | 4.34       | 4.39       | 1         
resnext50_32x4d      | 25.03      | 8.51       | 4.26       | 4.30       | 1         
resnext101_32x8d     | 88.79      | 32.93      | 16.48      | 16.61      | 1         
mobilenet_v2         | 3.50       | 0.63       | 0.31 

In [7]:
import io
import sys
from collections import OrderedDict

import pytest
import torch.nn as nn

from torch_book.scan import crawler


def test_apply():
    multi_convs = nn.Sequential(nn.Conv2d(16, 32, 3), nn.Conv2d(32, 64, 3))
    mod = nn.Sequential(nn.Conv2d(3, 16, 3), multi_convs)

    # Tag module attributes
    def tag_name(mod, name):
        mod.__depth__ = len(name.split(".")) - 1
        mod.__name__ = name.rpartition(".")[-1]

    crawler.apply(mod, tag_name)

    assert mod[1][1].__depth__ == 2
    assert mod[1][1].__name__ == "1"


def test_crawl_module():

    mod = nn.Conv2d(3, 8, 3)

    res = crawler.crawl_module(mod, (3, 32, 32))
    assert isinstance(res, dict)
    assert res["overall"]["grad_params"] == 224
    assert res["layers"][0]["output_shape"] == (-1, 8, 30, 30)


def test_summary():

    mod = nn.Conv2d(3, 8, 3)

    # Redirect stdout with StringIO object
    captured_output = io.StringIO()
    sys.stdout = captured_output
    crawler.summary(mod, (3, 32, 32))
    # Reset redirect.
    sys.stdout = sys.__stdout__
    assert captured_output.getvalue().split("\n")[7] == "Total params: 224"

    # Check receptive field
    captured_output = io.StringIO()
    sys.stdout = captured_output
    crawler.summary(mod, (3, 32, 32), receptive_field=True)
    # Reset redirect.
    sys.stdout = sys.__stdout__
    assert captured_output.getvalue().split("\n")[1].rpartition("  ")[-1] == "Receptive field"
    assert captured_output.getvalue().split("\n")[3].split()[-1] == "3"
    # Check effective stats
    captured_output = io.StringIO()
    sys.stdout = captured_output
    crawler.summary(mod, (3, 32, 32), receptive_field=True, effective_rf_stats=True)
    # Reset redirect.
    sys.stdout = sys.__stdout__
    assert captured_output.getvalue().split("\n")[1].rpartition("  ")[-1] == "Effective padding"
    assert captured_output.getvalue().split("\n")[3].split()[-1] == "0"

    # Max depth > model hierarchy
    with pytest.raises(ValueError):
        crawler.summary(mod, (3, 32, 32), max_depth=1)

    mod = nn.Sequential(
        OrderedDict(
            [
                ("features", nn.Sequential(nn.Conv2d(3, 8, 3), nn.ReLU(inplace=True))),
                ("pool", nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1))),
                ("classifier", nn.Linear(8, 1)),
            ]
        )
    )

    captured_output = io.StringIO()
    sys.stdout = captured_output
    crawler.summary(mod, (3, 32, 32), max_depth=1)
    # Reset redirect.
    sys.stdout = sys.__stdout__
    assert captured_output.getvalue().split("\n")[4].startswith("├─features ")


In [9]:
import os

import torch

from torch_book.scan import process


def test_get_process_gpu_ram():

    if torch.cuda.is_initialized:
        assert process.get_process_gpu_ram(os.getpid()) >= 0
    else:
        assert process.get_process_gpu_ram(os.getpid()) == 0
test_get_process_gpu_ram()

In [11]:
import pytest

from torch_book.scan import utils


def test_format_name():
    name = "mymodule"
    assert utils.format_name(name) == name
    assert utils.format_name(name, depth=1) == f"├─{name}"
    assert utils.format_name(name, depth=3) == f"|    |    └─{name}"


def test_wrap_string():

    example = ".".join(["a" for _ in range(10)])
    max_len = 10
    wrap = "[...]"

    assert utils.wrap_string(example, max_len, mode="end") == example[: max_len - len(wrap)] + wrap
    assert utils.wrap_string(example, max_len, mode="mid") == f"{example[:max_len - 2 - len(wrap)]}{wrap}.a"
    assert utils.wrap_string(example, len(example), mode="end") == example
    with pytest.raises(ValueError):
        _ = utils.wrap_string(example, max_len, mode="test")


@pytest.mark.parametrize(
    "input_val, num_val, unit",
    [
        [3e14, 300, "T"],
        [3e10, 30, "G"],
        [3e7, 30, "M"],
        [15e3, 15, "k"],
        [500, 500, ""],
    ],
)
def test_unit_scale(input_val, num_val, unit):
    assert utils.unit_scale(input_val) == (num_val, unit)


In [None]:
import pytest
import torch
from torch import nn

from torch_book.scan import modules


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()


def test_module_flops_warning():
    with pytest.warns(UserWarning):
        modules.module_flops(MyModule(), None, None)


@pytest.mark.parametrize(
    "mod, input_shape, output_shape, expected_val",
    [
        # Check for unknown module that it returns 0 and throws a warning
        [MyModule(), (1,), (1,), 0],
        # Fully-connected
        [nn.Linear(8, 4), (1, 8), (1, 4), 4 * (2 * 8 - 1) + 4],
        [nn.Linear(8, 4, bias=False), (1, 8), (1, 4), 4 * (2 * 8 - 1)],
        [nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 2 * (4 * (2 * 8 - 1) + 4)],
        # Activations
        [nn.Identity(), (1, 8), (1, 8), 0],
        [nn.Flatten(), (1, 8), (1, 8), 0],
        [nn.ReLU(), (1, 8), (1, 8), 8],
        [nn.ELU(), (1, 8), (1, 8), 48],
        [nn.LeakyReLU(), (1, 8), (1, 8), 32],
        [nn.ReLU6(), (1, 8), (1, 8), 16],
        [nn.Tanh(), (1, 8), (1, 8), 48],
        [nn.Sigmoid(), (1, 8), (1, 8), 32],
        # BN
        [nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 144 + 32 + 32 * 3 + 48],
        # Pooling
        [nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32],
        [nn.AvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32],
        [nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32],
        [nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32],
        [nn.AdaptiveAvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32],
        [nn.AdaptiveAvgPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32],
        # Dropout
        [nn.Dropout(), (1, 8), (1, 8), 8],
        [nn.Dropout(p=0), (1, 8), (1, 8), 0],
        # Conv
        [nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 388800],
        [nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 499408],
    ],
)
def test_module_flops(mod, input_shape, output_shape, expected_val):
    assert modules.module_flops(mod, (torch.zeros(input_shape),), torch.zeros(output_shape)) == expected_val


def test_transformer_flops():
    mod = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=3)
    src = torch.rand((10, 16, 64))
    tgt = torch.rand((20, 16, 64))
    assert modules.module_flops(mod, (src, tgt), mod(src, tgt)) == 774952841


def test_module_macs_warning():
    with pytest.warns(UserWarning):
        modules.module_macs(MyModule(), None, None)


@pytest.mark.parametrize(
    "mod, input_shape, output_shape, expected_val",
    [
        # Check for unknown module that it returns 0 and throws a warning
        [MyModule(), (1,), (1,), 0],
        # Fully-connected
        [nn.Linear(8, 4), (1, 8), (1, 4), 8 * 4],
        [nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 8 * 4 * 2],
        # Activations
        [nn.ReLU(), (1, 8), (1, 8), 0],
        # BN
        [nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 64 + 24 + 56 + 32],
        # Pooling
        [nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32],
        [nn.AvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32],
        [nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32],
        [nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32],
        [nn.AdaptiveAvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32],
        [nn.AdaptiveAvgPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32],
        # Dropout
        [nn.Dropout(), (1, 8), (1, 8), 0],
        # Conv
        [nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 194400],
        [nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 249704],
    ],
)
def test_module_macs(mod, input_shape, output_shape, expected_val):

    assert modules.module_macs(mod, torch.zeros(input_shape), torch.zeros(output_shape)) == expected_val


def test_module_dmas_warning():
    with pytest.warns(UserWarning):
        modules.module_dmas(MyModule(), None, None)


@pytest.mark.parametrize(
    "mod, input_shape, output_shape, expected_val",
    [
        # Check for unknown module that it returns 0 and throws a warning
        [MyModule(), (1,), (1,), 0],
        # Fully-connected
        [nn.Linear(8, 4), (1, 8), (1, 4), 4 * (8 + 1) + 8 + 4],
        [nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 4 * (8 + 1) + 2 * (8 + 4)],
        # Activations
        [nn.Identity(), (1, 8), (1, 8), 8],
        [nn.Flatten(), (1, 8), (1, 8), 16],
        [nn.ReLU(), (1, 8), (1, 8), 8 * 2],
        [nn.ReLU(inplace=True), (1, 8), (1, 8), 8],
        [nn.ELU(), (1, 8), (1, 8), 17],
        [nn.Tanh(), (1, 8), (1, 8), 24],
        [nn.Sigmoid(), (1, 8), (1, 8), 16],
        # BN
        [nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 32 + 17 + 16 + 1 + 17 + 32],
        # Pooling
        [nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32],
        [nn.MaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32],
        [nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32],
        [nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32],
        # Dropout
        [nn.Dropout(), (1, 8), (1, 8), 17],
        # Conv
        [nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 201824],
        [nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 259178],
    ],
)
def test_module_dmas(mod, input_shape, output_shape, expected_val):

    assert modules.module_dmas(mod, torch.zeros(input_shape), torch.zeros(output_shape)) == expected_val


# @torch.no_grad()
# def test_module_rf(self):

#     # Check for unknown module that it returns 0 and throws a warning
#     self.assertEqual(modules.module_rf(MyModule(), None, None), (1, 1, 0))
#     self.assertWarns(UserWarning, modules.module_rf, MyModule(), None, None)

#     # Common unit tests
#     # Linear
#     self.assertEqual(modules.module_rf(nn.Linear(8, 4), torch.zeros((1, 8)), torch.zeros((1, 4))),
#                      (1, 1, 0))
#     # Activation
#     self.assertEqual(modules.module_rf(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.Sigmoid(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.Tanh(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     # Conv
#     input_t = torch.rand((1, 3, 32, 32))
#     mod = nn.Conv2d(3, 8, 3)
#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (3, 1, 0))
#     # Check for dilation support
#     mod = nn.Conv2d(3, 8, 3, dilation=2)
#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (5, 1, 0))
#     # ConvTranspose
#     mod = nn.ConvTranspose2d(3, 8, 3)
#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (-3, 1, 0))
#     # BN
#     self.assertEqual(modules.module_rf(nn.BatchNorm1d(8), torch.zeros((1, 8, 4)), torch.zeros((1, 8, 4))),
#                      (1, 1, 0))

#     # Pooling
#     self.assertEqual(modules.module_rf(nn.MaxPool2d((2, 2)),
#                                        torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))),
#                      (2, 2, 0))
#     self.assertEqual(modules.module_rf(nn.AdaptiveMaxPool2d((2, 2)),
#                                        torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))),
#                      (2, 2, 0))

#     # Dropout
#     self.assertEqual(modules.module_rf(nn.Dropout(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
