Skip to content

[Export] dynamic shape divisibility problem results in excessive warnings from Ignored guard + stack trace #139408

@henrylhtsang

Description

@henrylhtsang

🐛 Describe the bug

Issue is split from #137748

repro:

class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y=None):
        return x + y.reshape(x.shape)

model = TestModule().cuda()
x = torch.ones(4, 4).cuda()
y = torch.rand(2, 8).cuda()

model(x, y)
from torch.export import Dim

_ = torch.export.export(
    model,
    (x, y),
    dynamic_shapes={"x": None, "y": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}},
    strict=False,
)

logs:

torch/fx/experimental/symbolic_shapes.py:5956] Ignored guard Eq(s0*s1, 16) == True, this could result in accuracy problems
torch.export.export(
torch/export/__init__.py", line 368, in export
torch/fx/experimental/symbolic_shapes.py:5956]     return _export(
torch/export/_trace.py", line 1002, in wrapper
torch/fx/experimental/symbolic_shapes.py:5956]     ep = fn(*args, **kwargs)
torch/export/exported_program.py", line 121, in wrapper
torch/fx/experimental/symbolic_shapes.py:5956]     return fn(*args, **kwargs)
torch/export/_trace.py", line 1942, in _export
torch/fx/experimental/symbolic_shapes.py:5956]     export_artifact = export_func(  # type: ignore[operator]
torch/export/_trace.py", line 1730, in _non_strict_export
torch/fx/experimental/symbolic_shapes.py:5956]     aten_export_artifact = _to_aten_func(  # type: ignore[operator]
torch/export/_trace.py", line 773, in _export_to_aten_ir
torch/fx/experimental/symbolic_shapes.py:5956]     return _produce_aten_artifact(
torch/export/_trace.py", line 437, in _produce_aten_artifact
torch/fx/experimental/symbolic_shapes.py:5956]     gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature)
torch/_export/utils.py", line 555, in apply_runtime_assertion_pass
torch/fx/experimental/symbolic_shapes.py:5956]     insert_deferred_runtime_asserts(
torch/fx/passes/runtime_assert.py", line 324, in insert_deferred_runtime_asserts
torch/fx/experimental/symbolic_shapes.py:5956]     add_runtime_asserts(ras_by_symbol.pop(None, []))  # type: ignore[call-overload]
torch/fx/passes/runtime_assert.py", line 256, in add_runtime_asserts
torch/fx/experimental/symbolic_shapes.py:5956]     graph.call_function(
torch/fx/graph.py", line 1489, in call_function
torch/fx/experimental/symbolic_shapes.py:5956]     return self.create_node(
torch/fx/graph.py", line 1154, in create_node
torch/fx/experimental/symbolic_shapes.py:5956]     f(n)
torch/fx/passes/runtime_assert.py", line 175, in _node_metadata_hook
torch/fx/experimental/symbolic_shapes.py:5956]     node.meta[val_key] = node.target(*fake_args)  # type: ignore[operator]
torch/_ops.py", line 723, in __call__
torch/fx/experimental/symbolic_shapes.py:5956]     return self._op(*args, **kwargs)
torch_function__
torch/fx/experimental/symbolic_shapes.py:5956]     return func(*args, **kwargs)
torch/_ops.py", line 723, in __call__
torch/fx/experimental/symbolic_shapes.py:5956]     return self._op(*args, **kwargs)
torch/utils/_stats.py", line 21, in wrapper
torch/fx/experimental/symbolic_shapes.py:5956]     return fn(*args, **kwargs)
torch_dispatch__
torch/fx/experimental/symbolic_shapes.py:5956]     return self.dispatch(func, types, args, kwargs)
torch/_subclasses/fake_tensor.py", line 1797, in dispatch
torch/fx/experimental/symbolic_shapes.py:5956]     return self._cached_dispatch_impl(func, types, args, kwargs)
torch/_subclasses/fake_tensor.py", line 1365, in _cached_dispatch_impl
torch/fx/experimental/symbolic_shapes.py:5956]     output = self._dispatch_impl(func, types, args, kwargs)
torch/_subclasses/fake_tensor.py", line 2160, in _dispatch_impl
torch/fx/experimental/symbolic_shapes.py:5956]     r = func(*args, **kwargs)
torch/_ops.py", line 723, in __call__
torch/fx/experimental/symbolic_shapes.py:5956]     return self._op(*args, **kwargs)
torch/fx/experimental/sym_node.py", line 515, in expect_true
torch/fx/experimental/symbolic_shapes.py:5956]     return self.shape_env.defer_runtime_assert(
torch/fx/experimental/recording.py", line 263, in wrapper
torch/fx/experimental/symbolic_shapes.py:5956]     return retlog(fn(*args, **kwargs))
torch/fx/experimental/symbolic_shapes.py", line 6363, in defer_runtime_assert
torch/fx/experimental/symbolic_shapes.py:5956]     self._check_frozen(expr, sympy.true)
torch/fx/experimental/symbolic_shapes.py", line 5956, in _check_frozen
torch/fx/experimental/symbolic_shapes.py:5956]     log.warning(

cc @chauhang @penguinwu @ezyang @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @pianpwk

Versions

trunk

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions