Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
from .decompose_sum_pass import DecomposeSumPass # noqa
from .decompose_var_pass import DecomposeVarPass # noqa
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
FoldAndAnnotateQParamsPass,
Expand All @@ -44,7 +45,6 @@
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .insert_rescales_pass import InsertRescalePass # noqa
from .insert_table_ops import InsertTableOpsPass # noqa
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DecomposeSoftmaxPass,
DecomposeSoftmaxUnstablePass,
DecomposeSqrtPass,
DecomposeSumPass,
DecomposeVarPass,
FoldAndAnnotateQParamsPass,
FuseBatchnorm2DPass,
Expand All @@ -45,7 +46,6 @@
FuseQuantizedActivationPass,
InsertRescalePass,
InsertTableOpsPass,
KeepDimsFalseToSqueezePass,
MatchArgRanksPass,
MatchWhereSelfDtypePass,
QuantizeOperatorArguments,
Expand Down Expand Up @@ -110,7 +110,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(DecomposeSumPass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSelectPass())
self.add_pass(ConvertSqueezesToViewPass())
Expand Down Expand Up @@ -163,7 +163,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(DecomposeSumPass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSelectPass())
self.add_pass(ConvertSqueezesToViewPass())
Expand Down Expand Up @@ -220,4 +220,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):

self.add_pass(ConvertMinMaxPass())
self.add_pass(ReplaceInfValues())
self.add_pass(DecomposeSumPass())

return self._transform(graph_module)
79 changes: 79 additions & 0 deletions backends/arm/_passes/decompose_sum_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def _get_sum_decomp(op):
match op:
case exir_ops.edge.aten.sum.dim_IntList:
return (
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.sum.dim_IntList,
)
case torch.ops.aten.sum.dim_IntList:
return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList)
case _:
raise RuntimeError("Unvalid op in DecomposeSumPass")


class DecomposeSumPass(ExportPass):
"""
In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the
dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always
preserves the rank of the input (keep_dim = True). To get a 1-1 mapping in the sum
lowering, normalize the keep_dim = False case to keep_dim = True and lower the rank
with a view op.

Since TOSA can only reduce one dimension at a time, multiple dims are additionally
unrolled into multiple ops.

Original:
sum((dim_1, dim_2), keep_dim = False) -> squeezed_shape
After pass:
sum(dim_1, keep_dim = True) -> unsqueezed_shape
sum(dim_2, keep_dim = True) -> unsqueezed_shape
view(shape = squeezed_shape) -> squeezed_shape
"""

def call_operator(self, op, args, kwargs, meta):
if op not in [
exir_ops.edge.aten.sum.dim_IntList,
torch.ops.aten.sum.dim_IntList,
]:
return super().call_operator(op, args, kwargs, meta)

match len(args):
case 3:
(
input_node,
dims,
keepdims,
) = args
case 2:
(
input_node,
dims,
) = args
keepdims = False
case _:
raise ValueError(f"Invalid number of arguments ({len(args)}) provided.")

view_op, sum_op = _get_sum_decomp(op)

for dim in dims:
input_node = super().call_operator(
sum_op, (input_node, dim, True), kwargs, meta
)

if not keepdims:
shape = list(meta["val"].size())
input_node = super().call_operator(
view_op, (input_node, shape), kwargs, meta
)

return input_node
92 changes: 0 additions & 92 deletions backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

This file was deleted.

Loading
Loading