diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 74deed06286..47824380d55 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -48,10 +48,11 @@ python_library( ], ) + python_library( - name = "passes", + name = "pass_utils", srcs = [ - "_passes.py", + "pass_utils.py", ], deps = [ ":utils", @@ -64,9 +65,9 @@ python_library( ) python_library( - name = "pass_utils", + name = "passes", srcs = [ - "pass_utils.py", + "passes.py", ], deps = [ ":utils", diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 62bae4b21dd..7e5fd3fec27 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -8,28 +8,17 @@ import logging from pathlib import Path -from typing import Optional +from typing import Callable, cast, Optional import torch -from executorch.backends.cadence.aot._passes import ( - InitializePipeline, - RemoveNopExpandOpPass, - RemoveZeroSizedCatArgsPass, - ReplaceLogicalNotBooleanWhereWithWherePass, - ReplacePT2DequantWithCadenceDequantPass, - ReplacePT2QuantWithCadenceQuantPass, - ReplaceSafeSoftmaxWithSoftmax, - ReplaceScalarTensorWithFullPass, - ReplaceSqueezeAndUnsqueezeWithViewPass, -) +from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, ) -from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.devtools import generate_etrecord from executorch.exir import ( EdgeCompileConfig, @@ -37,12 +26,15 @@ ExecutorchProgramManager, to_edge, ) +from executorch.exir.pass_base import PassResult from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export from torch.export.exported_program import ExportedProgram +from .passes import get_cadence_passes + from .utils import print_ops_info @@ -209,22 +201,16 @@ def export_to_cadence_edge_executorch( inputs: tuple[object, ...], dump_graphs: bool = False, output_dir: Optional[str] = None, + opt_level: int = 1, ) -> ExecutorchProgramManager: edge_prog_manager = export_to_edge(model, inputs) + cadence_passes = get_cadence_passes(opt_level) # Run a couple required passes for quant/dequant ops cadence_prog_manager = edge_prog_manager.transform( - [ - InitializePipeline(), - RemoveZeroSizedCatArgsPass(), - ReplaceLogicalNotBooleanWhereWithWherePass(), - ReplaceScalarTensorWithFullPass(), - RemoveCloneOpsTransform(), - RemoveNopExpandOpPass(), - ReplaceSqueezeAndUnsqueezeWithViewPass(), - ReplacePT2QuantWithCadenceQuantPass(), - ReplacePT2DequantWithCadenceDequantPass(), - ] + cast( + list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes + ) ) # Print some information to terminal diff --git a/backends/cadence/aot/_passes.py b/backends/cadence/aot/passes.py similarity index 74% rename from backends/cadence/aot/_passes.py rename to backends/cadence/aot/passes.py index 83ef43d1510..bd872a85e09 100644 --- a/backends/cadence/aot/_passes.py +++ b/backends/cadence/aot/passes.py @@ -6,21 +6,74 @@ # pyre-strict -from typing import Any, cast, Dict, Sequence, Tuple +from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type import torch +import torch.fx +import torch.utils._pytree as pytree +from executorch.backends.cadence.aot.pass_utils import ( + CadencePassAttribute, + create_cadence_pass_filter, + register_cadence_pass, +) from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.pass_manager import PassManager, PassType from executorch.exir.passes import dead_code_elimination_pass +from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch._subclasses import FakeTensor from torch.utils._pytree import tree_map_only + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class RemoveCloneOpsTransformImported(ExportPass): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + finalize_passes: List[PassType] = [ + RemoveCloneOpsTransform(), + ] + result = PassManager(passes=finalize_passes)(graph_module) + dead_code_elimination_pass(result.graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class InitializePipeline(ExportPass): + """ + Initialize the Jarvis pipeline. This should invariably be the first pass to + run. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + dead_code_elimination_pass(graph_module) + result = SpecPropPass()(graph_module) + assert result is not None + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class FinalizePipeline(ExportPass): + """ + The final cleanup pass after running the Jarvis pipeline. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + finalize_passes: List[PassType] = [ + ScalarToTensorPass(), + SpecPropPass(), + ] + result = PassManager(passes=finalize_passes)(graph_module) + dead_code_elimination_pass(result.graph_module) + return result + + # Similar to what's done in executorch/exir/pass_base.py Argument = Any # pyre-ignore +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplacePT2QuantWithCadenceQuantPass(ExportPass): """ Replace the pt2 quantization ops with custom cadence quantization ops. @@ -44,6 +97,7 @@ def call_operator( ) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplacePT2DequantWithCadenceDequantPass(ExportPass): """ Replace the pt2 dequantization ops with custom cadence dequantization ops. @@ -67,6 +121,7 @@ def call_operator( ) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceScalarTensorWithFullPass(ExportPass): """ aten.scalar_tensor can be replaced by aten.full with a shape of [1]. @@ -96,6 +151,7 @@ def call_operator( ) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass): """ When the shape is static, replace squeeze_copy and unsqueeze_copy ops with @@ -131,7 +187,8 @@ def call_operator( ) -class RemoveZeroSizedCatArgsPass(ExportPass): +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class RemoveZeroSizedCatArgsPass(ExportPass): # is this the latest? def call_operator( self, op, # pyre-ignore @@ -176,6 +233,7 @@ def call_operator( return super().call_operator(op, args, kwargs, meta) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveNopExpandOpPass(ExportPass): """ For an expand op, if the operator shape matches the expand shape, then the @@ -205,6 +263,7 @@ def call_operator( return super().call_operator(op, args, kwargs, meta) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): """ A where op with a logical_not and a boolean tensor can be replaced @@ -255,20 +314,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return result -class InitializePipeline(ExportPass): - """ - Initialize the Jarvis pipeline. This should invariably be the first pass to - run. - """ - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - dead_code_elimination_pass(graph_module) - result = SpecPropPass()(graph_module) - assert result is not None - return result - - -class ReplaceSafeSoftmaxWithSoftmax(ExportPass): +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep """ Replace _safe_softmax with _softmax """ @@ -292,3 +339,33 @@ def call_operator( kwargs, meta, ) + + +def get_passes_in_default_order() -> List[Type[PassType]]: + passes = [ + InitializePipeline, + RemoveZeroSizedCatArgsPass, + ReplaceLogicalNotBooleanWhereWithWherePass, + ReplaceScalarTensorWithFullPass, + RemoveCloneOpsTransformImported, + RemoveNopExpandOpPass, + ReplaceSqueezeAndUnsqueezeWithViewPass, + ReplacePT2QuantWithCadenceQuantPass, + ReplacePT2DequantWithCadenceDequantPass, + # TODO: add the rest of the passes here. + ] + return pytree.tree_flatten(passes)[0] + + +def get_cadence_passes( + opt_level: int, +) -> List[Optional[PassResult]]: + passes = get_passes_in_default_order() + pass_filter = create_cadence_pass_filter(opt_level) + filtered_passes = [ + # pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`. + filtered_pass() + # pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`. + for filtered_pass in list(filter(pass_filter, passes)) + ] + return filtered_passes