Skip to content

Commit

Permalink
[ONNX] Update 'Functionalize' pass to support pre-decomp graph; Drop …
Browse files Browse the repository at this point in the history
…'aten_graph' arg for 'DynamoExporter' (#99667)

Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Fixes #99662. (For the functionalization issue. Still need missing op support.)
Pull Request resolved: #99667
Approved by: https://github.com/titaiwangms
  • Loading branch information
BowenBao authored and pytorchmergebot committed May 4, 2023
1 parent 9bc68fc commit f827563
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 15 deletions.
40 changes: 34 additions & 6 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,43 @@ def forward(self, x):
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(Model(), (input,))

@pytorch_test_common.xfail(
"RuntimeError: false INTERNAL ASSERT FAILED at "
"'/home/titaiwang/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp':3725,"
" please report a bug to PyTorch. mutating a non-functional tensor with a "
"functional tensor is not allowed. Please ensure that all of your inputs are "
"wrapped inside of a functionalize() call."
"RuntimeError: Unknown call_function target: aten.mean.dim"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_resnet18(self):
# TODO(bowbao): Note [training vs eval in dynamo_export]
# So we are effectively exporting all models in traning mode by
# default. But for the sake of this export we are only interested in eval mode.
# The question is, should we call `model.eval()` in `dynamo_export`?
# This particular test fails 'functionalization' in training mode.
# So we are explicitly calling `model.eval()` for any model that contains
# batch norm.
# Ref: https://github.com/pytorch/pytorch/issues/99662#issuecomment-1528178221
model = torchvision.models.resnet18(pretrained=False).eval()
dummy_input = torch.randn(1, 3, 224, 224)

self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model,
(dummy_input,),
)

@pytorch_test_common.xfail(
"RuntimeError: Unknown call_function target: aten.mean.dim"
)
@pytorch_test_common.skip_min_ort_version(
reason="ORT doesn't support dynamic fx exporter yet making SegFault flaky test",
version="1.15",
dynamic_only=True,
)
@skip_if_no_torchvision
def test_shufflenet_v2(self):
model = torchvision.models.shufflenet_v2_x0_5(pretrained=False)
# TODO(bowbao): see Note [training vs eval in dynamo_export]
model = torchvision.models.shufflenet_v2_x0_5(pretrained=False).eval()
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True)

Expand Down
5 changes: 1 addition & 4 deletions torch/onnx/_internal/fx/dynamo_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,9 @@ def generate_fx(

# Translate callable to FX graph.
#
# TODO(wechi): There are several symbolic tracing mechanisms to convert
# nn.Module to FX graph. We should choose the right one after they are
# matured.
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
graph_module, graph_guard = torch._dynamo.export(
wrapped_model, *args, aten_graph=True, tracing_mode=fx_mode, **kwargs
wrapped_model, *args, tracing_mode=fx_mode, **kwargs
)
del graph_guard # Unused
torch._dynamo.reset()
Expand Down
40 changes: 35 additions & 5 deletions torch/onnx/_internal/fx/passes/functionalization.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from __future__ import annotations

from typing import Callable

import torch
import torch._ops
import torch.func
import torch.fx

from torch.fx.experimental import proxy_tensor
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
from torch.utils import _pytree as pytree


class Functionalize(_pass.Transform):
"""Functionalize a GraphModule.
This pass utilizes ``torch.func.functionalize`` to convert a GraphModule into a
functional form. The two main functionalities are (copied from its documentations):
This pass utilizes ``functionalization`` utility of ``torch._functorch`` to convert
a GraphModule into a functional form. The two main functionalities are (copied from
its documentations):
* ``torch.func.functionalize`` removes (intermediate) mutations and aliasing from a
* ``functionalization`` removes (intermediate) mutations and aliasing from a
function, while preserving the function's semantics.
* ``torch.func.functionalize`` also removes mutations (and views) that were performed
* ``functionalization`` also removes mutations (and views) that were performed
on function inputs. However to preserve semantics, functionalize will "fix up" the
mutations after the transform has finished running, by detecting if any tensor inputs
"should have" been mutated, and copying the new data back to the inputs if necessary.
Expand Down Expand Up @@ -64,12 +69,37 @@ def __init__(
super().__init__(diagnostic_context, module)
self.enable_dynamic_axes = enable_dynamic_axes

def _functionalize(self, function: Callable) -> Callable:
# Working around a dispatcher issue with `torch.func.functionalize` when used
# together with `make_fx`.
# Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391
def wrapped(*inputs):
inputs_functional = pytree.tree_map_only(
torch.Tensor, torch._to_functional_tensor, inputs
)
torch._enable_functionalization(reapply_views=True)
try:
out = function(*inputs_functional)
finally:
torch._disable_functionalization()
flat_inputs, _ = pytree.tree_flatten(inputs)
flat_inputs_functional, _ = pytree.tree_flatten(inputs_functional)
for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
if isinstance(input_functional, torch.Tensor):
torch._sync(input_functional)
inpt_new = torch._from_functional_tensor(input_functional)
pytree.tree_map(torch._sync, out)
out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out)
return out_unwrapped

return wrapped

@_beartype.beartype
def _run(self, *args) -> torch.fx.GraphModule:
# To preserve stack trace info after `make_fx`.
module = _utils.wrap_graph_module_for_node_meta_preservation(self.module)

functionalized_callable = torch.func.functionalize(module)
functionalized_callable = self._functionalize(module)
fx_mode = "symbolic" if self.enable_dynamic_axes else "fake"

graph_module = proxy_tensor.make_fx(
Expand Down

0 comments on commit f827563

Please sign in to comment.