-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: dynamic shapesoncall: exportoncall: pt2
Description
🐛 Describe the bug
Issue is split from #137748
repro:
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y=None):
return x + y.reshape(x.shape)
model = TestModule().cuda()
x = torch.ones(4, 4).cuda()
y = torch.rand(2, 8).cuda()
model(x, y)
from torch.export import Dim
_ = torch.export.export(
model,
(x, y),
dynamic_shapes={"x": None, "y": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}},
strict=False,
)
logs:
torch/fx/experimental/symbolic_shapes.py:5956] Ignored guard Eq(s0*s1, 16) == True, this could result in accuracy problems
torch.export.export(
torch/export/__init__.py", line 368, in export
torch/fx/experimental/symbolic_shapes.py:5956] return _export(
torch/export/_trace.py", line 1002, in wrapper
torch/fx/experimental/symbolic_shapes.py:5956] ep = fn(*args, **kwargs)
torch/export/exported_program.py", line 121, in wrapper
torch/fx/experimental/symbolic_shapes.py:5956] return fn(*args, **kwargs)
torch/export/_trace.py", line 1942, in _export
torch/fx/experimental/symbolic_shapes.py:5956] export_artifact = export_func( # type: ignore[operator]
torch/export/_trace.py", line 1730, in _non_strict_export
torch/fx/experimental/symbolic_shapes.py:5956] aten_export_artifact = _to_aten_func( # type: ignore[operator]
torch/export/_trace.py", line 773, in _export_to_aten_ir
torch/fx/experimental/symbolic_shapes.py:5956] return _produce_aten_artifact(
torch/export/_trace.py", line 437, in _produce_aten_artifact
torch/fx/experimental/symbolic_shapes.py:5956] gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature)
torch/_export/utils.py", line 555, in apply_runtime_assertion_pass
torch/fx/experimental/symbolic_shapes.py:5956] insert_deferred_runtime_asserts(
torch/fx/passes/runtime_assert.py", line 324, in insert_deferred_runtime_asserts
torch/fx/experimental/symbolic_shapes.py:5956] add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload]
torch/fx/passes/runtime_assert.py", line 256, in add_runtime_asserts
torch/fx/experimental/symbolic_shapes.py:5956] graph.call_function(
torch/fx/graph.py", line 1489, in call_function
torch/fx/experimental/symbolic_shapes.py:5956] return self.create_node(
torch/fx/graph.py", line 1154, in create_node
torch/fx/experimental/symbolic_shapes.py:5956] f(n)
torch/fx/passes/runtime_assert.py", line 175, in _node_metadata_hook
torch/fx/experimental/symbolic_shapes.py:5956] node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
torch/_ops.py", line 723, in __call__
torch/fx/experimental/symbolic_shapes.py:5956] return self._op(*args, **kwargs)
torch_function__
torch/fx/experimental/symbolic_shapes.py:5956] return func(*args, **kwargs)
torch/_ops.py", line 723, in __call__
torch/fx/experimental/symbolic_shapes.py:5956] return self._op(*args, **kwargs)
torch/utils/_stats.py", line 21, in wrapper
torch/fx/experimental/symbolic_shapes.py:5956] return fn(*args, **kwargs)
torch_dispatch__
torch/fx/experimental/symbolic_shapes.py:5956] return self.dispatch(func, types, args, kwargs)
torch/_subclasses/fake_tensor.py", line 1797, in dispatch
torch/fx/experimental/symbolic_shapes.py:5956] return self._cached_dispatch_impl(func, types, args, kwargs)
torch/_subclasses/fake_tensor.py", line 1365, in _cached_dispatch_impl
torch/fx/experimental/symbolic_shapes.py:5956] output = self._dispatch_impl(func, types, args, kwargs)
torch/_subclasses/fake_tensor.py", line 2160, in _dispatch_impl
torch/fx/experimental/symbolic_shapes.py:5956] r = func(*args, **kwargs)
torch/_ops.py", line 723, in __call__
torch/fx/experimental/symbolic_shapes.py:5956] return self._op(*args, **kwargs)
torch/fx/experimental/sym_node.py", line 515, in expect_true
torch/fx/experimental/symbolic_shapes.py:5956] return self.shape_env.defer_runtime_assert(
torch/fx/experimental/recording.py", line 263, in wrapper
torch/fx/experimental/symbolic_shapes.py:5956] return retlog(fn(*args, **kwargs))
torch/fx/experimental/symbolic_shapes.py", line 6363, in defer_runtime_assert
torch/fx/experimental/symbolic_shapes.py:5956] self._check_frozen(expr, sympy.true)
torch/fx/experimental/symbolic_shapes.py", line 5956, in _check_frozen
torch/fx/experimental/symbolic_shapes.py:5956] log.warning(
cc @chauhang @penguinwu @ezyang @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @pianpwk
Versions
trunk
Metadata
Metadata
Assignees
Labels
export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: dynamic shapesoncall: exportoncall: pt2