Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides w/ dynamo.export, make_fx and functionalize #99774

Closed
BowenBao opened this issue Apr 21, 2023 · 13 comments
Labels
module: functionalization used for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch) module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@BowenBao
Copy link
Collaborator

BowenBao commented Apr 21, 2023

Latest update

This is the most distilled repro.

import torch
import torch._dynamo
import torch.func
from torch.fx.experimental import proxy_tensor
from torch._dispatch.python import enable_python_dispatcher

def func(x, y):
    return torch.matmul(x, y)

x = torch.randn(2, 4, 3, 4)
y = torch.randn(2, 4, 4, 3)

with enable_python_dispatcher():
    # RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides
    gm = proxy_tensor.make_fx(torch.func.functionalize(func), tracing_mode="symbolic")(x, y)

Below is original issue post before further discussion.

🐛 Describe the bug

Distilled repro, greatly appreciate hints how to approach/debug this.

import torch
import torch._dynamo
import torch.func
from torch.fx.experimental import proxy_tensor

def func(x, y):
    return torch.matmul(x, y.transpose(-1, -2))

x = torch.randn(2, 4, 3, 4)
y = torch.randn(2, 4, 3, 4)


gm, _ = torch._dynamo.export(func, x, y)
gm.print_readable()
gm = proxy_tensor.make_fx(torch.func.functionalize(gm), tracing_mode="symbolic")(x, y)
gm.print_readable()
Traceback (most recent call last):
  File "/home/bowbao/pytorch_dev/torch/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/bowbao/pytorch_dev/torch/fx/_symbolic_trace.py", line 756, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/home/bowbao/pytorch_dev/torch/fx/experimental/proxy_tensor.py", line 433, in call_module
    return forward(*args, **kwargs)
  File "/home/bowbao/pytorch_dev/torch/fx/_symbolic_trace.py", line 749, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/home/bowbao/pytorch_dev/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.3", line 7, in forward
    matmul = torch.matmul(arg0, transpose);  arg0 = transpose = None
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides

Call using an FX-traced Module, line 7 of the traced Module's generated forward function:
    transpose = arg1.transpose(-1, -2);  arg1 = None
    matmul = torch.matmul(arg0, transpose);  arg0 = transpose = None

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return pytree.tree_unflatten([matmul], self._out_spec)

    

Traceback (most recent call last):
  File "repro_simpler_func_dynamic.py", line 15, in <module>
    gm = proxy_tensor.make_fx(torch.func.functionalize(gm), tracing_mode="symbolic")(x, y)
  File "/home/bowbao/pytorch_dev/torch/fx/experimental/proxy_tensor.py", line 771, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_autograd), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/bowbao/pytorch_dev/torch/_dynamo/eval_frame.py", line 252, in _fn
    return fn(*args, **kwargs)
  File "/home/bowbao/pytorch_dev/torch/fx/experimental/proxy_tensor.py", line 467, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/bowbao/pytorch_dev/torch/_dynamo/eval_frame.py", line 252, in _fn
    return fn(*args, **kwargs)
  File "/home/bowbao/pytorch_dev/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "/home/bowbao/pytorch_dev/torch/fx/experimental/proxy_tensor.py", line 484, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/bowbao/pytorch_dev/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/home/bowbao/pytorch_dev/torch/_functorch/eager_transforms.py", line 1600, in wrapped
    func_outputs = func(*func_args, **func_kwargs)
  File "/home/bowbao/pytorch_dev/torch/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/bowbao/pytorch_dev/torch/fx/graph_module.py", line 279, in __call__
    raise e.with_traceback(None)
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides

Versions

Main on e978614

cc @bdhirsh @ezyang @msaroufim @wconstab @anijain2305 @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @soumith @ngimel @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@yanboliang
Copy link
Contributor

Are you trying to capture aten graph, if yes, you can update these lines:

gm, _ = torch._dynamo.export(func, x, y)
gm.print_readable()
gm = proxy_tensor.make_fx(torch.func.functionalize(gm), tracing_mode="symbolic")(x, y)
gm.print_readable()

to:

gm, _ = torch._dynamo.export(func, x, y, aten_graph=True, tracing_mode="symbolic")
gm.print_readable()

Since the export API will call make_fx automatically, and the functionalization is also applied automatically.
It works well w/ your repro.

@yanboliang yanboliang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo module: export labels Apr 24, 2023
@BowenBao
Copy link
Collaborator Author

@yanboliang thanks, I have heard there is plan to turn on functionalization for export but afaik that is not enable yet? Changing to just export w/ aten_graph and symbolic, the original repro won't raise, however mutations aren't handled.

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mask = torch.tensor([True, False])
    def forward(self, x):
        x.view(3, 2).masked_fill_(self.mask.unsqueeze(0), torch.finfo(x.dtype).max)
        return x

gm, _ = torch._dynamo.export(m, x, tracing_mode="symbolic", aten_graph=True)
gm.print_readable()
"""
class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: f32[6], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        # File: test_mutation.py:8, code: x.view(3, 2).masked_fill_(self.mask.unsqueeze(0), torch.finfo(x.dtype).max)
        view_default: f32[3, 2] = torch.ops.aten.view.default(arg0, [3, 2])
        _tensor_constant0 = self._tensor_constant0
        unsqueeze_default: b8[1, 2] = torch.ops.aten.unsqueeze.default(_tensor_constant0, 0);  _tensor_constant0 = None
        masked_fill__scalar: f32[3, 2] = torch.ops.aten.masked_fill_.Scalar(view_default, unsqueeze_default, 3.4028234663852886e+38);  view_default = unsqueeze_default = None
        return pytree.tree_unflatten([arg0], self._out_spec)
"""

@ezyang
Copy link
Contributor

ezyang commented Apr 24, 2023

We should just paste in the support for it in export. It'll be easier after @bdhirsh finishes AOT export.

To fix your problem on the original case, turn on python dispatcher like this

    with enable_python_dispatcher():
        fx_g = make_fx(helper)(*args)

but make_fx ALSO doesn't functionalize! So this won't actually unblock you

@eellison
Copy link
Contributor

I think this is +1 to enabling the python dispatcher if we are enabling a FakeTensorMode with shape_env set

@BowenBao
Copy link
Collaborator Author

@ezyang trying below gives same error. Sorry for wasn't being clear in OP that only the addition of functionalize is triggering the issue. Somehow applying functionalize leads to calling sizes() somewhere for some symbolic tensor during torch.matmul (basically repeating the tracestack lol since I don't know how to get any further) ...

gm, _ = torch._dynamo.export(func, x, y)
with enable_python_dispatcher():
    gm = proxy_tensor.make_fx(torch.func.functionalize(gm), tracing_mode="symbolic")(x, y)

@ezyang
Copy link
Contributor

ezyang commented Apr 25, 2023

Oh this might just be a coverage problem then. Can you try the playbook at https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit# specifically check the docs for how to localize a C++ exception with gdb.

@BowenBao
Copy link
Collaborator Author

Handy. I'm further narrowing down the repro to exclude dynamo, and reduce to single matmul op. Looks like it is not dispatching to CompositeImplicitAutograd after applying functionalize. @ezyang

import torch
import torch._dynamo
import torch.func
from torch.fx.experimental import proxy_tensor
from torch._dispatch.python import enable_python_dispatcher

def func(x, y):
    return torch.matmul(x, y)

x = torch.randn(2, 4, 3, 4)
y = torch.randn(2, 4, 4, 3)

with enable_python_dispatcher():
    # RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides
    gm = proxy_tensor.make_fx(torch.func.functionalize(func), tracing_mode="symbolic")(x, y)

Forcing

    if op == torch.ops.aten.matmul.default and k == DispatchKey.Functionalize:
        return DispatchKey.CompositeImplicitAutograd

at

def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
would make it pass. I wonder if this is an edge case with dispatching? Hope it helps.

@BowenBao
Copy link
Collaborator Author

Bubbling this up. Any thoughts? @ezyang @bdhirsh

@bdhirsh
Copy link
Contributor

bdhirsh commented Apr 28, 2023

Oh hmm... I think I know what's going on.

(1) Functionalization (by default) doesn't try to handle operators with CompositeImplicitAutograd implementations (in particular, some composite ops like reshape() and to() cannot be supported by functionalization

(2) Instead, it defers to running the CompositeImplicit decomp when you try to run the op with functionalization

(3) today, functionalization does this by assuming that at::native::{op} will direct you to the composite implicit decomp

(4) That's no longer true with the python dispatcher: we need to faithfully run whatever the dispatcher has registered to the CompositeImplicitAutograd key, so the python dispatcher has a chance to intercept.

This isn't a real problem in the E2E torch.compile() workflow, because we're guaranteed that autograd runs (and decomposes CompositeImplicit ops), before functionalization has a chance to run. However, torch.func.functionalize() works a bit differently, and forces functionalization to run before anything else can run.

@bdhirsh
Copy link
Contributor

bdhirsh commented Apr 28, 2023

This seems like something we should fix - we'd need to update the functionalization codegen so that the function that it registers to each op with a composite implicit impl makes a call into the dispatcher, instead of manually calling at::native::{op}.

@BowenBao as an unblock, I confirmed that if you manually run functionalization using the underlying API's without using func.functionalize directly, that works around the issue:

import torch
import torch._dynamo
import torch.func
from torch.fx.experimental import proxy_tensor
from torch._dispatch.python import enable_python_dispatcher
import torch.utils._pytree as pytree

def _functionalize(f):

    def wrapped(*inputs):
        inputs_functional = pytree.tree_map_only(torch.Tensor, torch._to_functional_tensor, inputs)
        torch._enable_functionalization(reapply_views=True)
        try:
            out = f(*inputs_functional)
        finally:
            torch._disable_functionalization()
        flat_inputs, _ = pytree.tree_flatten(inputs)
        flat_inputs_functional, _ = pytree.tree_flatten(inputs_functional)
        for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
            torch._sync(input_functional)
            inpt_new = torch._from_functional_tensor(input_functional)
        pytree.tree_map(torch._sync, out)
        out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out)
        return out_unwrapped

    return wrapped

def func(x, y):
    return torch.matmul(x, y)

x = torch.randn(2, 4, 3, 4)
y = torch.randn(2, 4, 4, 3)

with enable_python_dispatcher():
    gm = proxy_tensor.make_fx(_functionalize(func), tracing_mode="symbolic")(x, y)

@bdhirsh bdhirsh added the module: functionalization used for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch) label Apr 28, 2023
@BowenBao
Copy link
Collaborator Author

This is very helpful and thorough. Thanks @bdhirsh ! Can confirm the workaround works for all my test cases.

One exception is that I have found another potential bug with batch norm, possibly related to https://discuss.pytorch.org/t/does-torch-func-functionalize-support-fx-graphmodule-from-dynamo-export-w-aten-graph/177365 I opened earlier. It is only happening when training==True. #99662 (comment)

BowenBao added a commit that referenced this issue Apr 29, 2023
…porter'; Apply workaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Apr 29, 2023
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
@ezyang
Copy link
Contributor

ezyang commented Apr 29, 2023

@bdhirsh another reason to hurry up with aot export I guess ;)

BowenBao added a commit that referenced this issue May 1, 2023
…porter'; Apply workaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue May 1, 2023
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue May 3, 2023
…porter'; Apply workaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue May 3, 2023
…orkaround in 'Functionalize' pass"


Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Workaround to unblock #99662.

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue May 4, 2023
…'aten_graph' arg for 'DynamoExporter' (#99667)

Summary
- Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing.
  That is resolved by #99555 and its follow ups.
- Later decomposition pass will do graph lowering, so this step is duplicated.
- Updated `Functionalization` to workaround #99774 (comment)

Todo
- Training vs eval in dynamo_export
  So we are effectively exporting all models in traning mode by
  default. But for the sake of this export we are only interested in eval mode.
  The question is, should we call `model.eval()` in `dynamo_export`?
  Tests with model containing batch norm fails 'functionalization' in training mode.
  We are explicitly calling `model.eval()` for these model for now.
- Merge decomp and functionalize pass. Both calls into `make_fx`.
  Merging potentially increases performance. However it is unclear
  if it will result in different behavior.

Fixes #99662. (For the functionalization issue. Still need missing op support.)
Pull Request resolved: #99667
Approved by: https://github.com/titaiwangms
@penguinwu penguinwu added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label Nov 29, 2023
@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 12, 2023

I'm going to tentatively close this, since you should no longer need to manually use make_fx and functionalize to get out a graph: you can use aot_export_module: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L910

@BowenBao feel free to re-open if you think the existing workarounds aren't the right fix

@bdhirsh bdhirsh closed this as completed Dec 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: functionalization used for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch) module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants