From daa681fdd27fd972dae927f07a10ae190c2eaf06 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 14 Oct 2025 14:00:06 +0200 Subject: [PATCH] Arm backend: Use reshape instead of view before edge The view operator can't handle non-contigious strides, such as the result of an expand. These are normalized after to_edge, but in the transform_for_annotation_pipeline we shouldn't use views for that reason. Reshape is the equivalent operator that can handle such strides. This issue was found in the roberta model. Signed-off-by: Erik Lundell Change-Id: I042b76ecf8b99e4e65bf951c00ec53a9d0d36c80 --- .../arm/_passes/decompose_embedding_pass.py | 2 +- .../arm/_passes/decompose_groupnorm_pass.py | 2 +- .../arm/_passes/decompose_layernorm_pass.py | 2 +- .../arm/_passes/decompose_meandim_pass.py | 2 +- backends/arm/_passes/decompose_sum_pass.py | 2 +- backends/arm/test/ops/test_embedding.py | 26 +++++++++++++++++-- .../passes/test_decompose_meandim_pass.py | 6 ++--- 7 files changed, 32 insertions(+), 10 deletions(-) diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index ac424230491..f58532d9297 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -44,7 +44,7 @@ class DecomposeEmbeddingPass(ArmPass): def get_decomposition(self, op): if op in self.aten_ops: return ( - torch.ops.aten.view_copy.default, + torch.ops.aten.reshape.default, torch.ops.aten.index_select.default, ) diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index 29d68234b29..bcf764b0b95 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -40,7 +40,7 @@ def get_group_norm_decomposition(op) -> tuple: torch.ops.aten.add.Tensor, torch.ops.aten.rsqrt.default, torch.ops.aten.mul.Tensor, - torch.ops.aten.view_copy.default, + torch.ops.aten.reshape.default, ) raise RuntimeError(f"Can't get group_norm composition for op {op}") diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index c73806b0022..2e8ba7cc56d 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -40,7 +40,7 @@ def get_layer_norm_decomposition(op) -> tuple: torch.ops.aten.add.Tensor, torch.ops.aten.rsqrt.default, torch.ops.aten.mul.Tensor, - torch.ops.aten.view_copy.default, + torch.ops.aten.reshape.default, ) raise RuntimeError(f"Can't get layer_norm composition for op {op}") diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 4d4c0ee75b1..a1940ef7d77 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -46,7 +46,7 @@ def get_view(op): if op == exir_ops.edge.aten.mean.dim: return exir_ops.edge.aten.view_copy.default if op == torch.ops.aten.mean.dim: - return torch.ops.aten.view_copy.default + return torch.ops.aten.reshape.default raise RuntimeError(f"Can't get meandim decomposition for op {op}") diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 59c352a0e07..9f3c7aaf390 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -19,7 +19,7 @@ def _get_sum_decomp(op): exir_ops.edge.aten.sum.dim_IntList, ) case torch.ops.aten.sum.dim_IntList: - return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList) + return (torch.ops.aten.reshape.default, torch.ops.aten.sum.dim_IntList) case _: raise RuntimeError("Unvalid op in DecomposeSumPass") diff --git a/backends/arm/test/ops/test_embedding.py b/backends/arm/test/ops/test_embedding.py index 901fbbc0916..23b14ae5c44 100644 --- a/backends/arm/test/ops/test_embedding.py +++ b/backends/arm/test/ops/test_embedding.py @@ -27,10 +27,17 @@ def forward(self, weights: torch.Tensor, indices: torch.Tensor): return torch.embedding(weights, indices) -input_params = Tuple[torch.Tensor, torch.Tensor, torch.dtype] +class ExpandEmbedding(Embedding): + example_inputs = (torch.randn(10, 3), torch.tensor([[1, 2, 3]], dtype=torch.int32)) + + def forward(self, weights: torch.Tensor, indices: torch.Tensor): + return torch.embedding(weights, indices.expand(2, 3)) + + +input_params = Tuple[torch.Tensor, torch.Tensor] -test_input: dict[input_params] = { +test_input: dict[str, input_params] = { "test_1": ( torch.randn(10, 3), torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32), @@ -89,6 +96,21 @@ def test_embedding_tosa_INT(test_input: input_params): pipeline.run() +def test_expand_embedding_tosa_INT(): + op = ExpandEmbedding() + pipeline = TosaPipelineINT( + op, + ExpandEmbedding.example_inputs, + ExpandEmbedding.aten_op, + ExpandEmbedding.exir_op, + use_to_edge_transform_and_lower=True, + ) + pipeline.pop_stage("check.aten") + pipeline.pop_stage("check_count.exir") + + pipeline.run() + + @pytest.mark.skip("reason=MLETORCH-1274 Improve data type checks during partitioning") @common.parametrize("test_input", test_input) @common.SkipIfNoModelConverter diff --git a/backends/arm/test/passes/test_decompose_meandim_pass.py b/backends/arm/test/passes/test_decompose_meandim_pass.py index 22dda5d9244..e771d74b5c4 100644 --- a/backends/arm/test/passes/test_decompose_meandim_pass.py +++ b/backends/arm/test/passes/test_decompose_meandim_pass.py @@ -28,7 +28,7 @@ class MeanDim(torch.nn.Module): } ops_not_after_pass = u55_ops_not_after_pass = [ - "torch.ops.aten.view_copy.default", + "torch.ops.aten.reshape.default", "torch.ops.aten.avg_pool2d.default", "torch.ops.aten.mean.dim", ] @@ -52,7 +52,7 @@ class MeanDimTensor(torch.nn.Module): "torch.ops.aten.sum.dim_IntList": 2, "torch.ops.aten.mul.Tensor": 1, "torch.ops.aten.avg_pool2d.default": 1, - "torch.ops.aten.view_copy.default": 1, + "torch.ops.aten.reshape.default": 1, } ops_not_after_pass = [ @@ -62,7 +62,7 @@ class MeanDimTensor(torch.nn.Module): u55_ops_after_pass = { "torch.ops.aten.sum.dim_IntList": 2, "torch.ops.aten.mul.Tensor": 1, - "torch.ops.aten.view_copy.default": 1, + "torch.ops.aten.reshape.default": 1, } u55_ops_not_after_pass = [