Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't run addruntimeassertion pass #125948

Open
wants to merge 8 commits into
base: gh/tugsbayasgalan/219/base
Choose a base branch
from
4 changes: 1 addition & 3 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2749,9 +2749,7 @@ def forward(self, x, y):

ep = export(M(), (torch.tensor(1), torch.ones(4, 5)))

# This is because we insert sym_constrain_range in the graph now
error_msg = r"Invalid value range for -1 between"
with self.assertRaisesRegex(RuntimeError, error_msg):
with self.assertRaisesRegex(RuntimeError, "Invalid value range"):
_ = ep.module()(torch.tensor(-1), torch.randn(4, 5))

self.assertTrue(
Expand Down
5 changes: 4 additions & 1 deletion test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,10 @@ def forward(self, x):
new_inp = torch.tensor([1, 1, 1, 1])
self.assertEqual(mod(new_inp), ep.module()(new_inp))

# Currently runtime assertions don't work well with submodules
# Since it is not real use case, we xfail this test for now
@unittest.skipIf(IS_WINDOWS, "Windows not supported")
@unittest.expectedFailure
def test_runtime_assert_inline_constraints_for_cond(self) -> None:
class M(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -602,7 +605,7 @@ def false_fn(x, y):
ep = export(mod, (torch.tensor(True), x, y))

with self.assertRaisesRegex(
RuntimeError, "is outside of inline constraint \\[2, 5\\]."
RuntimeError, "Invalid value range for 6 between \\[2, 5\\]."
):
ep.module()(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))

Expand Down
4 changes: 0 additions & 4 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,6 @@ def forward(self, x, y):
func, (torch.tensor([1]), torch.randn(3, 4))
)

@pytorch_test_common.xfail_if_model_type_is_exportedprogram(
error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}",
reason="https://github.com/pytorch/pytorch/issues/112622",
)
def test_operator_with_dynamic_output_shape(self):
class Foo(torch.nn.Module):
def forward(self, x):
Expand Down
26 changes: 0 additions & 26 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
_node_metadata_hook,
_set_node_metadata_hook,
)
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_AddRuntimeAssertionsForInlineConstraintsPass,
)
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
from torch._export.passes.lift_constants_pass import (
ConstantAttrMap,
Expand Down Expand Up @@ -152,23 +149,6 @@ def _strip_root(x):
return x


def _add_runtime_assertions_to_cond_in_subgraph(range_constraints, gm, fake_mode):
# We can't get rid of this yet, since for some reason
# insert_deferred_runtime_assertions doesn't add assertions to cond
# subgraphs
if len(range_constraints) > 0:
stack_trace = (
'File "torch/_export/passes/add_runtime_assertions_for_constraints_pass.py", line 46, '
"in _AddRuntimeAssertionsForInlineConstraintsPass"
)
with fake_mode, _set_node_metadata_hook(
gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
):
res = _AddRuntimeAssertionsForInlineConstraintsPass(range_constraints)(gm)
assert res is not None
gm = res.graph_module


def _rewrite_node(gm):
for node in gm.graph.nodes:
if node.target == torch.ops.higher_order._export_tracepoint:
Expand Down Expand Up @@ -1514,12 +1494,6 @@ def _export(
dynamic_shapes,
num_lifted,
)
if strict:
_add_runtime_assertions_to_cond_in_subgraph(
range_constraints,
gm,
fake_mode,
)

# Make module signatures.
module_call_signatures = {}
Expand Down