diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 35879d5026c..27722608c95 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -32,6 +32,7 @@ from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa from .decompose_sqrt_pass import DecomposeSqrtPass # noqa +from .decompose_sum_pass import DecomposeSumPass # noqa from .decompose_var_pass import DecomposeVarPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, @@ -44,7 +45,6 @@ from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa from .insert_rescales_pass import InsertRescalePass # noqa from .insert_table_ops import InsertTableOpsPass # noqa -from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c5ebace2834..21ff11b3598 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -37,6 +37,7 @@ DecomposeSoftmaxPass, DecomposeSoftmaxUnstablePass, DecomposeSqrtPass, + DecomposeSumPass, DecomposeVarPass, FoldAndAnnotateQParamsPass, FuseBatchnorm2DPass, @@ -45,7 +46,6 @@ FuseQuantizedActivationPass, InsertRescalePass, InsertTableOpsPass, - KeepDimsFalseToSqueezePass, MatchArgRanksPass, MatchWhereSelfDtypePass, QuantizeOperatorArguments, @@ -110,7 +110,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(KeepDimsFalseToSqueezePass()) + self.add_pass(DecomposeSumPass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) @@ -163,7 +163,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(KeepDimsFalseToSqueezePass()) + self.add_pass(DecomposeSumPass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) @@ -220,4 +220,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ConvertMinMaxPass()) self.add_pass(ReplaceInfValues()) + self.add_pass(DecomposeSumPass()) + return self._transform(graph_module) diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py new file mode 100644 index 00000000000..531b0d72a19 --- /dev/null +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -0,0 +1,79 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +def _get_sum_decomp(op): + match op: + case exir_ops.edge.aten.sum.dim_IntList: + return ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.sum.dim_IntList, + ) + case torch.ops.aten.sum.dim_IntList: + return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList) + case _: + raise RuntimeError("Unvalid op in DecomposeSumPass") + + +class DecomposeSumPass(ExportPass): + """ + In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the + dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always + preserves the rank of the input (keep_dim = True). To get a 1-1 mapping in the sum + lowering, normalize the keep_dim = False case to keep_dim = True and lower the rank + with a view op. + + Since TOSA can only reduce one dimension at a time, multiple dims are additionally + unrolled into multiple ops. + + Original: + sum((dim_1, dim_2), keep_dim = False) -> squeezed_shape + After pass: + sum(dim_1, keep_dim = True) -> unsqueezed_shape + sum(dim_2, keep_dim = True) -> unsqueezed_shape + view(shape = squeezed_shape) -> squeezed_shape + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in [ + exir_ops.edge.aten.sum.dim_IntList, + torch.ops.aten.sum.dim_IntList, + ]: + return super().call_operator(op, args, kwargs, meta) + + match len(args): + case 3: + ( + input_node, + dims, + keepdims, + ) = args + case 2: + ( + input_node, + dims, + ) = args + keepdims = False + case _: + raise ValueError(f"Invalid number of arguments ({len(args)}) provided.") + + view_op, sum_op = _get_sum_decomp(op) + + for dim in dims: + input_node = super().call_operator( + sum_op, (input_node, dim, True), kwargs, meta + ) + + if not keepdims: + shape = list(meta["val"].size()) + input_node = super().call_operator( + view_op, (input_node, shape), kwargs, meta + ) + + return input_node diff --git a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py deleted file mode 100644 index 744436cba9e..00000000000 --- a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import cast - -import torch -import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import ( - create_node, - get_node_arg, - set_node_arg, -) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class KeepDimsFalseToSqueezePass(ExportPass): - """ - In Pytorch, the default behaviour of for example Tensor.sum is to squeeze - the dimension that is summed (keep_dim = False). - However, in TOSA, REDUCE_SUM always preserves the - rank of the input (keep_dim = True). - To get a 1-1 mapping in the sum lowering, normalize the - keep_dim = False case to keep_dim = True and add squeeze ops. - - Original: - sum(dims, keep_dim = False) - After pass: - sum(dims, keep_dim = True) - squeeze(dim = dims) - """ - - # CURRENTLY NOT HANDLED OPS - # exir_ops.edge.aten.argmax, - # exir_ops.edge.aten.argmin, - # exir_ops.edge.aten.prod.dim_int, - - # HANDLED OPS - # exir_ops.edge.aten.sum.dim_IntList - # exir_ops.edge.aten.any.default (decomposed in convert_any_default_dim_dims_pass) - # exir_ops.edge.aten.any.dim (decomposed in convert_any_default_dim_dims_pass) - # exir_ops.edge.aten.any.dims (decomposed in convert_any_default_dim_dims_pass) - # exir_ops.edge.aten.max.dim (decomposed in convert_minmax_pass) - # exir_ops.edge.aten.min.dim (decomposed in convert_minmax_pass) - # exir_ops.edge.aten.amin (decomposed in convert_minmax_pass) - # exir_ops.edge.aten.amax (decomposed in convert_minmax_pass) - # exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass) - # exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass) - # exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass) - - def call(self, graph_module: torch.fx.GraphModule): - for node in graph_module.graph.nodes: - keep_dim_index = None - - if node.op != "call_function": - continue - if node.target == exir_ops.edge.aten.sum.dim_IntList: - keep_dim_index = 2 - else: - continue - - sum_node = cast(torch.fx.Node, node) - keep_dim = get_node_arg( - # pyre-ignore[6] - sum_node.args, # type: ignore[arg-type] - keep_dim_index, - False, - ) - - if keep_dim: - continue - - dim_list = get_node_arg(sum_node.args, 1, [0]) # type: ignore[arg-type] # pyre-ignore[6] - - # Add keep_dim = True arg to sum node. - set_node_arg(sum_node, 2, True) - - with graph_module.graph.inserting_after(sum_node): - squeeze_node = create_node( - graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, () - ) - sum_node.replace_all_uses_with(squeeze_node) - squeeze_node.args = (sum_node, dim_list) - - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - return PassResult(graph_module, True) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index f232136fd9b..c0a436f4d99 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -5,7 +5,7 @@ # pyre-unsafe -from typing import Any, cast, List +from typing import Any, List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils @@ -45,41 +45,36 @@ def define_node( validate_num_inputs(self.target, inputs, 3) - input_shape = list(inputs[0].shape) - dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(input_shape) for dim in dim_list] - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) + + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced # Rescale input to 32 bit rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( tosa_graph, - [inputs[0]], + [tensor], node, ) - prev_node = rescaled_inputs[0] - reduced_shape = input_shape - - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1. - reduced_shape[dim] = 1 - - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(inputs[0].dim_order.index(dim)) + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(tensor.dim_order.index(dim)) - next_node = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, inputs[0].dim_order), - dtype=ts.DType.INT32, - ) + intermediate = tosa_graph.addIntermediate( + tutils.tosa_shape(output_shape, tensor.dim_order), + dtype=ts.DType.INT32, + ) - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr - ) + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, + [rescaled_inputs[0].name], + [intermediate.name], + attr, + ) - prev_node = next_node - tqutils.insert_rescale_op_to_int8(tosa_graph, prev_node, scale, node) + tqutils.insert_rescale_op_to_int8(tosa_graph, intermediate, scale, node) @register_node_visitor @@ -103,38 +98,27 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 3) - if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) - input_name = inputs[0].name - reduced_shape = list(inputs[0].shape) - dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(reduced_shape) for dim in dim_list] - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1 - reduced_shape[dim] = 1 + validate_num_inputs(self.target, inputs, 3) - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(inputs[0].dim_order.index(dim)) + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) - if dim == dim_list[-1]: - output_name = output.name - else: - output_name = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, inputs[0].dim_order), - dtype=ts.DType.FP32, - ).name + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr - ) + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(tensor.dim_order.index(dim)) - input_name = output_name + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, + [tensor.name], + [output.name], + attr, + ) @register_node_visitor @@ -160,45 +144,37 @@ def define_node( validate_num_inputs(self.target, inputs, 3) - input_shape = list(inputs[0].shape) - dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(input_shape) for dim in dim_list] - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) + + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced # Rescale input to 32 bit rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( tosa_graph, - [inputs[0]], + [tensor], node, - self.tosa_specs, ) - prev_node = rescaled_inputs[0] - reduced_shape = input_shape - - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1. - reduced_shape[dim] = 1 + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(tensor.dim_order.index(dim)) - attr = ts.TosaSerializerAttribute() - attr.ReduceSumAttribute(inputs[0].dim_order.index(dim)) - - next_node = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, inputs[0].dim_order), - dtype=ts.DType.INT32, - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr - ) + intermediate = tosa_graph.addIntermediate( + tutils.tosa_shape(output_shape, tensor.dim_order), + dtype=ts.DType.INT32, + ) - prev_node = next_node - tqutils.insert_rescale_op_to_int8( - tosa_graph, prev_node, scale, node, self.tosa_specs + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, + [rescaled_inputs[0].name], + [intermediate.name], + attr, ) + tqutils.insert_rescale_op_to_int8(tosa_graph, intermediate, scale, node) + @register_node_visitor class SumVisitor_FP(SumVisitor_INT): @@ -221,33 +197,19 @@ def define_node( validate_num_inputs(self.target, inputs, 3) - if inputs[0].dtype == ts.DType.INT8: - return super().define_node(node, tosa_graph, inputs, output) - input_name = inputs[0].name - reduced_shape = list(inputs[0].shape) - dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(reduced_shape) for dim in dim_list] - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" - - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1 - reduced_shape[dim] = 1 - - attr = ts.TosaSerializerAttribute() - attr.ReduceSumAttribute(inputs[0].dim_order.index(dim)) - - if dim == dim_list[-1]: - output_name = output.name - else: - output_name = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, inputs[0].dim_order), - dtype=ts.DType.FP32, - ).name - - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr - ) - - input_name = output_name + tensor = inputs[0] + input_shape = list(tensor.shape) + dim = int(inputs[1].number % len(input_shape)) + + output_shape = input_shape + output_shape[dim] = 1 # Output shape is input shape with dim reduced + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(tensor.dim_order.index(dim)) + + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_SUM, + [tensor.name], + [output.name], + attr, + )