From 800405196233629889f6fa30f9ad33a889a930d0 Mon Sep 17 00:00:00 2001 From: Zonglin Peng Date: Tue, 19 Nov 2024 16:26:54 -0800 Subject: [PATCH] add simply ops to oss, update fuse simplify callsites, add ops_reg in oss compiler (#6881) Summary: titled Reviewed By: mcremon-meta Differential Revision: D65980636 --- backends/cadence/aot/TARGETS | 17 ++++ backends/cadence/aot/compiler.py | 1 + backends/cadence/aot/fuse_ops.py | 2 +- backends/cadence/aot/passes.py | 15 ++++ backends/cadence/aot/simplify_ops.py | 112 +++++++++++++++++++++++++++ 5 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 backends/cadence/aot/simplify_ops.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index d0d540a3742..c0374faa7e9 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -38,6 +38,7 @@ python_library( deps = [ ":passes", ":utils", + ":ops_registrations", "//caffe2:torch", "//executorch/backends/cadence/aot/quantizer:fusion_pass", "//executorch/backends/cadence/aot/quantizer:quantizer", @@ -71,6 +72,8 @@ python_library( ], deps = [ ":utils", + ":fuse_ops", + ":simplify_ops", "//caffe2:torch", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", @@ -163,6 +166,20 @@ python_library( ], ) +python_library( + name = "simplify_ops", + srcs = [ + "simplify_ops.py", + ], + typing = True, + deps = [ + ":pass_utils", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + python_unittest( name = "test_graph_builder", srcs = [ diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 6b799d99f9e..e53826b7b98 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Callable, cast, Optional +import executorch.backends.cadence.aot.ops_registrations # noqa import torch from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 8738711777e..8137c1fdbd2 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -1022,7 +1022,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return PassResult(graph_module, True) -class FuseOpsInGraph: +class CadenceFuseOpsInGraph: passes = [ FuseMMWithAdd, FuseBatchNormWithConv, diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index 265bf62bca1..e23e53bd2b1 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -11,11 +11,13 @@ import torch import torch.fx import torch.utils._pytree as pytree +from executorch.backends.cadence.aot.fuse_ops import CadenceFuseOpsInGraph from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, create_cadence_pass_filter, register_cadence_pass, ) +from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir.dialects._ops import ops as exir_ops @@ -346,10 +348,23 @@ def get_passes_in_default_order() -> List[Type[PassType]]: ReplaceScalarTensorWithFullPass, RemoveCloneOpsTransformImported, RemoveNopExpandOpPass, + CadenceFuseOpsInGraph.passes, ReplaceSqueezeAndUnsqueezeWithViewPass, ReplacePT2QuantWithCadenceQuantPass, ReplacePT2DequantWithCadenceDequantPass, + CadenceSimplifyOpsInGraph.passes, # TODO: add the rest of the passes here. + # InitializePipeline, + # RemoveRedundantOps.passes, + # ReorderOpsInGraph.passes, + # RemoveJarvisNops.passes, + # CadenceFuseOpsInGraph.passes, + # ReplaceOpsInGraph.passes, + # SimplifyOpsInGraph.passes, + # FinalizePipeline, + # FuseFullThenReshapePass, + # FuseTransposeOpPairsPass, + # RemoveNopSliceOrViewOpPass, ] return pytree.tree_flatten(passes)[0] diff --git a/backends/cadence/aot/simplify_ops.py b/backends/cadence/aot/simplify_ops.py new file mode 100644 index 00000000000..a072201ead3 --- /dev/null +++ b/backends/cadence/aot/simplify_ops.py @@ -0,0 +1,112 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-unsafe + + +# This file contains all the functions that simplify args of an op + +import sys +from typing import Optional + +from executorch.backends.cadence.aot.pass_utils import ( + CadencePassAttribute, + register_cadence_pass, +) + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, ProxyValue + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class SimplifySliceOpPass(ExportPass): + """ + Simplify the start and end indices of slice and slice_scatter ops. + """ + + def adjust_slice_range( + self, + length: int, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, + ) -> tuple[int, int]: + # Get the start index and end index + start_val = start if start is not None else 0 + end_val = end if end is not None else sys.maxsize # 2^63 – 1 + + # If start_val and end_val are negative, add length to them + if start_val < 0: + start_val += length + if end_val < 0: + end_val += length + + # If the start val is still outside the tensor_size along the sliced + # dimension, adjust it accordingly. + if start_val < 0: + start_val = 0 + elif start_val >= length: + start_val = length + + # If the end val is still outside the tensor_size along the sliced + # dimension, adjust it accordingly. + if end_val < start_val: + end_val = start_val + elif end_val >= length: + end_val = length + + # Return the adjusted start and end indices + return (start_val, end_val) + + def call_operator(self, op, args, kwargs, meta): + # We are only interested in slice_copy or slice_scatter ops + if op not in { + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.slice_scatter.default, + }: + return super().call_operator(op, args, kwargs, meta) + + # Check if it is a slice_scatter op or not. The slice_scatter op has + # an extra src argument at index 1. + slice_scatter = op == exir_ops.edge.aten.slice_scatter.default + # Parse the arguments + # Extract the tensor to be sliced, and the slicing dimension + in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0 + # Make dim non-negative + dim = dim if dim >= 0 else dim + in_tensor.dim() + length = in_tensor.size(dim) + + # Get the adjusted start and end indices + start_val = args[2 + slice_scatter] if len(args) > 2 + slice_scatter else None + end_val = args[3 + slice_scatter] if len(args) > 3 + slice_scatter else None + step = args[4 + slice_scatter] if len(args) > 4 + slice_scatter else 1 + (start_val, end_val) = self.adjust_slice_range(length, start_val, end_val, step) + + # If the start_val is geq end_val, then we can return an empty tensor + # for slice op, or input for slice_scatter op. + if start_val >= end_val and slice_scatter: + return args[0] + if start_val >= end_val: + empty_shape = [x for x in in_tensor.shape if x != 0] + empty_shape[dim] = 0 + return super().call_operator( + exir_ops.edge.aten.full.default, + (tuple(empty_shape), 0), + {"dtype": in_tensor.dtype}, + meta, + ) + + # Create new args + new_args = ( + (args[0],) + + ((args[1],) if slice_scatter else ()) + + (dim, start_val, end_val, step) + ) + return super().call_operator(op, new_args, kwargs, meta) + + +# This class encapsulates all the functions that simplify the op's args +class CadenceSimplifyOpsInGraph: + passes = [ + SimplifySliceOpPass, + ]