Skip to content

Commit

Permalink
[ONNX] Drop 'aten_graph' arg for 'DynamoExporter'
Browse files Browse the repository at this point in the history
ghstack-source-id: 78e3950a12d32e764987857a43faf3935004ae7f
Pull Request resolved: #99667
  • Loading branch information
BowenBao committed Apr 29, 2023
1 parent 0ed3b7d commit c684f4c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 15 deletions.
40 changes: 34 additions & 6 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,43 @@ def forward(self, x):
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, Model(), (input,))

@pytorch_test_common.xfail(
"RuntimeError: false INTERNAL ASSERT FAILED at "
"'/home/titaiwang/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp':3725,"
" please report a bug to PyTorch. mutating a non-functional tensor with a "
"functional tensor is not allowed. Please ensure that all of your inputs are "
"wrapped inside of a functionalize() call."
"RuntimeError: Unknown call_function target: aten.mean.dim"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_resnet18(self):
# TODO(bowbao): Note [training vs eval in dynamo_export]
# So we are effectively exporting all models in traning mode by
# default. But for the sake of this export we are only interested in eval mode.
# The question is, should we call `model.eval()` in `dynamo_export`?
# This particular test fails 'functionalization' in training mode.
# So we are explicitly calling `model.eval()` for any model that contains
# batch norm.
# Ref: https://github.com/pytorch/pytorch/issues/99662#issuecomment-1528178221
model = torchvision.models.resnet18(pretrained=False).eval()
dummy_input = torch.randn(1, 3, 224, 224)

self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model,
(dummy_input,),
)

@pytorch_test_common.xfail(
"RuntimeError: Unknown call_function target: aten.mean.dim"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_shufflenet_v2(self):
model = torchvision.models.shufflenet_v2_x0_5(pretrained=False)
# TODO(bowbao): see Note [training vs eval in dynamo_export]
model = torchvision.models.shufflenet_v2_x0_5(pretrained=False).eval()
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True)

Expand Down
5 changes: 1 addition & 4 deletions torch/onnx/_internal/fx/dynamo_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,9 @@ def generate_fx(

# Translate callable to FX graph.
#
# TODO(wechi): There are several symbolic tracing mechanisms to convert
# nn.Module to FX graph. We should choose the right one after they are
# matured.
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
graph_module, graph_guard = torch._dynamo.export(
wrapped_model, *args, aten_graph=True, tracing_mode=fx_mode, **kwargs
wrapped_model, *args, tracing_mode=fx_mode, **kwargs
)
del graph_guard # Unused
torch._dynamo.reset()
Expand Down
40 changes: 35 additions & 5 deletions torch/onnx/_internal/fx/passes/functionalization.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from __future__ import annotations

from typing import Callable

import torch
import torch._ops
import torch.func
import torch.fx

from torch.fx.experimental import proxy_tensor
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass
from torch.onnx._internal.fx.passes import _utils
from torch.utils import _pytree as pytree


class Functionalize(_pass.Transform):
"""Functionalize a GraphModule.
This pass utilizes ``torch.func.functionalize`` to convert a GraphModule into a
functional form. The two main functionalities are (copied from its documentations):
This pass utilizes ``functionalization`` utility of ``torch._functorch`` to convert
a GraphModule into a functional form. The two main functionalities are (copied from
its documentations):
* ``torch.func.functionalize`` removes (intermediate) mutations and aliasing from a
* ``functionalization`` removes (intermediate) mutations and aliasing from a
function, while preserving the function's semantics.
* ``torch.func.functionalize`` also removes mutations (and views) that were performed
* ``functionalization`` also removes mutations (and views) that were performed
on function inputs. However to preserve semantics, functionalize will "fix up" the
mutations after the transform has finished running, by detecting if any tensor inputs
"should have" been mutated, and copying the new data back to the inputs if necessary.
Expand Down Expand Up @@ -59,12 +64,37 @@ def __init__(self, module: torch.fx.GraphModule, enable_dynamic_axes: bool):
super().__init__(module)
self.enable_dynamic_axes = enable_dynamic_axes

def _functionalize(self, function: Callable) -> Callable:
# Working around a dispatcher issue with `torch.func.functionalize` when used
# together with `make_fx`.
# Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391
def wrapped(*inputs):
inputs_functional = pytree.tree_map_only(
torch.Tensor, torch._to_functional_tensor, inputs
)
torch._enable_functionalization(reapply_views=True)
try:
out = function(*inputs_functional)
finally:
torch._disable_functionalization()
flat_inputs, _ = pytree.tree_flatten(inputs)
flat_inputs_functional, _ = pytree.tree_flatten(inputs_functional)
for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
if isinstance(input_functional, torch.Tensor):
torch._sync(input_functional)
inpt_new = torch._from_functional_tensor(input_functional)
pytree.tree_map(torch._sync, out)
out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out)
return out_unwrapped

return wrapped

@_beartype.beartype
def _run(self, *args) -> torch.fx.GraphModule:
# To preserve stack trace info after `make_fx`.
module = _utils.wrap_graph_module_for_node_meta_preservation(self.module)

functionalized_callable = torch.func.functionalize(module)
functionalized_callable = self._functionalize(module)
fx_mode = "symbolic" if self.enable_dynamic_axes else "fake"

graph_module = proxy_tensor.make_fx(
Expand Down

0 comments on commit c684f4c

Please sign in to comment.