Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resnet18 fails through torch.onnx.dynamo_export #99662

Closed
abock opened this issue Apr 20, 2023 · 3 comments
Closed

resnet18 fails through torch.onnx.dynamo_export #99662

abock opened this issue Apr 20, 2023 · 3 comments
Assignees
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@abock
Copy link
Contributor

abock commented Apr 20, 2023

In Pytorch@2.1.0a0+git418a9fb (418a9fb)

Script

import torch
from torchvision.models import resnet18

torch.onnx.dynamo_export(resnet18(), torch.randn(1, 3, 224, 224))

Error

RuntimeError: false INTERNAL ASSERT FAILED at "/home/aaron/src/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp":3725, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

While executing %copy_ : [#users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%_tensor_constant1_1, %add_1), kwargs = {})
Original traceback:
  File "/home/aaron/.local/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
    return self._forward_impl(x)
  File "/home/aaron/.local/lib/python3.10/site-packages/torchvision/models/resnet.py", line 269, in _forward_impl
    x = self.bn1(x)

Traceback

Traceback (most recent call last):
  File "/home/aaron/src/pytorch/onnx_playground.py", line 6, in <module>
    for node in torch.onnx.dynamo_export(
  File "<@beartype(torch.onnx._internal.exporter.dynamo_export) at 0x7f30fed33370>", line 53, in dynamo_export
  File "/home/aaron/src/pytorch/torch/onnx/_internal/exporter.py", line 549, in dynamo_export
    ).export()
  File "/home/aaron/src/pytorch/torch/onnx/_internal/fx/dynamo_exporter.py", line 181, in export
    return self.export_fx_to_onnx(graph_module, merged_args)
  File "<@beartype(torch.onnx._internal.fx.fx_exporter.FXGraphModuleExporter.export_fx_to_onnx) at 0x7f30f9c2e680>", line 53, in export_fx_to_onnx
  File "/home/aaron/src/pytorch/torch/onnx/_internal/fx/fx_exporter.py", line 310, in export_fx_to_onnx
    module = passes.Functionalize(
  File "/home/aaron/src/pytorch/torch/onnx/_internal/diagnostics/infra/decorator.py", line 124, in wrapper
    return_values = fn(*args, **kwargs)
  File "/home/aaron/src/pytorch/torch/onnx/_internal/fx/_pass.py", line 190, in run
    module = self._run(*args, **kwargs)
  File "<@beartype(torch.onnx._internal.fx.passes.functionalization.Functionalize._run) at 0x7f30f9c17400>", line 11, in _run
  File "/home/aaron/src/pytorch/torch/onnx/_internal/fx/passes/functionalization.py", line 70, in _run
    graph_module = proxy_tensor.make_fx(
  File "/home/aaron/src/pytorch/torch/fx/experimental/proxy_tensor.py", line 771, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/aaron/src/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
    return fn(*args, **kwargs)
  File "/home/aaron/src/pytorch/torch/fx/experimental/proxy_tensor.py", line 467, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/aaron/src/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
    return fn(*args, **kwargs)
  File "/home/aaron/src/pytorch/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "/home/aaron/src/pytorch/torch/fx/experimental/proxy_tensor.py", line 484, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/aaron/src/pytorch/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/home/aaron/src/pytorch/torch/_functorch/eager_transforms.py", line 1600, in wrapped
    func_outputs = func(*func_args, **func_kwargs)
  File "/home/aaron/src/pytorch/torch/onnx/_internal/fx/passes/_utils.py", line 30, in wrapped
    return torch.fx.Interpreter(graph_module).run(*args)
  File "/home/aaron/src/pytorch/torch/fx/interpreter.py", line 137, in run
    self.env[node] = self.run_node(node)
  File "/home/aaron/src/pytorch/torch/fx/interpreter.py", line 179, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/aaron/src/pytorch/torch/fx/interpreter.py", line 251, in call_function
    return target(*args, **kwargs)
  File "/home/aaron/src/pytorch/torch/_ops.py", line 398, in __call__
    return self._op(*args, **kwargs or {})
@abock abock added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 20, 2023
@BowenBao
Copy link
Collaborator

BowenBao commented Apr 20, 2023

A dissected repro for reference

import torch
import torch._dynamo
from torch import func
from torch.fx.experimental import proxy_tensor
from torchvision.models import resnet18

dummy_input = torch.randn(1, 3, 224, 224)
gm, _ = torch._dynamo.export(resnet18(), dummy_input, aten_graph=True)
gm = proxy_tensor.make_fx(func.functionalize(gm))(dummy_input)

"""
Traceback (most recent call last):
  File "repro_resnet.py", line 9, in <module>
    gm = proxy_tensor.make_fx(func.functionalize(gm))(dummy_input)
  File "/home/bowbao/pytorch/torch/fx/experimental/proxy_tensor.py", line 771, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/bowbao/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
    return fn(*args, **kwargs)
  File "/home/bowbao/pytorch/torch/fx/experimental/proxy_tensor.py", line 467, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/bowbao/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
    return fn(*args, **kwargs)
  File "/home/bowbao/pytorch/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "/home/bowbao/pytorch/torch/fx/experimental/proxy_tensor.py", line 484, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/bowbao/pytorch/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/home/bowbao/pytorch/torch/_functorch/eager_transforms.py", line 1600, in wrapped
    func_outputs = func(*func_args, **func_kwargs)
  File "/home/bowbao/pytorch/torch/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/bowbao/pytorch/torch/fx/graph_module.py", line 281, in __call__
    raise e
  File "/home/bowbao/pytorch/torch/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/bowbao/pytorch/torch/fx/_symbolic_trace.py", line 756, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/home/bowbao/pytorch/torch/fx/experimental/proxy_tensor.py", line 433, in call_module
    return forward(*args, **kwargs)
  File "/home/bowbao/pytorch/torch/fx/_symbolic_trace.py", line 749, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/home/bowbao/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.4", line 15, in forward
  File "/home/bowbao/pytorch/torch/_ops.py", line 398, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: false INTERNAL ASSERT FAILED at "/home/bowbao/pytorch/build/aten/src/ATen/RegisterFunctionalization_2.cpp":7718, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.
"""

cc @bdhirsh

BowenBao added a commit that referenced this issue Apr 21, 2023
…porter'"


Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 21, 2023
Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 21, 2023
…porter'"


Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 21, 2023
Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
@BowenBao
Copy link
Collaborator

BowenBao commented Apr 28, 2023

A further reduced repro. BatchNorm, make_fx w/ symbolic and functionalize.

Update: also found this repros only for training. Calling .eval() on the model will make it pass.

import torch
from torch.fx.experimental import proxy_tensor
from typing import Callable
from torch.utils import _pytree as pytree


def _functionalize(function: Callable) -> Callable:
    def wrapped(*inputs):
        inputs_functional = pytree.tree_map_only(
            torch.Tensor, torch._to_functional_tensor, inputs
        )
        torch._enable_functionalization(reapply_views=True)
        try:
            out = function(*inputs_functional)
        finally:
            torch._disable_functionalization()
        flat_inputs, _ = pytree.tree_flatten(inputs)
        flat_inputs_functional, _ = pytree.tree_flatten(inputs_functional)
        for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
            if isinstance(input_functional, torch.Tensor):
                torch._sync(input_functional)
                inpt_new = torch._from_functional_tensor(input_functional)
        pytree.tree_map(torch._sync, out)
        out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out)
        return out_unwrapped

    return wrapped


dummy_input = torch.randn(2, 3, 224, 224)


class BNModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        return self.bn(x)


gm = proxy_tensor.make_fx(
    _functionalize(BNModule()), tracing_mode="symbolic", _allow_non_fake_inputs=True
)(dummy_input)
# RuntimeError: false INTERNAL ASSERT FAILED at "/home/bowbao/pytorch_dev/build/aten/src/ATen/RegisterFunctionalization_2.cpp":7731, 
# please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor 
# is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() 
# call.

Side note, the same error could be repro-ed from aot_function too.

from torch._functorch.aot_autograd import aot_function
def print_compile_fn(fx_module, args):
    print(fx_module)
    return fx_module
aot_fn = aot_function(BNModule(), print_compile_fn)
aot_fn(dummy_input)

BowenBao added a commit that referenced this issue Apr 28, 2023
… topic on "[ONNX] Drop 'aten_graph' arg for 'DynamoExporter'"


Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 28, 2023
…] Drop 'aten_graph' arg for 'DynamoExporter'"


Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 28, 2023
… Drop 'aten_graph' arg for 'DynamoExporter'"


Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 28, 2023
…h' arg for 'DynamoExporter'"


Summary
- Previously this was required by `tracing_mode=symbolic` for `dynamic` tracing.
  That argument will be dropped by #99555.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Functionalization currently cannot work properly on aten level graph.
  So it must happen before lowering & decompositions.
- Introduce `ReplaceInplacePostFunctionalization` pass to replace inplace variant ops with outplace version.
  These ops are created by aten graph lowering and decomposition post functionalization. They
  won't be doing any real mutation as it is expected to have been handled by functionalization.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 29, 2023
…porter'; Apply workaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 29, 2023
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue May 1, 2023
…porter'; Apply workaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue May 1, 2023
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
@bdhirsh
Copy link
Contributor

bdhirsh commented May 2, 2023

Hey @BowenBao: this is (unfortunately) a known limitation of functionalization: batchnorm during training will try to mutate the buffs on your module, but if you just try to run functionalization on a module's forward, we won't be able to functionalize the mutations that happens to that module's state (params and buffers).

In the next few weeks, I'm hoping to have an API in AOTAutograd that gives (limited) support for exporting out an inference / training graph (without the optimizer step), that will properly lift params/buffers into graph inputs. I can tag you when there's a PR ready for it. In the meantime, one thing you could do to hack around it is to manually lift params and buffers into graph inputs yourself, the same way that AOTAutograd does:

def functional_call(named_params, named_buffers, *args, **kwargs):

BowenBao added a commit that referenced this issue May 3, 2023
…porter'; Apply workaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue May 3, 2023
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants