Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
slice_copy_support,
to_dim_order_copy_support,
tosa_supported_operators,
where_support,
)
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.where.self,
operator.getitem,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
Expand Down Expand Up @@ -211,7 +210,6 @@
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.where.self,
operator.getitem,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.amax.default,
Expand Down
77 changes: 77 additions & 0 deletions backends/arm/operator_support/where_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch

import torch.fx as fx
from executorch.backends.arm.constants import DQ_OPS
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class WhereSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.where.self]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool: # type: ignore[override, misc]

if len(node.all_input_nodes) != 3:
self.reporter.report_reject(
node,
(
"Expected exactly three input nodes, "
f"got {len(node.all_input_nodes)} for {node.target}."
),
)
return False

condition, x, y = node.all_input_nodes
if condition.meta["val"].dtype != torch.bool:
self.reporter.report_reject(
node,
f"Type of condition in {node.target} is not torch.bool",
)
return False

x_dtype, y_dtype = x.meta["val"].dtype, y.meta["val"].dtype
if tosa_spec.support_float():
if x_dtype in (torch.bool, torch.float16, torch.float32) and y_dtype in (
torch.bool,
torch.float16,
torch.float32,
):
return True

if tosa_spec.support_integer():
if (
x_dtype in (torch.bool, torch.int8, torch.int16, torch.int32)
or (x_dtype == torch.float32 and x.target in DQ_OPS)
) and (
y_dtype in (torch.bool, torch.int8, torch.int16, torch.int32)
or (y_dtype == torch.float32 and y.target in DQ_OPS)
):
return True

self.reporter.report_reject(
node,
(
f"Tensor x dtype {x_dtype} and/or tensor y dtype {y_dtype} is not supported in {node.target} "
f"for tosa specification {tosa_spec}"
),
)

return False
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import unittest

import pytest
import torch
from executorch.backends.arm._passes import (
ConvertInt64ConstOpsToInt32Pass,
Expand All @@ -28,16 +27,25 @@ class TestCLIPTextModelWithProjection(unittest.TestCase):
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
"""

# Adjust nbr below as we increase op support. Note: most of the delegates
# calls are directly consecutive to each other in the .pte. The reason
# for that is some assert ops are removed by passes in the
# .to_executorch step, i.e. after Arm partitioner.
ops_after_partitioner = {
# Adjust nbr below as we increase op support.
ops_after_partitioner_FP = {
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
"torch.ops.higher_order.executorch_call_delegate": 2,
}

ops_after_partitioner_INT = {
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
"executorch_exir_dialects_edge__ops_aten_full_default": 1,
"executorch_exir_dialects_edge__ops_aten_index_select_default": 1,
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 1,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_where_self": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
"torch.ops.aten.scalar_tensor.default": 1,
"torch.ops.higher_order.executorch_call_delegate": 2,
}

def _prepare_inputs(
self,
batch_size=12,
Expand Down Expand Up @@ -78,14 +86,13 @@ def test_CLIPTextModelWithProjection_tosa_FP(self):
.export()
.to_edge_transform_and_lower()
.dump_operator_distribution()
.check_count(self.ops_after_partitioner)
.check_count(self.ops_after_partitioner_FP)
.to_executorch()
.run_method_and_compare_outputs(
inputs=text_encoder_model_inputs,
)
)

@pytest.mark.xfail(raises=AssertionError, reason="Output difference.")
def test_CLIPTextModelWithProjection_tosa_INT(self):
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
with torch.no_grad():
Expand All @@ -99,8 +106,10 @@ def test_CLIPTextModelWithProjection_tosa_INT(self):
.export()
.to_edge_transform_and_lower()
.dump_operator_distribution()
.check_count(self.ops_after_partitioner_INT)
.to_executorch()
.run_method_and_compare_outputs(
inputs=text_encoder_model_inputs,
atol=0.8,
)
)
16 changes: 15 additions & 1 deletion backends/arm/test/ops/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,11 @@ def scalar_condition(input: torch.Tensor):

test_modules_FP = {
**test_modules_common,
"float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype,
"float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool,
}

test_modules_FP_unsupported_dtype = {
"float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype,
"int32_scalar_cond": lambda: int32_scalar_cond,
}

Expand All @@ -162,6 +165,17 @@ def test_where_self_tosa_FP(test_module):
pipeline.run()


@common.parametrize("test_module", test_modules_FP_unsupported_dtype)
def test_where_self_tosa_FP_unsupported_dtype(test_module):
pipeline = OpNotSupportedPipeline[input_t](
test_module(),
test_module().get_inputs(),
{exir_op: 1},
n_expected_delegates=1, # condition can be delegated
)
pipeline.run()


@common.parametrize("test_module", test_modules_INT)
def test_where_self_tosa_INT(test_module):
pipeline = TosaPipelineINT[input_t](
Expand Down
Loading