Skip to content

Commit

Permalink
[ONNX] Drop 'aten_graph' arg for 'DynamoExporter'
Browse files Browse the repository at this point in the history
ghstack-source-id: 077fd3a1d13582ceb7a135f7b07ece93ff0b159d
Pull Request resolved: #99667
  • Loading branch information
BowenBao committed Apr 28, 2023
1 parent a4da688 commit 19f2d1e
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 20 deletions.
35 changes: 30 additions & 5 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,36 @@ def forward(self, x):
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, Model(), (input,))

@pytorch_test_common.xfail(
"RuntimeError: false INTERNAL ASSERT FAILED at "
"'/home/titaiwang/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp':3725,"
" please report a bug to PyTorch. mutating a non-functional tensor with a "
"functional tensor is not allowed. Please ensure that all of your inputs are "
"wrapped inside of a functionalize() call."
"RuntimeError: Unknown call_function target: aten.mean.dim"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_resnet18(self):
# TODO(bowbao): So we are effectively exporting all models in traning mode by
# default. But for the sake of export we are really only interested in eval mode.
# The question is, should we call `model.eval()` in `dynamo_export`?
# This particular test fails 'functionalization' in training mode.
# Ref: https://github.com/pytorch/pytorch/issues/99662#issuecomment-1528178221
model = torchvision.models.resnet18(pretrained=False).eval()
dummy_input = torch.randn(1, 3, 224, 224)

self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model,
(dummy_input,),
)

@pytorch_test_common.xfail(
"Found unsupported input types on PyTorch Op aten.convolution.default with "
"ValueError: Unexpected input argument type is found in node arguments. arg: None;"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_shufflenet_v2(self):
Expand Down
27 changes: 19 additions & 8 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,22 +474,33 @@ def _export_fx_to_onnx(
import torch.onnx._internal.fx.fx_exporter as fx_exporter
import torch.onnx._internal.fx.passes as passes

# Apply decomposition table to the input graph.
module = passes.Decompose(
fx_module,
options.decomposition_table,
enable_dynamic_axes=options.dynamic_shapes,
).run(*fx_module_args)

# ONNX does not support views and mutations.
# Functionalize to get a semantically equivalent graph without mutations.
# NOTE: Functionalize must run before make_fx (decomposition and aten graph
# lowering).
# https://github.com/pytorch/pytorch/issues/99662
module = passes.Functionalize(
module, enable_dynamic_axes=options.dynamic_shapes
fx_module, enable_dynamic_axes=options.dynamic_shapes
).run(*fx_module_args)
# Input mutations are detected and distilled after `Functionalize` pass.
# Remove them since ONNX inference does not need them.
module = passes.RemoveInputMutation(module).run(*fx_module_args)

# Apply decomposition table to the input graph.
module = passes.Decompose(
module,
options.decomposition_table,
enable_dynamic_axes=options.dynamic_shapes,
).run(*fx_module_args)

# FIXME: This pass is not needed if functionalize can be applied on decomposed
# graph. https://github.com/pytorch/pytorch/issues/99662
# This is a workaround to replace inplace variant ops with outplace version.
# These ops are created by aten graph lowering and decomposition post
# functionalization. No real mutation is expected as it should have been handled
# by functionalization.
module = passes.ReplaceInplacePostFunctionalization(module).run(*fx_module_args)

# Run ShapeInferenceWithFakeTensor to get static shape of nodes for op_level_debug purposes
# The pass added nodes with static shape into original node metadata:
# node.meta["static_shape"]: FakeTensor/int/float/SymInt/SynFloat
Expand Down
5 changes: 1 addition & 4 deletions torch/onnx/_internal/fx/dynamo_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,9 @@ def generate_fx(

# Translate callable to FX graph.
#
# TODO(wechi): There are several symbolic tracing mechanisms to convert
# nn.Module to FX graph. We should choose the right one after they are
# matured.
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
graph_module, graph_guard = torch._dynamo.export(
wrapped_model, *args, aten_graph=True, tracing_mode=fx_mode, **kwargs
wrapped_model, *args, tracing_mode=fx_mode, **kwargs
)
del graph_guard # Unused
torch._dynamo.reset()
Expand Down
7 changes: 6 additions & 1 deletion torch/onnx/_internal/fx/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .decomp import Decompose
from .functionalization import Functionalize, RemoveInputMutation
from .functionalization import (
Functionalize,
RemoveInputMutation,
ReplaceInplacePostFunctionalization,
)
from .fx_to_onnxscript import export_fx_to_onnxscript
from .shape_inference import ShapeInferenceWithFakeTensor
from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder
Expand All @@ -11,5 +15,6 @@
"MovePlaceholderToFront",
"RemoveInputMutation",
"ReplaceGetAttrWithPlaceholder",
"ReplaceInplacePostFunctionalization",
"ShapeInferenceWithFakeTensor",
]
142 changes: 140 additions & 2 deletions torch/onnx/_internal/fx/passes/functionalization.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

from typing import Callable, Dict, Optional

import torch
import torch._ops
import torch.func
import torch.fx

from torch.fx.experimental import proxy_tensor
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
from torch.utils import _pytree as pytree


class Functionalize(_pass.Transform):
Expand Down Expand Up @@ -52,19 +56,47 @@ def fn(a, b):
For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass.
``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``,
which are not needed for ONNX inference.
NOTE: Functionalize must run before decomposition and aten graph lowering.
https://github.com/pytorch/pytorch/issues/99662
"""

@_beartype.beartype
def __init__(self, module: torch.fx.GraphModule, enable_dynamic_axes: bool):
super().__init__(module)
self.enable_dynamic_axes = enable_dynamic_axes

def _functionalize(self, function: Callable) -> Callable:
# Working around a dispatcher issue with `torch.func.functionalize` when used
# together with `make_fx`.
# Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391
def wrapped(*inputs):
inputs_functional = pytree.tree_map_only(
torch.Tensor, torch._to_functional_tensor, inputs
)
torch._enable_functionalization(reapply_views=True)
try:
out = function(*inputs_functional)
finally:
torch._disable_functionalization()
flat_inputs, _ = pytree.tree_flatten(inputs)
flat_inputs_functional, _ = pytree.tree_flatten(inputs_functional)
for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
if isinstance(input_functional, torch.Tensor):
torch._sync(input_functional)
inpt_new = torch._from_functional_tensor(input_functional)
pytree.tree_map(torch._sync, out)
out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out)
return out_unwrapped

return wrapped

@_beartype.beartype
def _run(self, *args) -> torch.fx.GraphModule:
# To preserve stack trace info after `make_fx`.
module = _utils.wrap_graph_module_for_node_meta_preservation(self.module)

functionalized_callable = torch.func.functionalize(module)
functionalized_callable = self._functionalize(module)
fx_mode = "symbolic" if self.enable_dynamic_axes else "fake"

graph_module = proxy_tensor.make_fx(
Expand Down Expand Up @@ -102,3 +134,109 @@ def _run(self, *args) -> torch.fx.GraphModule:
):
self.module.graph.erase_node(node)
return self.module


class ReplaceInplacePostFunctionalization(_pass.Transform):
"""Replace inplace variant ops with outplace version.
This pass assumes that the graph has been functionalized and decomposed to aten level.
No real mutation is expected as it should have been handled by functionalization.
All inplace variant op nodes are expected to be the last user of its first argument.
That is, the input being mutated cannot not be used by another node afterwards.
Otherwise, a RuntimeError will be raised.
This pass only handles ops under "aten" namespace.
NOTE: This pass is not needed if functionalize can be applied on decomposed graph.
https://github.com/pytorch/pytorch/issues/99662
"""

@_beartype.beartype
def _outplace_target(
self, inplace_target: torch._ops.OpOverload
) -> Optional[torch._ops.OpOverload]:
assert inplace_target.namespace == "aten"
outplace_name = inplace_target._schema.name.split("::")[1][:-1]
overload_name = inplace_target._overloadname

opoverloadpacket = getattr(torch.ops.aten, outplace_name)
if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket):
return None

return getattr(opoverloadpacket, overload_name, None)

@_beartype.beartype
def _run(self, *args) -> torch.fx.GraphModule:
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to validate that
# the mutated input value is not used after the mutation.
node_to_last_use: Dict[torch.fx.Node, torch.fx.Node] = {}

def register_last_uses(n: torch.fx.Node, user: torch.fx.Node):
if n not in node_to_last_use:
node_to_last_use[n] = user

for node in reversed(self.module.graph.nodes):
torch.fx.node.map_arg(node.args, lambda n: register_last_uses(n, node))
torch.fx.node.map_arg(node.kwargs, lambda n: register_last_uses(n, node))

for node in self.module.graph.nodes:
if node.op != "call_function" or not isinstance(
node.target, torch._ops.OpOverload
):
continue

target = node.target
mutated_input = node.args[0]

name_without_overload = target._schema.name
is_inplace = name_without_overload.endswith(
"_"
) and not name_without_overload.endswith("__")
is_aten = target.namespace == "aten"

if not is_inplace:
continue

if not is_aten:
# TODO(bowbao): Turn this into individual diagnostic.
diagnostic = diagnostics.export_context().inflight_diagnostic(
rule=diagnostics.rules.fx_pass
)
diagnostic.level = diagnostics.levels.WARNING
diagnostic.with_additional_message(
f"Found non-aten op {target} in graph with inplace naming convention. "
f"Skip replacing this op with outplace version."
)
continue

assert isinstance(
mutated_input, torch.fx.Node
), f"Expected mutated input to be a torch.fx.Node. Got {type(mutated_input)}"

if node_to_last_use[mutated_input] != node:
# TODO(bowbao): Turn this into individual diagnostic.
raise RuntimeError(
f"Found inplace op node {node} that is not the last use of its input. "
f"Its mutated input is later used by {node_to_last_use[mutated_input]}. "
f"Please run Functionalize pass before ReplaceInplacePostFunctionalization."
)

outplace_target = self._outplace_target(target)

if outplace_target is None:
# TODO(bowbao): Turn this into individual diagnostic.
diagnostic = diagnostics.export_context().inflight_diagnostic(
rule=diagnostics.rules.fx_pass
)
diagnostic.level = diagnostics.levels.WARNING
diagnostic.with_additional_message(
f"Failed to find outplace version of {target}. Skip replacing this op."
)
continue

node.target = outplace_target

return self.module

0 comments on commit 19f2d1e

Please sign in to comment.