diff --git a/backends/cadence/aot/program_builder.py b/backends/cadence/aot/program_builder.py index d73cc9fcfbf..862ba4e977c 100644 --- a/backends/cadence/aot/program_builder.py +++ b/backends/cadence/aot/program_builder.py @@ -11,7 +11,9 @@ from executorch.exir.verification.verifier import EXIREdgeDialectVerifier from torch import Tensor from torch._export.verifier import Verifier +from torch._ops import OpOverload from torch.export import ExportedProgram +from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature from torch.export.graph_signature import ( ExportGraphSignature, InputKind, @@ -20,6 +22,7 @@ OutputSpec, TensorArgument, ) +from torch.utils import _pytree as pytree class IrMode(Enum): @@ -30,12 +33,19 @@ class IrMode(Enum): class ProgramBuilder(GraphBuilder): """Utility class to build a program from a graph module.""" - def __init__(self, mode: Optional[IrMode] = None) -> None: + def __init__( + self, + mode: Optional[IrMode] = None, + _core_aten_ops_exception_list: Optional[list[OpOverload]] = None, + ) -> None: self.input_specs: list[InputSpec] = [] self.output_specs: list[OutputSpec] = [] self.constants: dict[str, Tensor] = {} self.state_dict: dict[str, Tensor] = {} self.mode: IrMode = mode or IrMode.EXIR + self._core_aten_ops_exception_list: list[OpOverload] = ( + _core_aten_ops_exception_list or [] + ) super().__init__() def insert_input_spec( @@ -80,27 +90,42 @@ def get_verifiers(self) -> Optional[list[Verifier]]: return None return [ EXIREdgeDialectVerifier( - edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), + edge_compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _core_aten_ops_exception_list=self._core_aten_ops_exception_list, + ), + core_aten_ops_exception_list=self._core_aten_ops_exception_list, class_only=True, ) ] def get_program(self) -> ExportedProgram: gm = self.get_graph_module() + graph_signature = ExportGraphSignature(self.input_specs, self.output_specs) + in_spec = pytree.tree_flatten((tuple(graph_signature.user_inputs), {}))[1] + out_spec = pytree.tree_flatten(graph_signature.user_outputs)[1] return ExportedProgram( root=gm, graph=gm.graph, - graph_signature=ExportGraphSignature( - input_specs=self.input_specs, output_specs=self.output_specs - ), + graph_signature=graph_signature, # pyre-ignore[6]: Incompatible parameter type. constants=self.constants, state_dict=self.state_dict, range_constraints={}, - module_call_graph=[], + module_call_graph=[ + ModuleCallEntry( + "", + ModuleCallSignature( + inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec + ), + ) + ], # pyre-ignore[6]: Incompatible parameter type. verifiers=self.get_verifiers(), ) def get_edge_program(self) -> EdgeProgramManager: - return EdgeProgramManager(self.get_program()) + return EdgeProgramManager( + self.get_program(), + core_aten_ops_exception_list=self._core_aten_ops_exception_list, + ) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 7f493e1645d..75190b9c7be 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2242,10 +2242,26 @@ def call_operator(self, op, args, kwargs, meta): ) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass): + """ + Replace aten linalg svd op with cadence custom op. + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten._linalg_svd.default: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta + ) + + # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: passes = [ + ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, ReplacePermuteWithTransposePass, diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index bd02cb0ae11..ca5168db2be 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -22,6 +22,7 @@ ReplaceAddMMWithLinearPass, ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceAtenConvolutionWithCadenceConvolutionPass, + ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceConstantPadNdWithSlicePass, ReplaceConvolutionOptionalArgsWithConcreteArgsPass, ReplaceConvWithChannelLastConvPass, @@ -2045,3 +2046,38 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: len(avg_pool2d_nodes), 0, ) + + +class TestReplaceLinalgSvdPass(unittest.TestCase): + @expand( + [ + ("2x2", (2, 2)), + ("3x3", (3, 3)), + ("4x5", (4, 5)), + ("10x10", (10, 10)), + ] + ) + @torch.no_grad() + def test_replace_aten_linalg_svd_with_cadence_linalg_svd( + self, _: str, shape: Tuple[int, int] + ) -> None: + x = torch.randn(shape, dtype=torch.float32) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten._linalg_svd.default, + args=(x, False, True), + kwargs={"driver": None}, + ) + + p = ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass() + graph_after_passes = cast(PassResult, p(original_gm)).graph_module + + # Assert that the aten linalg_svd op was replaced with cadence linalg_svd op + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten._linalg_svd.default), + 0, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.linalg_svd.default), + 1, + )