diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 71dbdb5b854..aee68f74eb5 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -8,7 +8,6 @@ from typing import cast -from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -25,14 +24,14 @@ def call_operator(self, op, args, kwargs, meta): if op != self.expand_copy: return super().call_operator(op, args, kwargs, meta) - _, shape, _ = extract_tensor_meta(meta.data) + input_shape = args[0].data.shape multiples = cast(list[int], args[1]) expanded_rank = len(multiples) - # Expanded shape is 'shape' front-padded with ones. - padding = expanded_rank - len(shape) + # Expanded shape is 'input_shape' front-padded with ones. + padding = expanded_rank - len(input_shape) extended_shape = [ - shape[i] if i >= 0 else 1 for i in range(-padding, len(shape)) + input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape)) ] # To convert expand arg to repeat arg, non-repeated dims should have diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 283760e423b..73747d8313d 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta): sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta) full = super().call_operator( full_op, - ([1 for _ in shape], 1 / max(0, N - correction)), + ([], 1 / max(0, N - correction)), {"dtype": dtype}, meta, ) diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index e230fde1a05..941d20c95a1 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -90,7 +90,7 @@ def call(self, graph_module: GraphModule) -> PassResult: continue # Calculate max rank of all inputs to node - max_rank = 1 + max_rank = 0 for arg in node.args: if isinstance(arg, Node): shape = get_first_fake_tensor(arg).shape diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 084eec138f9..915b1fe7e00 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -34,12 +34,13 @@ class TestSimpleExpand(unittest.TestCase): class Expand(torch.nn.Module): # (input tensor, multiples) test_parameters = [ - (torch.ones(1), (2,)), - (torch.ones(1, 4), (1, -1)), - (torch.ones(1, 1, 2, 2), (4, 3, -1, 2)), - (torch.ones(1), (2, 2, 4)), - (torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)), - (torch.ones(1, 1, 192), (1, -1, -1)), + (torch.rand(1), (2,)), + (torch.randn(1, 4), (1, -1)), + (torch.rand(1, 1, 2, 2), (4, 3, -1, 2)), + (torch.randn(1), (2, 2, 4)), + (torch.rand(3, 2, 4, 1), (-1, -1, -1, 3)), + (torch.randn(1, 1, 192), (1, -1, -1)), + (torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)), ] def forward(self, x: torch.Tensor, multiples: Sequence):