From e786dc26f0f32ea250df5d68bb49ec184582b5e3 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Thu, 30 Jan 2025 09:40:15 -0800 Subject: [PATCH] Update verifier (#8034) Summary: Fixes https://github.com/pytorch/executorch/issues/7998 Reviewed By: JacobSzwejbka Differential Revision: D68839524 --- exir/program/test/test_program.py | 39 +++++++++++++++++++++++++ exir/verification/test/test_verifier.py | 34 --------------------- exir/verification/verifier.py | 36 ++++++++++++----------- 3 files changed, 58 insertions(+), 51 deletions(-) diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 046ad03e757..d5e0d15d4ad 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -313,6 +313,45 @@ def forward(self, x, y): ) edge_manager.to_executorch() + def test_data_dependent(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo1", + "(Tensor a, Tensor b) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo1", "cpu", lib=lib) + def foo_impl(a, b): + return a + b + + @torch.library.register_fake("mylib::foo1", lib=lib) + def mylib_foo_default_fake(*args, **kwargs): + ctx = torch.library.get_ctx() + fake_shape = ctx.new_dynamic_size() + return torch.empty(fake_shape, dtype=torch.float32, device="cpu") + + class M(torch.nn.Module): + def forward(self, a, b, c): + res = torch.ops.mylib.foo1(a, b) + + c_item = c.item() + torch._check_is_size(c_item) + torch._check(c_item < res.shape[0]) + return res[:c_item] + + inp = (torch.randn(10), torch.randn(10), torch.tensor(3)) + + ep = export(M(), inp) + edge = to_edge(ep) + self.assertTrue( + torch.allclose( + edge.exported_program().module()(*inp), + M()(*inp), + ) + ) + def test_edge_manager_transform(self): edge_manager: EdgeProgramManager = to_edge( get_exported_programs(), get_config_methods() diff --git a/exir/verification/test/test_verifier.py b/exir/verification/test/test_verifier.py index 369f976076d..f38072969a7 100644 --- a/exir/verification/test/test_verifier.py +++ b/exir/verification/test/test_verifier.py @@ -36,40 +36,6 @@ def test_edge_verifier_check_valid_op_succeed_given_custom_op(self) -> None: verifier.check_valid_edge_op(edge_op) verifier.check_valid_op(edge_op) - def test_edge_verifier_enablement(self) -> None: - class M(torch.nn.Module): - def forward(self, x, y): - z = y.item() - torch._check(z > 0) - torch._check(z < 4) - return x[z : z + y.shape[0]] - - ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3])), strict=True) - - compile_config_with_disable_ir_validity = EdgeCompileConfig( - _check_ir_validity=False - ) - edge_manager = to_edge( - ep, compile_config=compile_config_with_disable_ir_validity - ) - - normal_verifier = EXIREdgeDialectVerifier() - disable_ir_validity_verifier = EXIREdgeDialectVerifier( - compile_config_with_disable_ir_validity - ) - - # exported model can not pass normal verifier due to - # aten.sym_constrain_range.default is illegal to be edge op - with self.assertRaises(SpecViolationError): - normal_verifier(edge_manager.exported_program()) - - # exported model can pass disable_ir_validity_verifier due to verifier - # is disabled by compile_config_with_disable_ir_validity - # (_check_ir_validity=False). Noted that this verifation has been done - # when calling `to_edge`. Explicitly calling verifier here just for better - # demonstration and is unnecessary in real world for ir verification. - disable_ir_validity_verifier(edge_manager.exported_program()) - def test_edge_verifier_check_edge_op(self) -> None: class Model(torch.nn.Module): def __init__(self): diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index 166e6a758a5..bc510ff6849 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -16,6 +16,8 @@ from executorch.exir.error import ExportError, ExportErrorType from executorch.exir.lowered_backend_module import LoweredBackendModule from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap +from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS +from executorch.exir.passes.replace_aten_with_edge_pass import DISALLOW_LIST from executorch.exir.verification.arg_validator import ( EdgeOpArgValidator, RunHigherOrderOperatorError, @@ -99,16 +101,20 @@ def __init__(self) -> None: self._exception_list = exception_list if exception_list else [] def _get_exception_list(self) -> List[torch._ops.OpOverload]: - exception_list = [ - torch.ops.aten.mkldnn_rnn_layer.default, - torch.ops.aten._upsample_bilinear2d_aa.default, - torch.ops.aten.quantize_per_tensor.default, - torch.ops.aten.dequantize.self, - torch.ops.aten.max.default, # TODO(T188268054) - torch.ops.aten.min.default, # TODO(T188268054) - torch.ops.aten.full_like.default, # TODO(T183507359) - ] - exception_list += self._exception_list + exception_list = ( + [ + torch.ops.aten.mkldnn_rnn_layer.default, + torch.ops.aten._upsample_bilinear2d_aa.default, + torch.ops.aten.quantize_per_tensor.default, + torch.ops.aten.dequantize.self, + torch.ops.aten.max.default, # TODO(T188268054) + torch.ops.aten.min.default, # TODO(T188268054) + torch.ops.aten.full_like.default, # TODO(T183507359) + ] + + list(_EXECUTORCH_SYM_OPS) + + DISALLOW_LIST + + self._exception_list + ) return exception_list @@ -249,13 +255,9 @@ def check_valid_edge_op(self, op): return if ( op - in [ - operator.getitem, - torch.ops.aten.sym_size.int, - torch.ops.aten.scalar_tensor.default, - torch.ops.aten._assert_async.msg, - torch.ops.aten._assert_scalar.default, - ] + in [operator.getitem] + + DISALLOW_LIST + + list(_EXECUTORCH_SYM_OPS) + self._exception_list ): return