Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 83 additions & 14 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def tosa_support_factory(

# Negative checks: Remove nodes from partitioning
negative_checks: list[OperatorSupportBase] = [
CheckInt64Inputs(exported_program, reporter),
CheckInt64InputsAndOutputs(exported_program, reporter),
CheckFloat64Inputs(exported_program, reporter),
RankCheck(reporter, max_rank=5),
*[
Expand Down Expand Up @@ -454,7 +454,18 @@ def is_node_supported(
return True


class CheckInt64Inputs(OperatorSupportBase):
class CheckInt64InputsAndOutputs(OperatorSupportBase):
"""TOSA does not support int64 tensors so in general, ops with int64 inputs or outputs should not be partitioned.
There are however some exceptions:
- Nodes with int64 output can be partitioned if they are constant, within int32,
and all users cast to something else. In this case, the int64 tensor can safely be cast to int32 AOT.
- Nodes with int64 output can be partitioned if all users are getitem with non-int64 output.
In this case, there are multiple outputs and the int64 ones are not used.
- Nodes with int64 inputs can be partitioned if the inputs are constant placeholders, or constant
ops fulfilling the criteria above.
Note that we don't check placeholders here, they are partitioned based on whether their users are partitioned
or not.
"""

def __init__(
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
Expand All @@ -465,27 +476,85 @@ def __init__(
if spec.kind == InputKind.USER_INPUT
]
self.reporter = reporter
self.int32_min = torch.iinfo(torch.int32).min
self.int32_max = torch.iinfo(torch.int32).max
super().__init__()

def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
"""Node is assumed to be call_function with int64 output."""
if isinstance(node.target, str):
return False
data = node.target(*node.args, **node.kwargs)
min_val, max_val = int(torch.min(data)), int(torch.max(data))
return min_val >= self.int32_min and max_val <= self.int32_max

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

vals = node.meta["val"]
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]

any_int64 = any(tensor.dtype == torch.int64 for tensor in tensor_list)
# Don't partition nodes with int64 output...
if any_int64:
# ... Except for constant ops that are directly cast to something non-int64.
# This could be an explicit cast, or something like a less than that outputs a different dtype than the input.
users_output_non_int64 = all(
get_first_fake_tensor(output_node).dtype != torch.int64
for output_node in node.users
)
if (
node.target in ComputeConstantOpsAOT.targeted_ops
and users_output_non_int64
):
if not self.inside_int32_bounds(node):
self.reporter.report_reject(
node, "Constant node outside int32 range."
)
return False
# Will never have input nodes, safe to return True
return True

# ... Or ops with multiple outputs where only non-int64 are used.
users_are_getitem = all(
user.target == operator.getitem for user in node.users
)
if users_are_getitem and users_output_non_int64:
# Passed output check, go to input check.
pass
else:
self.reporter.report_reject(
node, "Non-constant node with int64 output."
)
return False

# Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned.
# If it is not partitioned, the partition will get an int64 input and fail.
for input_node in node.all_input_nodes:
# We can cast constant placeholders and constant ops AOT, such int64 are ok.
# Otherwise, don't partition if one or more inputs are int64.
tensor_in = get_first_fake_tensor(input_node)
if tensor_in.dtype != torch.int64:
continue
# Constant placeholder
if (
input_node.name in self.input_names
or not input_node.op == "placeholder"
input_node.op != "call_function"
and input_node.name not in self.input_names
):
tensor = get_first_fake_tensor(input_node)
if tensor.dtype == torch.int64:
if input_node.target not in ComputeConstantOpsAOT.targeted_ops:
self.reporter.report_reject(
node,
f"Had int64 input {input_node.name} that couldn't be handled.",
)
return False
continue
# Constant operator
if input_node.op == "call_function":
if input_node.target in ComputeConstantOpsAOT.targeted_ops:
# This is not perfect since the input_node can still be rejected by other checks but
# this should cover the majority of cases.
if self.is_node_supported(
None, input_node # type: ignore[arg-type] #(we don't use 'submodules')
):
continue
self.reporter.report_reject(
node, f"Non-constant int64 input {input_node.name}"
)
return False

return True


Expand Down
116 changes: 116 additions & 0 deletions backends/arm/test/misc/test_int64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester


class ConstAdd(torch.nn.Module):
def __init__(self, dtype: torch.dtype, bias=0):
super().__init__()
self.dtype = dtype
self.bias = bias

def forward(self, x: torch.Tensor):
c = torch.arange(self.bias, self.bias + 10, 1, dtype=self.dtype)
# Add explicit float cast to make quantization work, will be inserted by type promotion otherwise.
return x + c.to(torch.float32)


class BufferAdd(torch.nn.Module):
def __init__(self, dtype: torch.dtype, bias=0):
super().__init__()
self.dtype = dtype
self.buffer = torch.arange(0, 10, 1, dtype=self.dtype) + bias
self.bias = bias

def forward(self, x: torch.Tensor):
c = self.buffer
# Add explicit float cast to make quantization work, will be inserted by type promotion otherwise.
return x + c.to(torch.float32)


class ConstChainAdd(torch.nn.Module):
def __init__(self, dtype: torch.dtype):
super().__init__()
self.dtype = dtype

def forward(self, x: torch.Tensor):
c = torch.arange(0, 10, 1, dtype=self.dtype).reshape((2, 5)).unsqueeze(-1)
# Add explicit float cast to make quantization work, will be inserted by type promotion otherwise.
return x + c.to(torch.float32)


class BufferChainAdd(torch.nn.Module):
def __init__(self, dtype: torch.dtype):
super().__init__()
self.dtype = dtype
self.buffer = torch.arange(0, 10, 1, dtype=self.dtype)

def forward(self, x: torch.Tensor):
c = self.buffer.reshape((2, 5)).unsqueeze(-1)
# Add explicit float cast to make quantization work, will be inserted by type promotion otherwise.
return x + c.to(torch.float32)


test_data_suite = {
"fp32_in+int64_buffer": (BufferAdd(torch.int64), (torch.rand(10) - 0.5,)),
"fp32_in+int64_buffer_overflow": (
BufferAdd(torch.int64, 2**40),
(torch.rand(10) - 0.5,),
),
"fp32_in+int64_const": (ConstAdd(torch.int64), (torch.rand(10) - 0.5,)),
"fp32_in+int64_const_overflow": (
ConstAdd(torch.int64, 2**40),
(torch.rand(10) - 0.5,),
),
"int64_in+float_const": (
ConstAdd(torch.float32),
(torch.randint(0, 10, (10,)),),
),
"fp32_in+int64_buffer_chain": (
BufferChainAdd(torch.int64),
(torch.rand(2, 5, 3) - 0.5,),
),
"fp32_in+int64_const_chain": (
ConstChainAdd(torch.int64),
(torch.rand(2, 5, 3) - 0.5,),
),
"int64_in+float_const_chain": (
ConstChainAdd(torch.float32),
(torch.randint(0, 10, (2, 5, 3)),),
),
}


@common.parametrize("test_data", test_data_suite)
def test_int64_tosa_FP(test_data: Tuple):
model, inputs = test_data
(
ArmTester(
model,
inputs,
common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"),
)
.export()
.to_edge_transform_and_lower()
.to_executorch()
.run_method_and_compare_outputs(inputs)
)


@common.parametrize("test_data", test_data_suite)
def test_int64_tosa_INT(test_data: Tuple):
model, inputs = test_data
(
ArmTester(model, inputs, common.get_tosa_compile_spec("TOSA-1.0+INT"))
.quantize()
.export()
.to_edge_transform_and_lower()
.to_executorch()
.run_method_and_compare_outputs(inputs)
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@ class TestT5EncoderModel(unittest.TestCase):
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
"executorch_exir_dialects_edge__ops_aten_abs_default": 1,
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 3,
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 2,
"executorch_exir_dialects_edge__ops_aten_full_like_default": 1,
"executorch_exir_dialects_edge__ops_aten_gt_Scalar": 1,
"executorch_exir_dialects_edge__ops_aten_lt_Scalar": 1,
"executorch_exir_dialects_edge__ops_aten_minimum_default": 1,
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1,
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_where_self": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
"torch.ops.higher_order.executorch_call_delegate": 3,
"torch.ops.higher_order.executorch_call_delegate": 2,
}

def _prepare_inputs(
Expand Down
19 changes: 5 additions & 14 deletions backends/arm/test/models/test_nn_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,13 @@ def forward(self, *args):
@parametrize(
"test_data",
module_tests,
xfails={
"affine_grid": "Int64 input. Partition handling fails since arange int64 output is split between 2 partitions.",
"unfold": "ValueError: Invalid TOSA graph",
"fold": "ValueError: Invalid TOSA graph",
},
)
def test_nn_functional_FP(test_data):
module, inputs = test_data
pipeline = TosaPipelineFP[input_t](
module, inputs, "", use_to_edge_transform_and_lower=False
)
pipeline.pop_stage("check.aten")
pipeline.dump_artifact("to_edge")
pipeline.pop_stage("check_count.exir")
try:
pipeline.run()
Expand All @@ -105,14 +99,11 @@ def test_nn_functional_FP(test_data):
raise e


x_fails = {
"normalize": "MLETORCH-852: Support aten.index_put.default",
"unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor",
"fold": "Int64 input && MLETORCH-827: Support aten.index_put.default",
}


@parametrize("test_data", module_tests, x_fails, strict=False)
@parametrize(
"test_data",
module_tests,
{"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"},
)
def test_nn_functional_INT(test_data):
module, inputs = test_data
pipeline = TosaPipelineINT[input_t](
Expand Down
12 changes: 12 additions & 0 deletions backends/arm/test/ops/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
OpNotSupportedPipeline,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
Expand Down Expand Up @@ -46,6 +47,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
lambda: (torch.randint(0, 10, [10], dtype=torch.int32),),
(0.0, 10.0, 1.0, torch.int32),
),
}
test_reject: dict[str, test_data_t] = {
"int32_int64": (
lambda: (torch.randint(0, 10, [10], dtype=torch.int32),),
(0.0, 10.0, 1.0, torch.int64),
Expand Down Expand Up @@ -77,6 +80,15 @@ def test_arange_start_step_tosa_FP_dtypes(test_data: test_data_t):
pipeline.run()


@common.parametrize("test_data", ArangeAdd.test_reject)
def test_arange_start_step_tosa_FP_not_delegated(test_data: test_data_t):
input_data, init_data = test_data
pipeline = OpNotSupportedPipeline[input_t](
ArangeAdd(*init_data), input_data(), non_delegated_ops={ArangeAdd.exir_op: 1}
)
pipeline.run()


@common.parametrize("test_data", ArangeAdd.test_data)
def test_arange_start_step_tosa_INT(test_data: test_data_t):
input_data, init_data = test_data
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/ops/test_ones.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def test_ones_u85_INT(test_data: test_data_t):
xfails={
"fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela",
"fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela",
"int32_int64": "MLETORCG-716: Do not delegate empty networks to vela",
},
)
def test_ones_tosa_INT_not_delegated(test_data: test_data_t):
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/ops/test_zeros.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def test_zeros_u85_INT(test_data: test_data_t):
xfails={
"fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela",
"fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela",
"int32_int64": "MLETORCG-716: Do not delegate empty networks to vela",
},
)
def test_zeros_tosa_INT_not_delegated(test_data: test_data_t):
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/tosa/dialect/ops/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ def TABLE(a, table):
raise TosaValueError(f"Table dtype {table.dtype} is not int32", op="TABLE")
return_dtype = torch.int32
else:
raise TosaValueError(f"Unsupported dtype for {tosa_spec}", op="TABLE")
raise TosaValueError(f"Unsupported dtype {a.dtype} for {tosa_spec}", op="TABLE")

return torch.empty_like(a, dtype=return_dtype)
Loading