# torch._dynamo.export 结论（2023-12-07）
1. 支持静态 flag 控制
2. 不支持 Tensor 控制，torch 建议使用 `functorch.experimental.control_flow.cond` 无法调通，同样报错不支持

In [2]:
import torch
from pprint import pprint

In [3]:
from torch import nn

In [7]:
def fn_flag(x):
    if x.shape[0] > 2:
        return torch.relu(x)
    return torch.sigmoid(x)


x = torch.randn(2)
out = torch._dynamo.export(fn_flag)(x)
out.graph_module.print_readable()

class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: f32[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        # File: /tmp/ipykernel_194925/4153613159.py:2, code: if x.shape[0] > 2:
        size = arg0.size()
        getitem = size[0];  size = None
        gt = getitem > 2;  getitem = None
        
        # File: /tmp/ipykernel_194925/4153613159.py:4, code: return torch.sigmoid(x)
        sigmoid = torch.sigmoid(arg0);  arg0 = None
        return pytree.tree_unflatten([sigmoid], self._out_spec)
        


'class GraphModule(torch.nn.Module):\n    def forward(self, x):\n        arg0: f32[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)\n        # File: /tmp/ipykernel_194925/4153613159.py:2, code: if x.shape[0] > 2:\n        size = arg0.size()\n        getitem = size[0];  size = None\n        gt = getitem > 2;  getitem = None\n        \n        # File: /tmp/ipykernel_194925/4153613159.py:4, code: return torch.sigmoid(x)\n        sigmoid = torch.sigmoid(arg0);  arg0 = None\n        return pytree.tree_unflatten([sigmoid], self._out_spec)\n        '

In [9]:
list(out.guards)

[
         global '' GRAD_MODE
         {
             'guard_types': ['GRAD_MODE'],
             'code': ['___is_grad_enabled()'],
             'obj_weakref': None
             'guarded_class': None
         }
         ,
 
         local "L['x']" TENSOR_MATCH
         {
             'guard_types': ['TYPE_MATCH', 'TENSOR_MATCH'],
             'code': ["___check_type_id(L['x'], 92004160)", "str(L['x'].dtype) == 'torch.float32'", "str(L['x'].device) == 'cpu'", "L['x'].requires_grad == False", "L['x'].ndimension() == 1", "hasattr(L['x'], '_dynamo_dynamic_indices') == False"],
             'obj_weakref': <weakref at 0x7fd58f23ab80; to 'Tensor' at 0x7fd5aa2bc270>
             'guarded_class': <weakref at 0x7fd5abdef950; to 'torch._C._TensorMeta' at 0x57bdf40 (Tensor)>
         }
         ,
 
         global "G['torch']" FUNCTION_MATCH
         {
             'guard_types': None,
             'code': None,
             'obj_weakref': None
             'guarded_class': None
         }
       

In [5]:
# 不支持 Tensor 作为判断条件，会报错
# 其它 python 类型数据不会报错
flag = (torch.randn(1)[0] > 0.5).item()

# import numpy as np
# flag = np.random.uniform(size=(1,))[0] > 0.5

def fn_flag(x):
    if flag:
        return torch.relu(x)
    return torch.sigmoid(x)


x = torch.randn(2)
out = torch._dynamo.export(fn_flag)(x)
out.graph_module.print_readable()

class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: f32[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        # File: /tmp/ipykernel_194925/2675718805.py:10, code: return torch.relu(x)
        relu = torch.relu(arg0);  arg0 = None
        return pytree.tree_unflatten([relu], self._out_spec)
        


'class GraphModule(torch.nn.Module):\n    def forward(self, x):\n        arg0: f32[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)\n        # File: /tmp/ipykernel_194925/2675718805.py:10, code: return torch.relu(x)\n        relu = torch.relu(arg0);  arg0 = None\n        return pytree.tree_unflatten([relu], self._out_spec)\n        '

In [36]:
def fn_data_dependent_control_flow(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.sigmoid(x)

In [None]:
def test_fn_data_dependent_control_flow():
    a = torch.randn(10)
    out = torch._dynamo.export(fn_data_dependent_control_flow)(a)
    guards = out.guards
    gm = out.graph_module
    print(gm)
    for guard in guards:
        print(guard)


test_fn_data_dependent_control_flow()

In [None]:
# 尝试过 lambda function，x.sum(), x.shape 等判断条件，都不支持
def true_fn(x):
    return torch.relu(x)


def false_fn(x):
    return torch.sigmoid(x)


def fn_data_dependent_control_flow(x):
    return torch._higher_order_ops.cond(x.shape[0] > 0, true_fn, false_fn, (x,))

In [None]:
def test_fn_data_dependent_control_flow():
    a = torch.randn(10)
    out = torch._dynamo.export(fn_data_dependent_control_flow)(a)
    guards = out.guards
    gm = out.graph_module
    print(gm)
    for guard in guards:
        print(guard)


test_fn_data_dependent_control_flow()

In [40]:
import torch
from torch import nn

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 3, 3, 3)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = torch.sigmoid(x)
        return x

x  = torch.randn(1, 1, 10, 10)
net = Net()
y = net(x)
exported = torch._dynamo.export(net, aten_graph=True, pre_dispatch=True)(x)

In [41]:
exported.graph_module.print_readable()


class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: f32[1, 1, s0, s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        # File: /tmp/ipykernel_22132/4130250000.py:11, code: x = self.conv(x)
        _param_constant0 = self._param_constant0
        _param_constant1 = self._param_constant1
        conv2d_default: f32[1, 3, (s0//3), (s0//3)] = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [3, 3]);  arg0 = _param_constant0 = _param_constant1 = None
        
        # File: /tmp/ipykernel_22132/4130250000.py:12, code: x = self.relu(x)
        relu_default: f32[1, 3, (s0//3), (s0//3)] = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
        
        # File: /tmp/ipykernel_22132/4130250000.py:13, code: x = torch.sigmoid(x)
        sigmoid_default: f32[1, 3, (s0//3), (s0//3)] = torch.ops.aten.sigmoid.default(relu_default);  relu_default = None
        return pytree.tree_unflatten([sigmoid_default], self._out_s

'class GraphModule(torch.nn.Module):\n    def forward(self, x):\n        arg0: f32[1, 1, s0, s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)\n        # File: /tmp/ipykernel_22132/4130250000.py:11, code: x = self.conv(x)\n        _param_constant0 = self._param_constant0\n        _param_constant1 = self._param_constant1\n        conv2d_default: f32[1, 3, (s0//3), (s0//3)] = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [3, 3]);  arg0 = _param_constant0 = _param_constant1 = None\n        \n        # File: /tmp/ipykernel_22132/4130250000.py:12, code: x = self.relu(x)\n        relu_default: f32[1, 3, (s0//3), (s0//3)] = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None\n        \n        # File: /tmp/ipykernel_22132/4130250000.py:13, code: x = torch.sigmoid(x)\n        sigmoid_default: f32[1, 3, (s0//3), (s0//3)] = torch.ops.aten.sigmoid.default(relu_default);  relu_default = None\n        return pytree.tree_unflatten([sigmoid_default

In [15]:
import horizon_plugin_pytorch as hz

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 3, 3, 3)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

In [26]:
x  = torch.randn(1, 1, 10, 10)
net = Net()
y = net(x)
exported = torch._dynamo.export(net, aten_graph=False)(x)
gm = exported.graph_module

AttributeError: 'function' object has no attribute 'graph_module'

In [24]:
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
    default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
)
from horizon_plugin_pytorch.march import March, set_march
set_march("bayes")
qconfig = {"": default_qat_8bit_fake_quant_qconfig}
hz.quantization.prepare_qat_fx(gm, qconfig)



GraphModule(
  (L__self___conv): ConvReLU2d(
    1, 3, kernel_size=(3, 3), stride=(3, 3)
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0,         scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0])
      (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
  )
  (L__self___bn): Identity()
  (L__self___relu): Identity()
)

In [28]:
qat_model = hz.quantization.prepare_qat_fx(net, qconfig)

In [29]:
qat_model(x)



QTensor(
  data = tensor([[[[0.4563, 0.0000, 0.6794],
          [0.0000, 0.0000, 0.6592],
          [0.0000, 0.0000, 0.5679]],

         [[0.0000, 0.6186, 0.0000],
          [0.3854, 1.0039, 0.0000],
          [0.1825, 0.5679, 0.0000]],

         [[0.5172, 1.2879, 0.2941],
          [0.4766, 0.5983, 0.0000],
          [0.0000, 0.0000, 0.0000]]]], grad_fn=<AliasBackward0>),
  scale = tensor([0.0101]),
  zero_point = tensor([0]),
  dtype = qint8,
  per_channel_axis = -1,
  is_quantized = False
)

In [115]:
import torch
from torch import nn
from torch._export import capture_pre_autograd_graph
class MyAct(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.act = MyReLU()
    def forward(self, x):
        return self.act(x) + self.act(x)
    
class MyReLU(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.myrelu = torch.nn.ReLU()

    def forward(self, x):
        return self.myrelu(x)
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 3, 3, 3)
        self.relu = nn.ReLU()
        self.custom_act = MyAct()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = torch.sigmoid(x)
        x = self.custom_act(x)
        return x

x  = torch.randn(1, 1, 10, 10)
net = Net()
y = net(x)
exported = capture_pre_autograd_graph(net, (x, ))
from torch.ao.quantization.pt2e.utils import _get_node_name_to_scope
for n in exported.graph.nodes:
    pprint(n.meta.get("nn_module_stack", "None"))
    print("---")
# print(exported)
# pprint(exported)
# exported = torch._dynamo.export(net, aten_graph=True, pre_dispatch=True)(x)
# from torch.ao.quantization.pt2e.utils import _get_node_name_to_scope
# _get_node_name_to_scope(gm)
# # print(exported.graph_module)

# exported = torch._dynamo.export(net, aten_graph=False, tracing_mode="real")(x)
# from torch.ao.quantization.pt2e.utils import _get_node_name_to_scope
# _get_node_name_to_scope(gm)
# print(exported.graph_module)

'None'
---
{'L__self___conv': ("L['self'].conv", <class 'torch.nn.modules.conv.Conv2d'>)}
---
{'L__self___conv': ("L['self'].conv", <class 'torch.nn.modules.conv.Conv2d'>)}
---
{'L__self___conv': ("L['self'].conv", <class 'torch.nn.modules.conv.Conv2d'>)}
---
{'L__self___relu': ("L['self'].relu",
                    <class 'torch.nn.modules.activation.ReLU'>)}
---
'None'
---
{'L__self___custom_act': ("L['self'].custom_act", <class '__main__.MyAct'>),
 'L__self___custom_act_act': ("L['self'].custom_act.act",
                              <class '__main__.MyReLU'>),
 'L__self___custom_act_act_myrelu': ("L['self'].custom_act.act.myrelu",
                                     <class 'torch.nn.modules.activation.ReLU'>)}
---
{'L__self___custom_act': ("L['self'].custom_act", <class '__main__.MyAct'>),
 'L__self___custom_act_act': ("L['self'].custom_act.act",
                              <class '__main__.MyReLU'>),
 'L__self___custom_act_act_myrelu': ("L['self'].custom_act.act.myrelu",
      

In [53]:
from torch.ao.quantization.pt2e.utils import _get_node_name_to_scope
_get_node_name_to_scope(gm)

{'arg0': ('', NoneType),
 'l__self___conv': ('conv', torch.nn.modules.conv.Conv2d),
 'l__self___bn': ('bn', torch.nn.modules.batchnorm.BatchNorm2d),
 'l__self___relu': ('relu', torch.nn.modules.activation.ReLU),
 'output': ('', NoneType)}

In [64]:
gm = torch.fx.symbolic_trace(net)

In [69]:
gm.__dict__

{'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('conv',
               Conv2d(1, 3, kernel_size=(3, 3), stride=(3, 3))),
              ('relu', ReLU())]),
 '_graph': <torch.fx.graph.Graph at 0x7faa7b0001c0>,
 '_code': '\n\n\ndef forward(self, x):\n    conv = self.conv(x);  x = None\n    relu = self.relu(conv);  conv = None\n    sigmoid = torch.sigmoid(relu);  relu = None\n    return sigmoid\n    ',
 '_trace

In [106]:
tracer = torch.fx.Tracer()

In [108]:
gm = tracer.trace(net)

In [110]:
tracer.node_name_to_scope

{'x': ('', None),
 'conv': ('conv', torch.nn.modules.conv.Conv2d),
 'relu': ('relu', torch.nn.modules.activation.ReLU),
 'sigmoid': ('', None),
 'custom_act_act_myrelu': ('custom_act.act.myrelu',
  torch.nn.modules.activation.ReLU),
 'custom_act_act_myrelu_1': ('custom_act.act.myrelu',
  torch.nn.modules.activation.ReLU),
 'add': ('custom_act', __main__.MyAct),
 'output': ('', None)}