diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index f932ae7f4c4..0cd306086cb 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -20,6 +20,7 @@ def calculate_multiples(args): + """Returns expand args converted to repeat args, and whether the expand changes the rank""" input_node_or_tensor = args[0] if isinstance(input_node_or_tensor, torch.fx.node.Node): @@ -45,7 +46,7 @@ def calculate_multiples(args): multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 for i in range(expanded_rank) ] - return multiples + return multiples, expanded_rank != len(input_shape) class ConvertExpandCopyToRepeatPass(ArmPass): @@ -62,9 +63,9 @@ def call_operator(self, op, args, kwargs, meta): if op != self.expand_copy: return super().call_operator(op, args, kwargs, meta) - multiples = calculate_multiples(args) + multiples, changes_rank = calculate_multiples(args) - if all((x == 1 for x in multiples)): + if all((x == 1 for x in multiples)) and not changes_rank: # All dimensions/repetitions occur only once. Remove node # altogether since it's in practice just a copy. logger.warning("Found redundant expand node (no-op). Removing it.") diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 83294369ae7..1b9a02fbb6a 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -110,8 +110,8 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.aten.expand_copy.default: return False else: - multiples = calculate_multiples(node.args) - return all(m == 1 for m in multiples) + multiples, changes_rank = calculate_multiples(node.args) + return all(m == 1 for m in multiples) and not changes_rank def is_partitioned(