From cc4c73ed111320d9e820cfdd810abb9e35fa5c5f Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 10 May 2024 11:53:32 -0700 Subject: [PATCH 1/4] Don't run addruntimeassertion pass [ghstack-poisoned] --- test/export/test_passes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 74aff27376c81..d5935585764d3 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -573,7 +573,10 @@ def forward(self, x): new_inp = torch.tensor([1, 1, 1, 1]) self.assertEqual(mod(new_inp), ep.module()(new_inp)) + # Currently runtime assertions don't work well with submodules + # Since it is not real use case, we xfail this test for now @unittest.skipIf(IS_WINDOWS, "Windows not supported") + @unittest.expectedFailure def test_runtime_assert_inline_constraints_for_cond(self) -> None: class M(torch.nn.Module): def __init__(self): @@ -601,7 +604,7 @@ def false_fn(x, y): ep = export(mod, (torch.tensor(True), x, y)) with self.assertRaisesRegex( - RuntimeError, "is outside of inline constraint \\[2, 5\\]." + RuntimeError, "Invalid value range for 6 between \\[2, 5\\]." ): ep.module()(torch.tensor(False), torch.tensor([6]), torch.tensor([6])) From 4d35ecf99e6b4ab74089af37b891441d403fe5f4 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 10 May 2024 11:59:11 -0700 Subject: [PATCH 2/4] Update on "Don't run addruntimeassertion pass" [ghstack-poisoned] --- torch/export/_trace.py | 18 ------------------ torch/fx/proxy.py | 3 +++ 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index d1ba830372a56..b7829d3d37e1c 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -23,9 +23,6 @@ produce_guards_and_solve_constraints, ) from torch._export.passes._node_metadata_hook import _node_metadata_hook -from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( - _AddRuntimeAssertionsForInlineConstraintsPass, -) from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass from torch._export.passes.lift_constants_pass import ( ConstantAttrMap, @@ -1340,21 +1337,6 @@ def forward(self, *args, **kwargs): assert res is not None gm = res.graph_module - # We can't get rid of this yet, since for some reason - # insert_deferred_runtime_assertions doesn't add assertions to cond - # subgraphs - if len(range_constraints) > 0: - stack_trace = ( - 'File "torch/_export/passes/add_runtime_assertions_for_constraints_pass.py", line 46, ' - "in _AddRuntimeAssertionsForInlineConstraintsPass" - ) - with dynamo_fake_mode, gm._set_create_node_hook( - functools.partial(_node_metadata_hook, stack_trace=stack_trace) - ): - res = _AddRuntimeAssertionsForInlineConstraintsPass(range_constraints)(gm) - assert res is not None - gm = res.graph_module - assert orig_out_spec is not None _verify_nn_module_stack(gm) _verify_stack_trace(gm) diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index f732b21080ddb..98989627ad6e9 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -133,6 +133,9 @@ def create_node(self, kind : str, target : Target, modification of values used in node creation. For example, one might want to disallow in-place operations from being recorded. """ + + if kind == "call_function" and "mul_2" in str(args): + breakpoint() if kind == 'call_function' and self.check_mutable_operations: check_for_mutable_operation(target, args, kwargs) From cd28dfdc5679d28394c805da971512d7fef85121 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 10 May 2024 12:01:48 -0700 Subject: [PATCH 3/4] Update on "Don't run addruntimeassertion pass" [ghstack-poisoned] --- torch/fx/proxy.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 98989627ad6e9..f732b21080ddb 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -133,9 +133,6 @@ def create_node(self, kind : str, target : Target, modification of values used in node creation. For example, one might want to disallow in-place operations from being recorded. """ - - if kind == "call_function" and "mul_2" in str(args): - breakpoint() if kind == 'call_function' and self.check_mutable_operations: check_for_mutable_operation(target, args, kwargs) From 674560d03f0e5d35ac1b9055e27541077560352b Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 10 May 2024 13:09:50 -0700 Subject: [PATCH 4/4] Update on "Don't run addruntimeassertion pass" [ghstack-poisoned] --- test/export/test_export.py | 7 +------ test/onnx/test_fx_to_onnx_with_onnxruntime.py | 4 ---- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index f3852ed1c4523..7c6a4c3804e81 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2430,12 +2430,7 @@ def forward(self, x, y): ep = export(M(), (torch.tensor(1), torch.ones(4, 5))) - # This is because we insert sym_constrain_range in the graph now - if is_non_strict_test(self._testMethodName): - error_msg = "Invalid value range" - else: - error_msg = "is outside of inline constraint" - with self.assertRaisesRegex(RuntimeError, error_msg): + with self.assertRaisesRegex(RuntimeError, "Invalid value range"): _ = ep.module()(torch.tensor(-1), torch.randn(4, 5)) self.assertTrue( diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 3f8ba1de2bc3f..23dc5acdb52df 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -647,10 +647,6 @@ def forward(self, x, y): func, (torch.tensor([1]), torch.randn(3, 4)) ) - @pytorch_test_common.xfail_if_model_type_is_exportedprogram( - error_message="Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}", - reason="https://github.com/pytorch/pytorch/issues/112622", - ) def test_operator_with_dynamic_output_shape(self): class Foo(torch.nn.Module): def forward(self, x):