-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides w/ dynamo.export
, make_fx
and functionalize
#99774
Comments
Are you trying to capture aten graph, if yes, you can update these lines:
to:
Since the |
@yanboliang thanks, I have heard there is plan to turn on functionalization for class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mask = torch.tensor([True, False])
def forward(self, x):
x.view(3, 2).masked_fill_(self.mask.unsqueeze(0), torch.finfo(x.dtype).max)
return x
gm, _ = torch._dynamo.export(m, x, tracing_mode="symbolic", aten_graph=True)
gm.print_readable()
"""
class GraphModule(torch.nn.Module):
def forward(self, x):
arg0: f32[6], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
# File: test_mutation.py:8, code: x.view(3, 2).masked_fill_(self.mask.unsqueeze(0), torch.finfo(x.dtype).max)
view_default: f32[3, 2] = torch.ops.aten.view.default(arg0, [3, 2])
_tensor_constant0 = self._tensor_constant0
unsqueeze_default: b8[1, 2] = torch.ops.aten.unsqueeze.default(_tensor_constant0, 0); _tensor_constant0 = None
masked_fill__scalar: f32[3, 2] = torch.ops.aten.masked_fill_.Scalar(view_default, unsqueeze_default, 3.4028234663852886e+38); view_default = unsqueeze_default = None
return pytree.tree_unflatten([arg0], self._out_spec)
""" |
We should just paste in the support for it in export. It'll be easier after @bdhirsh finishes AOT export. To fix your problem on the original case, turn on python dispatcher like this
but make_fx ALSO doesn't functionalize! So this won't actually unblock you |
I think this is +1 to enabling the python dispatcher if we are enabling a FakeTensorMode with shape_env set |
@ezyang trying below gives same error. Sorry for wasn't being clear in OP that only the addition of gm, _ = torch._dynamo.export(func, x, y)
with enable_python_dispatcher():
gm = proxy_tensor.make_fx(torch.func.functionalize(gm), tracing_mode="symbolic")(x, y) |
Oh this might just be a coverage problem then. Can you try the playbook at https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit# specifically check the docs for how to localize a C++ exception with gdb. |
Handy. I'm further narrowing down the repro to exclude import torch
import torch._dynamo
import torch.func
from torch.fx.experimental import proxy_tensor
from torch._dispatch.python import enable_python_dispatcher
def func(x, y):
return torch.matmul(x, y)
x = torch.randn(2, 4, 3, 4)
y = torch.randn(2, 4, 4, 3)
with enable_python_dispatcher():
# RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides
gm = proxy_tensor.make_fx(torch.func.functionalize(func), tracing_mode="symbolic")(x, y) Forcing if op == torch.ops.aten.matmul.default and k == DispatchKey.Functionalize:
return DispatchKey.CompositeImplicitAutograd at Line 137 in 67c329b
|
Oh hmm... I think I know what's going on. (1) Functionalization (by default) doesn't try to handle operators with (2) Instead, it defers to running the CompositeImplicit decomp when you try to run the op with functionalization (3) today, functionalization does this by assuming that (4) That's no longer true with the python dispatcher: we need to faithfully run whatever the dispatcher has registered to the CompositeImplicitAutograd key, so the python dispatcher has a chance to intercept. This isn't a real problem in the E2E torch.compile() workflow, because we're guaranteed that autograd runs (and decomposes CompositeImplicit ops), before functionalization has a chance to run. However, |
This seems like something we should fix - we'd need to update the functionalization codegen so that the function that it registers to each op with a composite implicit impl makes a call into the dispatcher, instead of manually calling @BowenBao as an unblock, I confirmed that if you manually run functionalization using the underlying API's without using
|
This is very helpful and thorough. Thanks @bdhirsh ! Can confirm the workaround works for all my test cases. One exception is that I have found another potential bug with batch norm, possibly related to https://discuss.pytorch.org/t/does-torch-func-functionalize-support-fx-graphmodule-from-dynamo-export-w-aten-graph/177365 I opened earlier. It is only happening when |
…porter'; Apply workaround in 'Functionalize' pass" Summary - Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing. That is resolved by #99555 and its follow ups. - Later decomposition pass will do graph lowering, so this step is duplicated. - Updated `Functionalization` to workaround #99774 (comment) Todo - Training vs eval in dynamo_export So we are effectively exporting all models in traning mode by default. But for the sake of this export we are only interested in eval mode. The question is, should we call `model.eval()` in `dynamo_export`? Tests with model containing batch norm fails 'functionalization' in training mode. We are explicitly calling `model.eval()` for these model for now. - Merge decomp and functionalize pass. Both calls into `make_fx`. Merging potentially increases performance. However it is unclear if it will result in different behavior. Workaround to unblock #99662. [ghstack-poisoned]
…orkaround in 'Functionalize' pass" Summary - Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing. That is resolved by #99555 and its follow ups. - Later decomposition pass will do graph lowering, so this step is duplicated. - Updated `Functionalization` to workaround #99774 (comment) Todo - Training vs eval in dynamo_export So we are effectively exporting all models in traning mode by default. But for the sake of this export we are only interested in eval mode. The question is, should we call `model.eval()` in `dynamo_export`? Tests with model containing batch norm fails 'functionalization' in training mode. We are explicitly calling `model.eval()` for these model for now. - Merge decomp and functionalize pass. Both calls into `make_fx`. Merging potentially increases performance. However it is unclear if it will result in different behavior. Workaround to unblock #99662. [ghstack-poisoned]
@bdhirsh another reason to hurry up with aot export I guess ;) |
…porter'; Apply workaround in 'Functionalize' pass" Summary - Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing. That is resolved by #99555 and its follow ups. - Later decomposition pass will do graph lowering, so this step is duplicated. - Updated `Functionalization` to workaround #99774 (comment) Todo - Training vs eval in dynamo_export So we are effectively exporting all models in traning mode by default. But for the sake of this export we are only interested in eval mode. The question is, should we call `model.eval()` in `dynamo_export`? Tests with model containing batch norm fails 'functionalization' in training mode. We are explicitly calling `model.eval()` for these model for now. - Merge decomp and functionalize pass. Both calls into `make_fx`. Merging potentially increases performance. However it is unclear if it will result in different behavior. Workaround to unblock #99662. [ghstack-poisoned]
…orkaround in 'Functionalize' pass" Summary - Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing. That is resolved by #99555 and its follow ups. - Later decomposition pass will do graph lowering, so this step is duplicated. - Updated `Functionalization` to workaround #99774 (comment) Todo - Training vs eval in dynamo_export So we are effectively exporting all models in traning mode by default. But for the sake of this export we are only interested in eval mode. The question is, should we call `model.eval()` in `dynamo_export`? Tests with model containing batch norm fails 'functionalization' in training mode. We are explicitly calling `model.eval()` for these model for now. - Merge decomp and functionalize pass. Both calls into `make_fx`. Merging potentially increases performance. However it is unclear if it will result in different behavior. Workaround to unblock #99662. [ghstack-poisoned]
…porter'; Apply workaround in 'Functionalize' pass" Summary - Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing. That is resolved by #99555 and its follow ups. - Later decomposition pass will do graph lowering, so this step is duplicated. - Updated `Functionalization` to workaround #99774 (comment) Todo - Training vs eval in dynamo_export So we are effectively exporting all models in traning mode by default. But for the sake of this export we are only interested in eval mode. The question is, should we call `model.eval()` in `dynamo_export`? Tests with model containing batch norm fails 'functionalization' in training mode. We are explicitly calling `model.eval()` for these model for now. - Merge decomp and functionalize pass. Both calls into `make_fx`. Merging potentially increases performance. However it is unclear if it will result in different behavior. Workaround to unblock #99662. [ghstack-poisoned]
…orkaround in 'Functionalize' pass" Summary - Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing. That is resolved by #99555 and its follow ups. - Later decomposition pass will do graph lowering, so this step is duplicated. - Updated `Functionalization` to workaround #99774 (comment) Todo - Training vs eval in dynamo_export So we are effectively exporting all models in traning mode by default. But for the sake of this export we are only interested in eval mode. The question is, should we call `model.eval()` in `dynamo_export`? Tests with model containing batch norm fails 'functionalization' in training mode. We are explicitly calling `model.eval()` for these model for now. - Merge decomp and functionalize pass. Both calls into `make_fx`. Merging potentially increases performance. However it is unclear if it will result in different behavior. Workaround to unblock #99662. [ghstack-poisoned]
…'aten_graph' arg for 'DynamoExporter' (#99667) Summary - Previously this was required by and entangled with `tracing_mode=symbolic` for `dynamic` tracing. That is resolved by #99555 and its follow ups. - Later decomposition pass will do graph lowering, so this step is duplicated. - Updated `Functionalization` to workaround #99774 (comment) Todo - Training vs eval in dynamo_export So we are effectively exporting all models in traning mode by default. But for the sake of this export we are only interested in eval mode. The question is, should we call `model.eval()` in `dynamo_export`? Tests with model containing batch norm fails 'functionalization' in training mode. We are explicitly calling `model.eval()` for these model for now. - Merge decomp and functionalize pass. Both calls into `make_fx`. Merging potentially increases performance. However it is unclear if it will result in different behavior. Fixes #99662. (For the functionalization issue. Still need missing op support.) Pull Request resolved: #99667 Approved by: https://github.com/titaiwangms
I'm going to tentatively close this, since you should no longer need to manually use @BowenBao feel free to re-open if you think the existing workarounds aren't the right fix |
Latest update
This is the most distilled repro.
Below is original issue post before further discussion.
🐛 Describe the bug
Distilled repro, greatly appreciate hints how to approach/debug this.
Versions
Main on e978614
cc @bdhirsh @ezyang @msaroufim @wconstab @anijain2305 @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @soumith @ngimel @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire
The text was updated successfully, but these errors were encountered: