From a113625616b28b86304bbd9571e50ea96610a1c0 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 30 Apr 2025 09:48:57 +0200 Subject: [PATCH] Arm backend: Decompose sum in pass Moves the unrolling of reducing multiple indices from the sum node visitor to a new DecomposeSumPass. KeepDimsFalseToSqueezePass is merged into the new pass to decompose the sum op fully in one pass. This change introduces new rescales for each reduced dim, requiring decomposition before quantization to get proper quantization parameters. Change-Id: I1b113813f22c6b25aac56d63110d7eee4833167a Signed-off-by: Adrian Lundell --- backends/arm/_passes/__init__.py | 2 +- backends/arm/_passes/arm_pass_manager.py | 8 +- backends/arm/_passes/decompose_sum_pass.py | 79 ++++++++ .../keep_dims_false_to_squeeze_pass.py | 92 --------- backends/arm/operators/op_sum.py | 180 +++++++----------- 5 files changed, 156 insertions(+), 205 deletions(-) create mode 100644 backends/arm/_passes/decompose_sum_pass.py delete mode 100644 backends/arm/_passes/keep_dims_false_to_squeeze_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 35879d5026c..27722608c95 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -32,6 +32,7 @@ from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa from .decompose_sqrt_pass import DecomposeSqrtPass # noqa +from .decompose_sum_pass import DecomposeSumPass # noqa from .decompose_var_pass import DecomposeVarPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, @@ -44,7 +45,6 @@ from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa from .insert_rescales_pass import InsertRescalePass # noqa from .insert_table_ops import InsertTableOpsPass # noqa -from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c5ebace2834..21ff11b3598 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -37,6 +37,7 @@ DecomposeSoftmaxPass, DecomposeSoftmaxUnstablePass, DecomposeSqrtPass, + DecomposeSumPass, DecomposeVarPass, FoldAndAnnotateQParamsPass, FuseBatchnorm2DPass, @@ -45,7 +46,6 @@ FuseQuantizedActivationPass, InsertRescalePass, InsertTableOpsPass, - KeepDimsFalseToSqueezePass, MatchArgRanksPass, MatchWhereSelfDtypePass, QuantizeOperatorArguments, @@ -110,7 +110,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(KeepDimsFalseToSqueezePass()) + self.add_pass(DecomposeSumPass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) @@ -163,7 +163,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(KeepDimsFalseToSqueezePass()) + self.add_pass(DecomposeSumPass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) @@ -220,4 +220,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ConvertMinMaxPass()) self.add_pass(ReplaceInfValues()) + self.add_pass(DecomposeSumPass()) + return self._transform(graph_module) diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py new file mode 100644 index 00000000000..531b0d72a19 --- /dev/null +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -0,0 +1,79 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +def _get_sum_decomp(op): + match op: + case exir_ops.edge.aten.sum.dim_IntList: + return ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.sum.dim_IntList, + ) + case torch.ops.aten.sum.dim_IntList: + return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList) + case _: + raise RuntimeError("Unvalid op in DecomposeSumPass") + + +class DecomposeSumPass(ExportPass): + """ + In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the + dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always + preserves the rank of the input (keep_dim = True). To get a 1-1 mapping in the sum + lowering, normalize the keep_dim = False case to keep_dim = True and lower the rank + with a view op. + + Since TOSA can only reduce one dimension at a time, multiple dims are additionally + unrolled into multiple ops. + + Original: + sum((dim_1, dim_2), keep_dim = False) -> squeezed_shape + After pass: + sum(dim_1, keep_dim = True) -> unsqueezed_shape + sum(dim_2, keep_dim = True) -> unsqueezed_shape + view(shape = squeezed_shape) -> squeezed_shape + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in [ + exir_ops.edge.aten.sum.dim_IntList, + torch.ops.aten.sum.dim_IntList, + ]: + return super().call_operator(op, args, kwargs, meta) + + match len(args): + case 3: + ( + input_node, + dims, + keepdims, + ) = args + case 2: + ( + input_node, + dims, + ) = args + keepdims = False + case _: + raise ValueError(f"Invalid number of arguments ({len(args)}) provided.") + + view_op, sum_op = _get_sum_decomp(op) + + for dim in dims: + input_node = super().call_operator( + sum_op, (input_node, dim, True), kwargs, meta + ) + + if not keepdims: + shape = list(meta["val"].size()) + input_node = super().call_operator( + view_op, (input_node, shape), kwargs, meta + ) + + return input_node diff --git a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py deleted file mode 100644 index 744436cba9e..00000000000 --- a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import cast - -import torch -import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import ( - create_node, - get_node_arg, - set_node_arg, -) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class KeepDimsFalseToSqueezePass(ExportPass): - """ - In Pytorch, the default behaviour of for example Tensor.sum is to squeeze - the dimension that is summed (keep_dim = False). - However, in TOSA, REDUCE_SUM always preserves the - rank of the input (keep_dim = True). - To get a 1-1 mapping in the sum lowering, normalize the - keep_dim = False case to keep_dim = True and add squeeze ops. - - Original: - sum(dims, keep_dim = False) - After pass: - sum(dims, keep_dim = True) - squeeze(dim = dims) - """ - - # CURRENTLY NOT HANDLED OPS - # exir_ops.edge.aten.argmax, - # exir_ops.edge.aten.argmin, - # exir_ops.edge.aten.prod.dim_int, - - # HANDLED OPS - # exir_ops.edge.aten.sum.dim_IntList - # exir_ops.edge.aten.any.default (decomposed in convert_any_default_dim_dims_pass) - # exir_ops.edge.aten.any.dim (decomposed in convert_any_default_dim_dims_pass) - # exir_ops.edge.aten.any.dims (decomposed in convert_any_default_dim_dims_pass) - # exir_ops.edge.aten.max.dim (decomposed in convert_minmax_pass) - # exir_ops.edge.aten.min.dim (decomposed in convert_minmax_pass) - # exir_ops.edge.aten.amin (decomposed in convert_minmax_pass) - # exir_ops.edge.aten.amax (decomposed in convert_minmax_pass) - # exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass) - # exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass) - # exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass) - - def call(self, graph_module: torch.fx.GraphModule): - for node in graph_module.graph.nodes: - keep_dim_index = None - - if node.op != "call_function": - continue - if node.target == exir_ops.edge.aten.sum.dim_IntList: - keep_dim_index = 2 - else: - continue - - sum_node = cast(torch.fx.Node, node) - keep_dim = get_node_arg( - # pyre-ignore[6] - sum_node.args, # type: ignore[arg-type] - keep_dim_index, - False, - ) - - if keep_dim: - continue - - dim_list = get_node_arg(sum_node.args, 1, [0]) # type: ignore[arg-type] # pyre-ignore[6] - - # Add keep_dim = True arg to sum node. - set_node_arg(sum_node, 2, True) - - with graph_module.graph.inserting_after(sum_node): - squeeze_node = create_node( - graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, () - ) - sum_node.replace_all_uses_with(squeeze_node) - squeeze_node.args = (sum_node, dim_list) - - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - return PassResult(graph_module, True) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index f232136fd9b..c0a436f4d99 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -5,7 +5,7 @@ # pyre-unsafe -from typing import Any, cast, List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils @@ -45,41 +45,36 @@ def define_node( validate_num_inputs(self.target, inputs, 3) - input_shape = list(inputs[0].shape) - dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(input_shape) for dim in dim_list] - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) + + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced # Rescale input to 32 bit rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( tosa_graph, - [inputs[0]], + [tensor], node, ) - prev_node = rescaled_inputs[0] - reduced_shape = input_shape - - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1. - reduced_shape[dim] = 1 - - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(inputs[0].dim_order.index(dim)) + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(tensor.dim_order.index(dim)) - next_node = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, inputs[0].dim_order), - dtype=ts.DType.INT32, - ) + intermediate = tosa_graph.addIntermediate( + tutils.tosa_shape(output_shape, tensor.dim_order), + dtype=ts.DType.INT32, + ) - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr - ) + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, + [rescaled_inputs[0].name], + [intermediate.name], + attr, + ) - prev_node = next_node - tqutils.insert_rescale_op_to_int8(tosa_graph, prev_node, scale, node) + tqutils.insert_rescale_op_to_int8(tosa_graph, intermediate, scale, node) @register_node_visitor @@ -103,38 +98,27 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 3) - if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) - input_name = inputs[0].name - reduced_shape = list(inputs[0].shape) - dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(reduced_shape) for dim in dim_list] - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1 - reduced_shape[dim] = 1 + validate_num_inputs(self.target, inputs, 3) - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(inputs[0].dim_order.index(dim)) + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) - if dim == dim_list[-1]: - output_name = output.name - else: - output_name = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, inputs[0].dim_order), - dtype=ts.DType.FP32, - ).name + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr - ) + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(tensor.dim_order.index(dim)) - input_name = output_name + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, + [tensor.name], + [output.name], + attr, + ) @register_node_visitor @@ -160,45 +144,37 @@ def define_node( validate_num_inputs(self.target, inputs, 3) - input_shape = list(inputs[0].shape) - dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(input_shape) for dim in dim_list] - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) + + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced # Rescale input to 32 bit rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( tosa_graph, - [inputs[0]], + [tensor], node, - self.tosa_specs, ) - prev_node = rescaled_inputs[0] - reduced_shape = input_shape - - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1. - reduced_shape[dim] = 1 + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(tensor.dim_order.index(dim)) - attr = ts.TosaSerializerAttribute() - attr.ReduceSumAttribute(inputs[0].dim_order.index(dim)) - - next_node = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, inputs[0].dim_order), - dtype=ts.DType.INT32, - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr - ) + intermediate = tosa_graph.addIntermediate( + tutils.tosa_shape(output_shape, tensor.dim_order), + dtype=ts.DType.INT32, + ) - prev_node = next_node - tqutils.insert_rescale_op_to_int8( - tosa_graph, prev_node, scale, node, self.tosa_specs + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, + [rescaled_inputs[0].name], + [intermediate.name], + attr, ) + tqutils.insert_rescale_op_to_int8(tosa_graph, intermediate, scale, node) + @register_node_visitor class SumVisitor_FP(SumVisitor_INT): @@ -221,33 +197,19 @@ def define_node( validate_num_inputs(self.target, inputs, 3) - if inputs[0].dtype == ts.DType.INT8: - return super().define_node(node, tosa_graph, inputs, output) - input_name = inputs[0].name - reduced_shape = list(inputs[0].shape) - dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(reduced_shape) for dim in dim_list] - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" - - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1 - reduced_shape[dim] = 1 - - attr = ts.TosaSerializerAttribute() - attr.ReduceSumAttribute(inputs[0].dim_order.index(dim)) - - if dim == dim_list[-1]: - output_name = output.name - else: - output_name = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, inputs[0].dim_order), - dtype=ts.DType.FP32, - ).name - - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr - ) - - input_name = output_name + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) + + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(tensor.dim_order.index(dim)) + + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, + [tensor.name], + [output.name], + attr, + )