Skip to content

Commit

Permalink
[ONNX] Drop 'aten_graph' arg for 'DynamoExporter'
Browse files Browse the repository at this point in the history
ghstack-source-id: 6d04ee3f2ee259a6c51cd527c290a357096846f2
Pull Request resolved: #99667
  • Loading branch information
BowenBao committed Apr 21, 2023
1 parent 224ffcb commit eb50da5
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 19 deletions.
31 changes: 26 additions & 5 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,32 @@ def forward(self, x):
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, Model(), (input,))

@pytorch_test_common.xfail(
"RuntimeError: false INTERNAL ASSERT FAILED at "
"'/home/titaiwang/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp':3725,"
" please report a bug to PyTorch. mutating a non-functional tensor with a "
"functional tensor is not allowed. Please ensure that all of your inputs are "
"wrapped inside of a functionalize() call."
"RuntimeError: Unknown call_function target: aten.var_mean.correction"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_resnet18(self):
model = torchvision.models.resnet18(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self,
model,
(dummy_input,),
)

@pytorch_test_common.xfail(
"Found unsupported input types on PyTorch Op aten.convolution.default with "
"ValueError: Unexpected input argument type is found in node arguments. arg: None;"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_shufflenet_v2(self):
Expand Down
5 changes: 1 addition & 4 deletions torch/onnx/_internal/fx/dynamo_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def export(self) -> torch.onnx.ExportOutput:
# TODO(wechi): There are several symbolic tracing mechanisms to convert
# nn.Module to FX graph. We should choose the right one after they are
# matured.
# TODO(titaiwang): Set `tracing_mode` according to `self.options.dynamic_shapes`
graph_module, graph_guard = torch._dynamo.export(
wrapped_model, *args, aten_graph=True, **kwargs
)
graph_module, graph_guard = torch._dynamo.export(wrapped_model, *args, **kwargs)
del graph_guard # Unused
torch._dynamo.reset()

Expand Down
20 changes: 12 additions & 8 deletions torch/onnx/_internal/fx/fx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,22 +298,26 @@ def export_fx_to_onnx(
fx_module: torch.fx.GraphModule,
fx_module_args: Sequence[Any],
) -> torch.onnx.ExportOutput:
# Apply decomposition table to the input graph.
module = passes.Decompose(
fx_module,
self.decomposition_table,
enable_dynamic_axes=self.options.dynamic_shapes,
).run(*fx_module_args)

# ONNX does not support views and mutations.
# Functionalize to get a semantically equivalent graph without mutations.
# NOTE: Functionalize must run before decomposition and aten graph lowering.
# https://github.com/pytorch/pytorch/issues/99662
module = passes.Functionalize(
module, enable_dynamic_axes=self.options.dynamic_shapes
fx_module, enable_dynamic_axes=self.options.dynamic_shapes
).run(*fx_module_args)
# Input mutations are detected and distilled after `Functionalize` pass.
# Remove them since ONNX inference does not need them.
module = passes.RemoveInputMutation(module).run(*fx_module_args)

# Apply decomposition table to the input graph.
module = passes.Decompose(
module,
self.decomposition_table,
enable_dynamic_axes=self.options.dynamic_shapes,
).run(*fx_module_args)

module = passes.ReplaceInplacePostFunctionalization(module).run(*fx_module_args)

# Run ShapeInferenceWithFakeTensor to get static shape of nodes for op_level_debug purposes
# The pass added nodes with static shape into original node metadata:
# node.meta["static_shape"]: FakeTensor/int/float/SymInt/SynFloat
Expand Down
3 changes: 2 additions & 1 deletion torch/onnx/_internal/fx/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .decomp import Decompose
from .functionalization import Functionalize, RemoveInputMutation
from .functionalization import Functionalize, RemoveInputMutation, ReplaceInplacePostFunctionalization
from .fx_to_onnxscript import export_fx_to_onnxscript
from .shape_inference import ShapeInferenceWithFakeTensor
from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder
Expand All @@ -11,5 +11,6 @@
"MovePlaceholderToFront",
"RemoveInputMutation",
"ReplaceGetAttrWithPlaceholder",
"ReplaceInplacePostFunctionalization",
"ShapeInferenceWithFakeTensor",
]
105 changes: 104 additions & 1 deletion torch/onnx/_internal/fx/passes/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import torch
import torch.func
import torch.fx
import torch._ops

from torch.fx.experimental import proxy_tensor
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils

from typing import Dict, Optional


class Functionalize(_pass.Transform):
"""Functionalize a GraphModule.
Expand Down Expand Up @@ -52,6 +55,9 @@ def fn(a, b):
For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass.
``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``,
which are not needed for ONNX inference.
NOTE: Functionalize must run before decomposition and aten graph lowering.
https://github.com/pytorch/pytorch/issues/99662
"""

@_beartype.beartype
Expand Down Expand Up @@ -102,3 +108,100 @@ def _run(self, *args) -> torch.fx.GraphModule:
):
self.module.graph.erase_node(node)
return self.module


class ReplaceInplacePostFunctionalization(_pass.Transform):
"""
NOTE: This pass is not needed, if functionalize can be applied on decomposed graph.
https://github.com/pytorch/pytorch/issues/99662
"""

@_beartype.beartype
def _outplace_target(
self, inplace_target: torch._ops.OpOverload
) -> Optional[torch._ops.OpOverload]:
assert inplace_target.namespace == "aten"
outplace_name = inplace_target._schema.name.split("::")[1][:-1]
overload_name = inplace_target._overloadname

opoverloadpacket = getattr(torch.ops.aten, outplace_name)
if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket):
return None

return getattr(opoverloadpacket, overload_name, None)

@_beartype.beartype
def _run(self, *args) -> torch.fx.GraphModule:
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to validate that
# the mutated input value is not used after the mutation.
node_to_last_use: Dict[torch.fx.Node, torch.fx.Node] = {}

def register_last_uses(n: torch.fx.Node, user: torch.fx.Node):
if n not in node_to_last_use:
node_to_last_use[n] = user

for node in reversed(self.module.graph.nodes):
torch.fx.node.map_arg(node.args, lambda n: register_last_uses(n, node))
torch.fx.node.map_arg(node.kwargs, lambda n: register_last_uses(n, node))

for node in self.module.graph.nodes:
if node.op != "call_function" or not isinstance(
node.target, torch._ops.OpOverload
):
continue

target = node.target
mutated_input = node.args[0]

name_without_overload = target._schema.name
is_inplace = name_without_overload.endswith(
"_"
) and not name_without_overload.endswith("__")
is_aten = target.namespace == "aten"

if not is_inplace:
continue

if not is_aten:
# TODO(bowbao): Turn this into individual diagnostic.
diagnostic = diagnostics.export_context().inflight_diagnostic(
rule=diagnostics.rules.fx_pass
)
diagnostic.level = diagnostics.levels.WARNING
diagnostic.with_additional_message(
f"Found non-aten op {target} in graph with inplace naming convention. "
f"Skip replacing this op with outplace version."
)
continue

assert isinstance(
mutated_input, torch.fx.Node
), f"Expected mutated input to be a torch.fx.Node. Got {type(mutated_input)}"

if node_to_last_use[mutated_input] != node:
# TODO(bowbao): Turn this into individual diagnostic.
raise RuntimeError(
f"Found inplace op node {node} that is not the last use of its input. "
f"Its mutated input is later used by {node_to_last_use[mutated_input]}. "
f"Please run RemoveInputMutation pass before ReplaceInplacePostFunctionalization."
)

outplace_target = self._outplace_target(target)

if outplace_target is None:
# TODO(bowbao): Turn this into individual diagnostic.
diagnostic = diagnostics.export_context().inflight_diagnostic(
rule=diagnostics.rules.fx_pass
)
diagnostic.level = diagnostics.levels.WARNING
diagnostic.with_additional_message(
f"Failed to find outplace version of {target}. Skip replacing this op."
)
continue

node.target = outplace_target

return self.module

0 comments on commit eb50da5

Please sign in to comment.