# 编译 PyTorch 模型

**Author**: [Yaoda Zhou](https://github.com/juda)

本文是一篇使用装饰器`optimize_torch`优化PyTorch模型的教程。要跟随本教程，需要安装 PyTorch 以及 TorchVision：
```bash
%%shell
pip install torch
pip install torchvision
```

In [1]:
import set_env

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Import library for profiling
import torch.utils.benchmark as benchmark
from torchvision.models import resnet18

# Import `optimize_torch` function
from tvm.contrib.torch import optimize_torch



## 使用 PyTorch 构建简单模型

In [4]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

## 使用 TVM MetaSchedule 优化 SimpleModel

我们提供了`optimize_torch`函数，其用法与`torch.jit.trace`类似。用户需要提供要优化的PyTorch模型以及其示例输入。PyTorch模块将由TVM针对目标硬件进行调优。如果不提供额外信息，模型将针对CPU进行调优。


In [6]:
simple_model = SimpleModel()
example_input = torch.randn(20, 1, 10, 10)
model_optimized_by_tvm = optimize_torch(simple_model, example_input, max_trials_global=2)

2024-03-20 12:21:19 [INFO] Logging directory: /tmp/tmpl0j3jqte/logs
2024-03-20 12:21:36 [INFO] LocalBuilder: max_workers = 24
2024-03-20 12:21:38 [INFO] LocalRunner: max_workers = 1
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #0: "fused_layout_transform"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #1: "fused_nn_contrib_conv2d_NCHWc_add_nn_relu"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #2: "fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #3: "fused_layout_transform_1"


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,fused_layout_transform,1,1,,,,0,
1,fused_nn_contrib_conv2d_NCHWc_add_nn_relu,748800,1,,,,0,
2,fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1,1603200,1,,,,0,
3,fused_layout_transform_1,1,1,,,,0,


2024-03-20 12:21:40 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |      
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |      
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
-----------------------------------------------------------------------------------------------

Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,fused_layout_transform,1,1,0.0001,11.0077,11.0077,2,
1,fused_nn_contrib_conv2d_NCHWc_add_nn_relu,748800,1,,,,0,
2,fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1,1603200,1,,,,0,
3,fused_layout_transform_1,1,1,,,,0,


2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |      
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |      
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |      
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
-----------------------------------------------------------------------------------------------

Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,fused_layout_transform,1,1,0.0001,11.0077,11.0077,2,Y
1,fused_nn_contrib_conv2d_NCHWc_add_nn_relu,748800,1,,,,0,
2,fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1,1603200,1,,,,0,
3,fused_layout_transform_1,1,1,,,,0,


2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |    Y 
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |      
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |      
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
-----------------------------------------------------------------------------------------------

Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,fused_layout_transform,1,1,0.0001,11.0077,11.0077,2,Y
1,fused_nn_contrib_conv2d_NCHWc_add_nn_relu,748800,1,,,,0,Y
2,fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1,1603200,1,,,,0,
3,fused_layout_transform_1,1,1,,,,0,


2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |    Y 
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |      
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
-----------------------------------------------------------------------------------------------

Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,fused_layout_transform,1,1,0.0001,11.0077,11.0077,2,Y
1,fused_nn_contrib_conv2d_NCHWc_add_nn_relu,748800,1,,,,0,Y
2,fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1,1603200,1,,,,0,Y
3,fused_layout_transform_1,1,1,,,,0,


2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |    Y 
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
-----------------------------------------------------------------------------------------------

Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,fused_layout_transform,1,1,0.0001,11.0077,11.0077,2,Y
1,fused_nn_contrib_conv2d_NCHWc_add_nn_relu,748800,1,,,,0,Y
2,fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1,1603200,1,,,,0,Y
3,fused_layout_transform_1,1,1,,,,0,Y


2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |    Y 
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
-----------------------------------------------------------------------------------------------



ValueError: optimize_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake

## 保存/加载模块

我们可以像标准的`nn.Module`一样保存和加载我们优化过的模块。

让我们运行我们的优化模块。

In [7]:
ret1 = model_optimized_by_tvm(example_input)

torch.save(model_optimized_by_tvm, "model_optimized.pt")
model_loaded = torch.load("model_optimized.pt")

NameError: name 'model_optimized_by_tvm' is not defined

In [None]:

# We load the module and run it again.
ret2 = model_loaded(example_input)

# We will show 2 results:
# (1) we can safely load and save model by showing the result of model
# after save and load operations is still the same as original one;
# (2) the model we optimize returns the same result as the original PyTorch model.

ret3 = simple_model(example_input)
testing.assert_allclose(ret1.detach().numpy(), ret2.detach().numpy(), atol=1e-5, rtol=1e-5)
testing.assert_allclose(ret1.detach().numpy(), ret3.detach().numpy(), atol=1e-5, rtol=1e-5)

######################################################################
# Optimize resnet18
# -----------------
# In the following, we will show that our approach is able to
# accelerate common models, such as resnet18.

# We will tune our model for the GPU.
target_cuda = "nvidia/geforce-rtx-3070"

# For PyTorch users, the code could be written as usual, except for
# applying "optimize_torch" function on the resnet18 model.

resnet18_tvm = optimize_torch(
    resnet18().cuda().eval(), [torch.rand(1, 3, 224, 224).cuda()], target=target_cuda
)

# TorchScript also provides a built-in "optimize_for_inference" function to accelerate the inference.
resnet18_torch = torch.jit.optimize_for_inference(torch.jit.script(resnet18().cuda().eval()))


######################################################################
# Compare the performance between two approaches
# ----------------------------------------------

results = []
for i in range(5):
    test_input = torch.rand(1, 3, 224, 224).cuda()
    sub_label = f"[test {i}]"
    results.append(
        benchmark.Timer(
            stmt="resnet18_tvm(test_input)",
            setup="from __main__ import resnet18_tvm",
            globals={"test_input": test_input},
            sub_label=sub_label,
            description="tuning by meta",
        ).blocked_autorange()
    )
    results.append(
        benchmark.Timer(
            stmt="resnet18_torch(test_input)",
            setup="from __main__ import resnet18_torch",
            globals={"test_input": test_input},
            sub_label=sub_label,
            description="tuning by jit",
        ).blocked_autorange()
    )

compare = benchmark.Compare(results)
compare.print()

# In author's environment, the average inference time of `resnet18_tvm` is 620.0 us,
# while the average inference time of `resnet18_torch` is 980.0 us (PyTorch version is 1.11.0),
# showing the speedup of around 38%.
