Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Update 'Functionalize' pass to support pre-decomp graph; Drop 'aten_graph' arg for 'DynamoExporter' #99667

Closed
wants to merge 9 commits into from
31 changes: 26 additions & 5 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,32 @@ def forward(self, x):
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, Model(), (input,))

@pytorch_test_common.xfail(
"RuntimeError: false INTERNAL ASSERT FAILED at "
"'/home/titaiwang/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp':3725,"
" please report a bug to PyTorch. mutating a non-functional tensor with a "
"functional tensor is not allowed. Please ensure that all of your inputs are "
"wrapped inside of a functionalize() call."
"RuntimeError: Unknown call_function target: aten.var_mean.correction"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_resnet18(self):
model = torchvision.models.resnet18(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self,
model,
(dummy_input,),
)

@pytorch_test_common.xfail(
"Found unsupported input types on PyTorch Op aten.convolution.default with "
"ValueError: Unexpected input argument type is found in node arguments. arg: None;"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_shufflenet_v2(self):
Expand Down
5 changes: 1 addition & 4 deletions torch/onnx/_internal/fx/dynamo_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def export(self) -> torch.onnx.ExportOutput:
# TODO(wechi): There are several symbolic tracing mechanisms to convert
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
# nn.Module to FX graph. We should choose the right one after they are
# matured.
# TODO(titaiwang): Set `tracing_mode` according to `self.options.dynamic_shapes`
graph_module, graph_guard = torch._dynamo.export(
wrapped_model, *args, aten_graph=True, **kwargs
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
)
graph_module, graph_guard = torch._dynamo.export(wrapped_model, *args, **kwargs)
del graph_guard # Unused
torch._dynamo.reset()

Expand Down
26 changes: 18 additions & 8 deletions torch/onnx/_internal/fx/fx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,22 +298,32 @@ def export_fx_to_onnx(
fx_module: torch.fx.GraphModule,
fx_module_args: Sequence[Any],
) -> torch.onnx.ExportOutput:
# Apply decomposition table to the input graph.
module = passes.Decompose(
fx_module,
self.decomposition_table,
enable_dynamic_axes=self.options.dynamic_shapes,
).run(*fx_module_args)

# ONNX does not support views and mutations.
# Functionalize to get a semantically equivalent graph without mutations.
# NOTE: Functionalize must run before decomposition and aten graph lowering.
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/pytorch/pytorch/issues/99662
module = passes.Functionalize(
module, enable_dynamic_axes=self.options.dynamic_shapes
fx_module, enable_dynamic_axes=self.options.dynamic_shapes
).run(*fx_module_args)
# Input mutations are detected and distilled after `Functionalize` pass.
# Remove them since ONNX inference does not need them.
module = passes.RemoveInputMutation(module).run(*fx_module_args)

# Apply decomposition table to the input graph.
module = passes.Decompose(
module,
self.decomposition_table,
enable_dynamic_axes=self.options.dynamic_shapes,
).run(*fx_module_args)

# NOTE: This pass is not needed if functionalize can be applied on decomposed graph.
# https://github.com/pytorch/pytorch/issues/99662
# This is a workaround to replace inplace variant ops with outplace version.
# These ops are created by aten graph lowering and decomposition post
# functionalization. No real mutation is expected as it should have been handled
# by functionalization.
module = passes.ReplaceInplacePostFunctionalization(module).run(*fx_module_args)
BowenBao marked this conversation as resolved.
Show resolved Hide resolved

# Run ShapeInferenceWithFakeTensor to get static shape of nodes for op_level_debug purposes
# The pass added nodes with static shape into original node metadata:
# node.meta["static_shape"]: FakeTensor/int/float/SymInt/SynFloat
Expand Down
7 changes: 6 additions & 1 deletion torch/onnx/_internal/fx/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .decomp import Decompose
from .functionalization import Functionalize, RemoveInputMutation
from .functionalization import (
Functionalize,
RemoveInputMutation,
ReplaceInplacePostFunctionalization,
)
from .fx_to_onnxscript import export_fx_to_onnxscript
from .shape_inference import ShapeInferenceWithFakeTensor
from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder
Expand All @@ -11,5 +15,6 @@
"MovePlaceholderToFront",
"RemoveInputMutation",
"ReplaceGetAttrWithPlaceholder",
"ReplaceInplacePostFunctionalization",
"ShapeInferenceWithFakeTensor",
]
114 changes: 113 additions & 1 deletion torch/onnx/_internal/fx/passes/functionalization.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from typing import Dict, Optional

import torch
import torch._ops
import torch.func
import torch.fx

from torch.fx.experimental import proxy_tensor
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils


Expand Down Expand Up @@ -52,6 +55,9 @@ def fn(a, b):
For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass.
``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``,
which are not needed for ONNX inference.

NOTE: Functionalize must run before decomposition and aten graph lowering.
https://github.com/pytorch/pytorch/issues/99662
"""

@_beartype.beartype
Expand Down Expand Up @@ -102,3 +108,109 @@ def _run(self, *args) -> torch.fx.GraphModule:
):
self.module.graph.erase_node(node)
return self.module


class ReplaceInplacePostFunctionalization(_pass.Transform):
"""Replace inplace variant ops with outplace version.

This pass assumes that the graph has been functionalized and decomposed to aten level.
No real mutation is expected as it should have been handled by functionalization.

All inplace variant op nodes are expected to be the last user of its first argument.
That is, the input being mutated cannot not be used by another node afterwards.
Otherwise, a RuntimeError will be raised.

This pass only handles ops under "aten" namespace.

NOTE: This pass is not needed if functionalize can be applied on decomposed graph.
https://github.com/pytorch/pytorch/issues/99662
"""

@_beartype.beartype
def _outplace_target(
self, inplace_target: torch._ops.OpOverload
) -> Optional[torch._ops.OpOverload]:
assert inplace_target.namespace == "aten"
outplace_name = inplace_target._schema.name.split("::")[1][:-1]
overload_name = inplace_target._overloadname

opoverloadpacket = getattr(torch.ops.aten, outplace_name)
if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket):
return None

return getattr(opoverloadpacket, overload_name, None)

@_beartype.beartype
def _run(self, *args) -> torch.fx.GraphModule:
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to validate that
# the mutated input value is not used after the mutation.
node_to_last_use: Dict[torch.fx.Node, torch.fx.Node] = {}

def register_last_uses(n: torch.fx.Node, user: torch.fx.Node):
if n not in node_to_last_use:
node_to_last_use[n] = user

for node in reversed(self.module.graph.nodes):
torch.fx.node.map_arg(node.args, lambda n: register_last_uses(n, node))
torch.fx.node.map_arg(node.kwargs, lambda n: register_last_uses(n, node))

for node in self.module.graph.nodes:
if node.op != "call_function" or not isinstance(
node.target, torch._ops.OpOverload
):
continue

target = node.target
mutated_input = node.args[0]

name_without_overload = target._schema.name
is_inplace = name_without_overload.endswith(
"_"
) and not name_without_overload.endswith("__")
is_aten = target.namespace == "aten"

if not is_inplace:
continue

if not is_aten:
# TODO(bowbao): Turn this into individual diagnostic.
diagnostic = diagnostics.export_context().inflight_diagnostic(
rule=diagnostics.rules.fx_pass
)
diagnostic.level = diagnostics.levels.WARNING
diagnostic.with_additional_message(
f"Found non-aten op {target} in graph with inplace naming convention. "
f"Skip replacing this op with outplace version."
)
continue

assert isinstance(
mutated_input, torch.fx.Node
), f"Expected mutated input to be a torch.fx.Node. Got {type(mutated_input)}"

if node_to_last_use[mutated_input] != node:
# TODO(bowbao): Turn this into individual diagnostic.
raise RuntimeError(
f"Found inplace op node {node} that is not the last use of its input. "
f"Its mutated input is later used by {node_to_last_use[mutated_input]}. "
f"Please run Functionalize pass before ReplaceInplacePostFunctionalization."
)

outplace_target = self._outplace_target(target)

if outplace_target is None:
# TODO(bowbao): Turn this into individual diagnostic.
diagnostic = diagnostics.export_context().inflight_diagnostic(
rule=diagnostics.rules.fx_pass
)
diagnostic.level = diagnostics.levels.WARNING
diagnostic.with_additional_message(
f"Failed to find outplace version of {target}. Skip replacing this op."
)
continue

node.target = outplace_target

return self.module