In [1]:
import torch 
from torchvision.models import resnet18 
from loguru import logger

In [2]:
# 使用PyTorch model zoo中的resnet18作为例子 
model = resnet18() 
model.eval() 

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [3]:
dummy_input = torch.rand(1, 3, 224, 224) 

In [4]:
# IR生成 
with torch.no_grad(): 
    jit_model = torch.jit.trace(model, dummy_input) 

In [5]:
jit_model

ResNet(
  original_name=ResNet
  (conv1): Conv2d(original_name=Conv2d)
  (bn1): BatchNorm2d(original_name=BatchNorm2d)
  (relu): ReLU(original_name=ReLU)
  (maxpool): MaxPool2d(original_name=MaxPool2d)
  (layer1): Sequential(
    original_name=Sequential
    (0): BasicBlock(
      original_name=BasicBlock
      (conv1): Conv2d(original_name=Conv2d)
      (bn1): BatchNorm2d(original_name=BatchNorm2d)
      (relu): ReLU(original_name=ReLU)
      (conv2): Conv2d(original_name=Conv2d)
      (bn2): BatchNorm2d(original_name=BatchNorm2d)
    )
    (1): BasicBlock(
      original_name=BasicBlock
      (conv1): Conv2d(original_name=Conv2d)
      (bn1): BatchNorm2d(original_name=BatchNorm2d)
      (relu): ReLU(original_name=ReLU)
      (conv2): Conv2d(original_name=Conv2d)
      (bn2): BatchNorm2d(original_name=BatchNorm2d)
    )
  )
  (layer2): Sequential(
    original_name=Sequential
    (0): BasicBlock(
      original_name=BasicBlock
      (conv1): Conv2d(original_name=Conv2d)
      (bn1): B

In [6]:
jit_layer1 = jit_model.layer1 
print(jit_layer1.graph) 

graph(%self.11 : __torch__.torch.nn.modules.container.Sequential,
      %4 : Float(1, 64, 56, 56, strides=[200704, 3136, 56, 1], requires_grad=0, device=cpu)):
  %_1.1 : __torch__.torchvision.models.resnet.___torch_mangle_10.BasicBlock = prim::GetAttr[name="1"](%self.11)
  %_0.1 : __torch__.torchvision.models.resnet.BasicBlock = prim::GetAttr[name="0"](%self.11)
  %6 : Tensor = prim::CallMethod[name="forward"](%_0.1, %4)
  %7 : Tensor = prim::CallMethod[name="forward"](%_1.1, %6)
  return (%7)



In [8]:
print(jit_layer1.code)

def forward(self,
    argument_1: Tensor) -> Tensor:
  _1 = getattr(self, "1")
  _0 = getattr(self, "0")
  _2 = (_1).forward((_0).forward(argument_1, ), )
  return _2



In [9]:
torch._C._jit_pass_inline(jit_layer1.graph) 
print(jit_layer1.code) 

def forward(self,
    argument_1: Tensor) -> Tensor:
  _1 = getattr(self, "1")
  _0 = getattr(self, "0")
  bn2 = _0.bn2
  conv2 = _0.conv2
  relu = _0.relu
  bn1 = _0.bn1
  conv1 = _0.conv1
  weight = conv1.weight
  input = torch._convolution(argument_1, weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
  running_var = bn1.running_var
  running_mean = bn1.running_mean
  bias = bn1.bias
  weight0 = bn1.weight
  input0 = torch.batch_norm(input, weight0, bias, running_mean, running_var, False, 0.10000000000000001, 1.0000000000000001e-05, True)
  input1 = torch.relu_(input0)
  weight1 = conv2.weight
  input2 = torch._convolution(input1, weight1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
  running_var0 = bn2.running_var
  running_mean0 = bn2.running_mean
  bias0 = bn2.bias
  weight2 = bn2.weight
  out = torch.batch_norm(input2, weight2, bias0, running_mean0, running_var0, False, 0.10000000000000001, 1.0000000000000001e-05, True

# test onnx export 

In [10]:
import torch

class Model(torch.nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x):
        for i in range(self.n):
            x = self.conv(x)
        return x



models = [Model(2), Model(3)]
model_names = ['model_2', 'model_3']

In [13]:
for model, model_name in zip(models, model_names):
    dummy_input = torch.rand(1, 3, 10, 10)
    dummy_output = model(dummy_input)
    model_trace = torch.jit.trace(model, dummy_input)
    model_script = torch.jit.script(model)

    # 跟踪法与直接 torch.onnx.export(model, ...)等价
    torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx')
    # 脚本化必须先调用 torch.jit.sciprt
    torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx')

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR



