diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index f37bc06d16a..56e124d8d0a 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -161,6 +161,7 @@ def f(node_name_pre_computed): exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.eye.default, exir_ops.edge.aten.linspace.default, + torch.ops.aten.scalar_tensor.default, ] def __init__(self, exported_program: ExportedProgram) -> None: diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 3f79f9f879b..c65c88a2e73 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -205,6 +205,7 @@ def is_node_supported( exir_ops.edge.aten.amin.default, exir_ops.edge.aten.eye.default, exir_ops.edge.aten.linspace.default, + torch.ops.aten.scalar_tensor.default, ] return supported diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 3d32454f8de..42da8fde7d2 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -35,7 +35,6 @@ class TestConformer(unittest.TestCase): "executorch_exir_dialects_edge__ops_aten_where_self": 4, "torch.ops.aten._assert_scalar.default": 10, "torch.ops.aten._local_scalar_dense.default": 1, - "torch.ops.aten.scalar_tensor.default": 2, "torch.ops.higher_order.executorch_call_delegate": 6, } @@ -92,7 +91,7 @@ def test_conformer_tosa_BI(self): ) ) - @conftest.expectedFailureOnFVP # TODO(MLETORCH-635) + @unittest.expectedFailure # TODO(MLETORCH-635) def test_conformer_u55_BI(self): tester = ( ArmTester( @@ -114,7 +113,7 @@ def test_conformer_u55_BI(self): inputs=get_test_inputs(self.dim, self.lengths, self.num_examples), ) - @conftest.expectedFailureOnFVP # TODO(MLETORCH-635) + @unittest.expectedFailure # TODO(MLETORCH-635) def test_conformer_u85_BI(self): tester = ( ArmTester( diff --git a/backends/arm/test/ops/test_scalar_tensor.py b/backends/arm/test/ops/test_scalar_tensor.py new file mode 100644 index 00000000000..ad9d385c1d1 --- /dev/null +++ b/backends/arm/test/ops/test_scalar_tensor.py @@ -0,0 +1,91 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +float_test_data_suite = { + "scalar_tensor_float_1": (3.7, torch.float32, torch.rand((1, 2, 3, 4))), + "scalar_tensor_float_2": (66, torch.float32, torch.rand((1, 2, 3))), +} + +int_test_data_suite = { + "scalar_tensor_int32": ( + 33, + torch.int32, + torch.randint(0, 10, (1, 2), dtype=torch.int32), + ), + "scalar_tensor_int8": ( + 8, + torch.int8, + torch.rand(1, 2, 3), + ), + "scalar_tensor_int16": ( + 16 * 16 * 16, + torch.int16, + torch.rand((1,)).unsqueeze(0), # Rank 0 inputs not supported + ), +} + + +class ScalarTensor(torch.nn.Module): + aten_op = "torch.ops.aten.scalar_tensor.default" + + def __init__(self, scalar, dtype=torch.float32): + super().__init__() + self.scalar = scalar + self.dtype = dtype + + def forward(self, x: torch.Tensor): + return torch.scalar_tensor(self.scalar, dtype=self.dtype) + x + + +@common.parametrize("test_data", int_test_data_suite | float_test_data_suite) +def test_scalar_tensor_tosa_MI(test_data): # Note TOSA MI supports all types + scalar, dtype, data = test_data + TosaPipelineMI(ScalarTensor(scalar, dtype), tuple(data), ScalarTensor.aten_op).run() + + +@common.parametrize("test_data", int_test_data_suite | float_test_data_suite) +def test_scalar_tensor_tosa_BI(test_data): + scalar, dtype, data = test_data + pipeline: TosaPipelineBI = TosaPipelineBI( + ScalarTensor(scalar, dtype), tuple(data), ScalarTensor.aten_op + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +@common.parametrize("test_data", float_test_data_suite) +@common.XfailIfNoCorstone300 +def test_scalar_tensor_tosa_u55(test_data): + scalar, dtype, data = test_data + EthosU55PipelineBI( + ScalarTensor(scalar, dtype), + tuple(data), + ScalarTensor.aten_op, + symmetric_io_quantization=True, + run_on_fvp=True, + ).run() + + +@common.parametrize("test_data", float_test_data_suite) +@common.XfailIfNoCorstone320 +def test_scalar_tensor_tosa_u85(test_data): + scalar, dtype, data = test_data + EthosU85PipelineBI( + ScalarTensor(scalar, dtype), + tuple(data), + ScalarTensor.aten_op, + symmetric_io_quantization=True, + run_on_fvp=True, + ).run()