Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/decompose_embedding_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DecomposeEmbeddingPass(ArmPass):
def get_decomposition(self, op):
if op in self.aten_ops:
return (
torch.ops.aten.view_copy.default,
torch.ops.aten.reshape.default,
torch.ops.aten.index_select.default,
)

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/decompose_groupnorm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_group_norm_decomposition(op) -> tuple:
torch.ops.aten.add.Tensor,
torch.ops.aten.rsqrt.default,
torch.ops.aten.mul.Tensor,
torch.ops.aten.view_copy.default,
torch.ops.aten.reshape.default,
)
raise RuntimeError(f"Can't get group_norm composition for op {op}")

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/decompose_layernorm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_layer_norm_decomposition(op) -> tuple:
torch.ops.aten.add.Tensor,
torch.ops.aten.rsqrt.default,
torch.ops.aten.mul.Tensor,
torch.ops.aten.view_copy.default,
torch.ops.aten.reshape.default,
)
raise RuntimeError(f"Can't get layer_norm composition for op {op}")

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_view(op):
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
return exir_ops.edge.aten.view_copy.default
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
return torch.ops.aten.view_copy.default
return torch.ops.aten.reshape.default
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/decompose_sum_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _get_sum_decomp(op):
exir_ops.edge.aten.sum.dim_IntList,
)
case torch.ops.aten.sum.dim_IntList:
return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList)
return (torch.ops.aten.reshape.default, torch.ops.aten.sum.dim_IntList)
case _:
raise RuntimeError("Unvalid op in DecomposeSumPass")

Expand Down
26 changes: 24 additions & 2 deletions backends/arm/test/ops/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,17 @@ def forward(self, weights: torch.Tensor, indices: torch.Tensor):
return torch.embedding(weights, indices)


input_params = Tuple[torch.Tensor, torch.Tensor, torch.dtype]
class ExpandEmbedding(Embedding):
example_inputs = (torch.randn(10, 3), torch.tensor([[1, 2, 3]], dtype=torch.int32))

def forward(self, weights: torch.Tensor, indices: torch.Tensor):
return torch.embedding(weights, indices.expand(2, 3))


input_params = Tuple[torch.Tensor, torch.Tensor]


test_input: dict[input_params] = {
test_input: dict[str, input_params] = {
"test_1": (
torch.randn(10, 3),
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32),
Expand Down Expand Up @@ -89,6 +96,21 @@ def test_embedding_tosa_INT(test_input: input_params):
pipeline.run()


def test_expand_embedding_tosa_INT():
op = ExpandEmbedding()
pipeline = TosaPipelineINT(
op,
ExpandEmbedding.example_inputs,
ExpandEmbedding.aten_op,
ExpandEmbedding.exir_op,
use_to_edge_transform_and_lower=True,
)
pipeline.pop_stage("check.aten")
pipeline.pop_stage("check_count.exir")

pipeline.run()


@pytest.mark.skip("reason=MLETORCH-1274 Improve data type checks during partitioning")
@common.parametrize("test_input", test_input)
@common.SkipIfNoModelConverter
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/test/passes/test_decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MeanDim(torch.nn.Module):
}

ops_not_after_pass = u55_ops_not_after_pass = [
"torch.ops.aten.view_copy.default",
"torch.ops.aten.reshape.default",
"torch.ops.aten.avg_pool2d.default",
"torch.ops.aten.mean.dim",
]
Expand All @@ -52,7 +52,7 @@ class MeanDimTensor(torch.nn.Module):
"torch.ops.aten.sum.dim_IntList": 2,
"torch.ops.aten.mul.Tensor": 1,
"torch.ops.aten.avg_pool2d.default": 1,
"torch.ops.aten.view_copy.default": 1,
"torch.ops.aten.reshape.default": 1,
}

ops_not_after_pass = [
Expand All @@ -62,7 +62,7 @@ class MeanDimTensor(torch.nn.Module):
u55_ops_after_pass = {
"torch.ops.aten.sum.dim_IntList": 2,
"torch.ops.aten.mul.Tensor": 1,
"torch.ops.aten.view_copy.default": 1,
"torch.ops.aten.reshape.default": 1,
}

u55_ops_not_after_pass = [
Expand Down
Loading