diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 148ad9ff352..fc224aaad30 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -6,9 +6,8 @@ # pye-strict -import copy import unittest -from typing import Any, Callable, Dict +from typing import Any, Dict import torch from executorch.exir import ExecutorchBackendConfig @@ -16,11 +15,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.error import ExportError from executorch.exir.lowered_backend_module import get_lowered_submodules -from executorch.exir.pass_base import ExportPass, PassResult -from executorch.exir.passes.replace_aten_with_edge_pass import ( - aten_to_edge, - should_lower_to_edge, -) +from executorch.exir.pass_base import ExportPass from executorch.exir.program._program import ( EdgeProgramManager, ExecutorchProgramManager, @@ -31,9 +26,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) -from torch import fx from torch.export import export, ExportedProgram -from torch.fx import GraphModule, subgraph_rewriter def get_exported_programs() -> Dict[str, ExportedProgram]: @@ -70,32 +63,13 @@ def bar(): class AddToMulPassEdge(ExportPass): - def call(self, graph_module: GraphModule) -> PassResult: - """ - Dummy pass that replaces add with mul - """ - - def _trace_and_lower_to_edge_ops(f: Callable) -> fx.GraphModule: - gm = fx.symbolic_trace(f) - for node in gm.graph.nodes: - if node.op == "call_function" and should_lower_to_edge(node.target): - node.target = aten_to_edge(node.target) - gm.recompile() - return gm - - def pattern(x: torch.Tensor, y: torch.Tensor): - return torch.ops.aten.add.Tensor(x, y) - - def replacement(x: torch.Tensor, y: torch.Tensor): - return torch.ops.aten.mul.Tensor(x, y) - - new_graph_module = copy.deepcopy(graph_module) - subgraph_rewriter.replace_pattern_with_filters( - new_graph_module, - _trace_and_lower_to_edge_ops(pattern), - _trace_and_lower_to_edge_ops(replacement), - ) - return PassResult(new_graph_module, True) + def call_operator(self, op, args, kwargs, meta): + if op == exir_ops.edge.aten.add.Tensor: + return super().call_operator( + exir_ops.edge.aten.mul.Tensor, args, kwargs, meta + ) + else: + return super().call_operator(op, args, kwargs, meta) class TestProgramManagers(unittest.TestCase): diff --git a/exir/tests/test_verification.py b/exir/tests/test_verification.py index d7eb6fe0049..e42b49ed5d7 100644 --- a/exir/tests/test_verification.py +++ b/exir/tests/test_verification.py @@ -145,6 +145,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertEqual(len(res), len(res_interp)) self.assertTrue(torch.allclose(res, res_interp)) + +class TestEdgeVerification(unittest.TestCase): def test_edge_happy(self) -> None: class TestModel(torch.nn.Module): def __init__(self): diff --git a/exir/verification/TARGETS b/exir/verification/TARGETS index 8d05c8c8d41..6b7fdf2bc4f 100644 --- a/exir/verification/TARGETS +++ b/exir/verification/TARGETS @@ -51,8 +51,8 @@ python_library( ], deps = [ "//caffe2:torch", - "//executorch/exir:delegate", "//executorch/exir:error", + "//executorch/exir:lowered_backend_module", "//executorch/exir/dialects/edge:lib", "//executorch/exir/emit:emit", ], diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index c1eb53843c9..6bae14d68cc 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -4,12 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +import itertools +import operator +from typing import Any, List, Optional, Tuple, Type import torch -from executorch.exir.delegate import executorch_call_delegate from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.error import ExportError, ExportErrorType +from executorch.exir.lowered_backend_module import LoweredBackendModule from executorch.exir.verification.arg_validator import ( EdgeOpArgValidator, RunHigherOrderOperatorError, @@ -17,7 +19,6 @@ from torch._export.verifier import ( _check_has_fake_tensor, - _check_tensors_are_contiguous, ATenDialectVerifier, SpecViolationError, Verifier, @@ -28,30 +29,21 @@ ALLOWED_META_KEYS = {"spec", "stack_trace"} -VALID_BUILTIN_FUNCS = [ - executorch_call_delegate, -] -class EXIRATenDialectVerifier(ATenDialectVerifier): - def valid_builtin_funcs(self): - builtin_funcs = super().valid_builtin_funcs() - builtin_funcs.extend(VALID_BUILTIN_FUNCS) - return builtin_funcs - - # TODO(angelayi): Delete this function when we migrate all tests to - # because right now old tracer does not add ["val"] metadata - def check_valid(self, gm: GraphModule) -> None: # noqa: C901 - - for node in gm.graph.nodes: - if node.op in {"call_module", "call_method"}: +def _check_tensors_are_contiguous(gm: GraphModule) -> None: + # Tensors be of contiguous format + for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): + if isinstance(param, torch.Tensor): + if not param.is_contiguous(): raise SpecViolationError( - "call_module is not valid: got a class '{}' ".format(node.target), + f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" ) - if node.op == "call_function": - if node.target not in self.valid_builtin_funcs(): - self.check_valid_op(node.target) + +class EXIRATenDialectVerifier(ATenDialectVerifier): + def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: + return (torch.fx.GraphModule, LoweredBackendModule, torch.Tensor) def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]: @@ -97,15 +89,21 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None: class EXIREdgeDialectVerifier(Verifier): - def __init__(self, check_edge_ops: bool = False) -> None: + def __init__(self, check_edge_ops: bool = True) -> None: self.check_edge_ops = check_edge_ops - def valid_builtin_funcs(self): - builtin_funcs = super().valid_builtin_funcs() - builtin_funcs.extend(VALID_BUILTIN_FUNCS) - return builtin_funcs + if self.check_edge_ops: + self.check_valid_op = self.check_valid_edge_op + else: + self.check_valid_op = self.check_valid_aten_op + + def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: + return (torch.fx.GraphModule, LoweredBackendModule, torch.Tensor) def check_valid_edge_op(self, op): + if op in [operator.getitem]: + return + if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload): raise SpecViolationError( "Operator {}.{} is not an Edge operator.".format( @@ -116,33 +114,23 @@ def check_valid_edge_op(self, op): def check_valid_aten_op(self, op) -> None: super().check_valid_op(op) - op_name = op.name if hasattr(op, "name") else op.__name__ - - if not isinstance(op, OpOverload): - raise SpecViolationError( - "Operator '{}' is not a registered Op".format(op_name), - ) - - if ( - torch.Tag.core not in op.tags # type: ignore[attr-defined] - and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined] - ): - # NOTE(qihan): whether view_copy operators are marked as canonical is still under - # discussion. - raise SpecViolationError( - "Operator {}.{} is not Aten Canonical.".format( - op.__module__, op.__name__ + if isinstance(op, OpOverload): + if ( + torch.Tag.core not in op.tags # type: ignore[attr-defined] + and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined] + ): + # NOTE(qihan): whether view_copy operators are marked as canonical is still under + # discussion. + raise SpecViolationError( + "Operator {}.{} is not Aten Canonical.".format( + op.__module__, op.__name__ + ) ) - ) - def check_valid(self, gm: GraphModule) -> None: + def check_additional(self, gm: GraphModule) -> None: if self.check_edge_ops: - self.check_valid_op = self.check_valid_edge_op - super().check_valid(gm) _check_tensors_are_contiguous(gm) _check_tensor_args_matching_op_allowed_dtype(gm) - else: - self.check_valid_op = self.check_valid_aten_op # Additionally, edge dialect's operator must have same input dtype for n in gm.graph.nodes: