diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 7b73cddad37..7ebe5e4c3a5 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -20,4 +20,5 @@ slice_copy_support, to_dim_order_copy_support, tosa_supported_operators, + where_support, ) diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index d3207c65dff..814c72b42fc 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -97,7 +97,6 @@ exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.pow.Tensor_Tensor, - exir_ops.edge.aten.where.self, operator.getitem, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.quantize_per_channel.default, @@ -211,7 +210,6 @@ exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.pow.Tensor_Tensor, - exir_ops.edge.aten.where.self, operator.getitem, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.amax.default, diff --git a/backends/arm/operator_support/where_support.py b/backends/arm/operator_support/where_support.py new file mode 100644 index 00000000000..2ec7c30827d --- /dev/null +++ b/backends/arm/operator_support/where_support.py @@ -0,0 +1,77 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +import torch.fx as fx +from executorch.backends.arm.constants import DQ_OPS +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class WhereSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.where.self] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: # type: ignore[override, misc] + + if len(node.all_input_nodes) != 3: + self.reporter.report_reject( + node, + ( + "Expected exactly three input nodes, " + f"got {len(node.all_input_nodes)} for {node.target}." + ), + ) + return False + + condition, x, y = node.all_input_nodes + if condition.meta["val"].dtype != torch.bool: + self.reporter.report_reject( + node, + f"Type of condition in {node.target} is not torch.bool", + ) + return False + + x_dtype, y_dtype = x.meta["val"].dtype, y.meta["val"].dtype + if tosa_spec.support_float(): + if x_dtype in (torch.bool, torch.float16, torch.float32) and y_dtype in ( + torch.bool, + torch.float16, + torch.float32, + ): + return True + + if tosa_spec.support_integer(): + if ( + x_dtype in (torch.bool, torch.int8, torch.int16, torch.int32) + or (x_dtype == torch.float32 and x.target in DQ_OPS) + ) and ( + y_dtype in (torch.bool, torch.int8, torch.int16, torch.int32) + or (y_dtype == torch.float32 and y.target in DQ_OPS) + ): + return True + + self.reporter.report_reject( + node, + ( + f"Tensor x dtype {x_dtype} and/or tensor y dtype {y_dtype} is not supported in {node.target} " + f"for tosa specification {tosa_spec}" + ), + ) + + return False diff --git a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py index 0e99f3f5bfa..49266beee63 100644 --- a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py +++ b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py @@ -6,7 +6,6 @@ import unittest -import pytest import torch from executorch.backends.arm._passes import ( ConvertInt64ConstOpsToInt32Pass, @@ -28,16 +27,25 @@ class TestCLIPTextModelWithProjection(unittest.TestCase): CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium """ - # Adjust nbr below as we increase op support. Note: most of the delegates - # calls are directly consecutive to each other in the .pte. The reason - # for that is some assert ops are removed by passes in the - # .to_executorch step, i.e. after Arm partitioner. - ops_after_partitioner = { + # Adjust nbr below as we increase op support. + ops_after_partitioner_FP = { "executorch_exir_dialects_edge__ops_aten_argmax_default": 1, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, "torch.ops.higher_order.executorch_call_delegate": 2, } + ops_after_partitioner_INT = { + "executorch_exir_dialects_edge__ops_aten_argmax_default": 1, + "executorch_exir_dialects_edge__ops_aten_full_default": 1, + "executorch_exir_dialects_edge__ops_aten_index_select_default": 1, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_where_self": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, + "torch.ops.aten.scalar_tensor.default": 1, + "torch.ops.higher_order.executorch_call_delegate": 2, + } + def _prepare_inputs( self, batch_size=12, @@ -78,14 +86,13 @@ def test_CLIPTextModelWithProjection_tosa_FP(self): .export() .to_edge_transform_and_lower() .dump_operator_distribution() - .check_count(self.ops_after_partitioner) + .check_count(self.ops_after_partitioner_FP) .to_executorch() .run_method_and_compare_outputs( inputs=text_encoder_model_inputs, ) ) - @pytest.mark.xfail(raises=AssertionError, reason="Output difference.") def test_CLIPTextModelWithProjection_tosa_INT(self): text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs() with torch.no_grad(): @@ -99,8 +106,10 @@ def test_CLIPTextModelWithProjection_tosa_INT(self): .export() .to_edge_transform_and_lower() .dump_operator_distribution() + .check_count(self.ops_after_partitioner_INT) .to_executorch() .run_method_and_compare_outputs( inputs=text_encoder_model_inputs, + atol=0.8, ) ) diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py index ea036d26361..f27c8358cdc 100644 --- a/backends/arm/test/ops/test_where.py +++ b/backends/arm/test/ops/test_where.py @@ -139,8 +139,11 @@ def scalar_condition(input: torch.Tensor): test_modules_FP = { **test_modules_common, - "float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype, "float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool, +} + +test_modules_FP_unsupported_dtype = { + "float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype, "int32_scalar_cond": lambda: int32_scalar_cond, } @@ -162,6 +165,17 @@ def test_where_self_tosa_FP(test_module): pipeline.run() +@common.parametrize("test_module", test_modules_FP_unsupported_dtype) +def test_where_self_tosa_FP_unsupported_dtype(test_module): + pipeline = OpNotSupportedPipeline[input_t]( + test_module(), + test_module().get_inputs(), + {exir_op: 1}, + n_expected_delegates=1, # condition can be delegated + ) + pipeline.run() + + @common.parametrize("test_module", test_modules_INT) def test_where_self_tosa_INT(test_module): pipeline = TosaPipelineINT[input_t](