diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 6d129af8278..004f2f1ccc9 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -116,7 +116,7 @@ def tosa_support_factory( # Negative checks: Remove nodes from partitioning negative_checks: list[OperatorSupportBase] = [ - CheckInt64Inputs(exported_program, reporter), + CheckInt64InputsAndOutputs(exported_program, reporter), CheckFloat64Inputs(exported_program, reporter), RankCheck(reporter, max_rank=5), *[ @@ -454,7 +454,18 @@ def is_node_supported( return True -class CheckInt64Inputs(OperatorSupportBase): +class CheckInt64InputsAndOutputs(OperatorSupportBase): + """TOSA does not support int64 tensors so in general, ops with int64 inputs or outputs should not be partitioned. + There are however some exceptions: + - Nodes with int64 output can be partitioned if they are constant, within int32, + and all users cast to something else. In this case, the int64 tensor can safely be cast to int32 AOT. + - Nodes with int64 output can be partitioned if all users are getitem with non-int64 output. + In this case, there are multiple outputs and the int64 ones are not used. + - Nodes with int64 inputs can be partitioned if the inputs are constant placeholders, or constant + ops fulfilling the criteria above. + Note that we don't check placeholders here, they are partitioned based on whether their users are partitioned + or not. + """ def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter @@ -465,27 +476,85 @@ def __init__( if spec.kind == InputKind.USER_INPUT ] self.reporter = reporter + self.int32_min = torch.iinfo(torch.int32).min + self.int32_max = torch.iinfo(torch.int32).max super().__init__() + def inside_int32_bounds(self, node: torch.fx.Node) -> bool: + """Node is assumed to be call_function with int64 output.""" + if isinstance(node.target, str): + return False + data = node.target(*node.args, **node.kwargs) + min_val, max_val = int(torch.min(data)), int(torch.max(data)) + return min_val >= self.int32_min and max_val <= self.int32_max + def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + vals = node.meta["val"] + tensor_list = vals if isinstance(vals, (list, tuple)) else [vals] + + any_int64 = any(tensor.dtype == torch.int64 for tensor in tensor_list) + # Don't partition nodes with int64 output... + if any_int64: + # ... Except for constant ops that are directly cast to something non-int64. + # This could be an explicit cast, or something like a less than that outputs a different dtype than the input. + users_output_non_int64 = all( + get_first_fake_tensor(output_node).dtype != torch.int64 + for output_node in node.users + ) + if ( + node.target in ComputeConstantOpsAOT.targeted_ops + and users_output_non_int64 + ): + if not self.inside_int32_bounds(node): + self.reporter.report_reject( + node, "Constant node outside int32 range." + ) + return False + # Will never have input nodes, safe to return True + return True + + # ... Or ops with multiple outputs where only non-int64 are used. + users_are_getitem = all( + user.target == operator.getitem for user in node.users + ) + if users_are_getitem and users_output_non_int64: + # Passed output check, go to input check. + pass + else: + self.reporter.report_reject( + node, "Non-constant node with int64 output." + ) + return False + + # Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned. + # If it is not partitioned, the partition will get an int64 input and fail. for input_node in node.all_input_nodes: - # We can cast constant placeholders and constant ops AOT, such int64 are ok. - # Otherwise, don't partition if one or more inputs are int64. + tensor_in = get_first_fake_tensor(input_node) + if tensor_in.dtype != torch.int64: + continue + # Constant placeholder if ( - input_node.name in self.input_names - or not input_node.op == "placeholder" + input_node.op != "call_function" + and input_node.name not in self.input_names ): - tensor = get_first_fake_tensor(input_node) - if tensor.dtype == torch.int64: - if input_node.target not in ComputeConstantOpsAOT.targeted_ops: - self.reporter.report_reject( - node, - f"Had int64 input {input_node.name} that couldn't be handled.", - ) - return False + continue + # Constant operator + if input_node.op == "call_function": + if input_node.target in ComputeConstantOpsAOT.targeted_ops: + # This is not perfect since the input_node can still be rejected by other checks but + # this should cover the majority of cases. + if self.is_node_supported( + None, input_node # type: ignore[arg-type] #(we don't use 'submodules') + ): + continue + self.reporter.report_reject( + node, f"Non-constant int64 input {input_node.name}" + ) + return False + return True diff --git a/backends/arm/test/misc/test_int64.py b/backends/arm/test/misc/test_int64.py new file mode 100644 index 00000000000..d6d6d6cb39c --- /dev/null +++ b/backends/arm/test/misc/test_int64.py @@ -0,0 +1,116 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + + +class ConstAdd(torch.nn.Module): + def __init__(self, dtype: torch.dtype, bias=0): + super().__init__() + self.dtype = dtype + self.bias = bias + + def forward(self, x: torch.Tensor): + c = torch.arange(self.bias, self.bias + 10, 1, dtype=self.dtype) + # Add explicit float cast to make quantization work, will be inserted by type promotion otherwise. + return x + c.to(torch.float32) + + +class BufferAdd(torch.nn.Module): + def __init__(self, dtype: torch.dtype, bias=0): + super().__init__() + self.dtype = dtype + self.buffer = torch.arange(0, 10, 1, dtype=self.dtype) + bias + self.bias = bias + + def forward(self, x: torch.Tensor): + c = self.buffer + # Add explicit float cast to make quantization work, will be inserted by type promotion otherwise. + return x + c.to(torch.float32) + + +class ConstChainAdd(torch.nn.Module): + def __init__(self, dtype: torch.dtype): + super().__init__() + self.dtype = dtype + + def forward(self, x: torch.Tensor): + c = torch.arange(0, 10, 1, dtype=self.dtype).reshape((2, 5)).unsqueeze(-1) + # Add explicit float cast to make quantization work, will be inserted by type promotion otherwise. + return x + c.to(torch.float32) + + +class BufferChainAdd(torch.nn.Module): + def __init__(self, dtype: torch.dtype): + super().__init__() + self.dtype = dtype + self.buffer = torch.arange(0, 10, 1, dtype=self.dtype) + + def forward(self, x: torch.Tensor): + c = self.buffer.reshape((2, 5)).unsqueeze(-1) + # Add explicit float cast to make quantization work, will be inserted by type promotion otherwise. + return x + c.to(torch.float32) + + +test_data_suite = { + "fp32_in+int64_buffer": (BufferAdd(torch.int64), (torch.rand(10) - 0.5,)), + "fp32_in+int64_buffer_overflow": ( + BufferAdd(torch.int64, 2**40), + (torch.rand(10) - 0.5,), + ), + "fp32_in+int64_const": (ConstAdd(torch.int64), (torch.rand(10) - 0.5,)), + "fp32_in+int64_const_overflow": ( + ConstAdd(torch.int64, 2**40), + (torch.rand(10) - 0.5,), + ), + "int64_in+float_const": ( + ConstAdd(torch.float32), + (torch.randint(0, 10, (10,)),), + ), + "fp32_in+int64_buffer_chain": ( + BufferChainAdd(torch.int64), + (torch.rand(2, 5, 3) - 0.5,), + ), + "fp32_in+int64_const_chain": ( + ConstChainAdd(torch.int64), + (torch.rand(2, 5, 3) - 0.5,), + ), + "int64_in+float_const_chain": ( + ConstChainAdd(torch.float32), + (torch.randint(0, 10, (2, 5, 3)),), + ), +} + + +@common.parametrize("test_data", test_data_suite) +def test_int64_tosa_FP(test_data: Tuple): + model, inputs = test_data + ( + ArmTester( + model, + inputs, + common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"), + ) + .export() + .to_edge_transform_and_lower() + .to_executorch() + .run_method_and_compare_outputs(inputs) + ) + + +@common.parametrize("test_data", test_data_suite) +def test_int64_tosa_INT(test_data: Tuple): + model, inputs = test_data + ( + ArmTester(model, inputs, common.get_tosa_compile_spec("TOSA-1.0+INT")) + .quantize() + .export() + .to_edge_transform_and_lower() + .to_executorch() + .run_method_and_compare_outputs(inputs) + ) diff --git a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py index aba58379a92..0567d32eebb 100644 --- a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py +++ b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py @@ -31,16 +31,18 @@ class TestT5EncoderModel(unittest.TestCase): "executorch_exir_dialects_edge__ops_aten__to_copy_default": 2, "executorch_exir_dialects_edge__ops_aten_abs_default": 1, "executorch_exir_dialects_edge__ops_aten_add_Tensor": 3, + "executorch_exir_dialects_edge__ops_aten_arange_start_step": 2, "executorch_exir_dialects_edge__ops_aten_full_like_default": 1, "executorch_exir_dialects_edge__ops_aten_gt_Scalar": 1, "executorch_exir_dialects_edge__ops_aten_lt_Scalar": 1, "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, "executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2, "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, "executorch_exir_dialects_edge__ops_aten_where_self": 1, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3, - "torch.ops.higher_order.executorch_call_delegate": 3, + "torch.ops.higher_order.executorch_call_delegate": 2, } def _prepare_inputs( diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index c1a9f312d85..4896074b544 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -81,11 +81,6 @@ def forward(self, *args): @parametrize( "test_data", module_tests, - xfails={ - "affine_grid": "Int64 input. Partition handling fails since arange int64 output is split between 2 partitions.", - "unfold": "ValueError: Invalid TOSA graph", - "fold": "ValueError: Invalid TOSA graph", - }, ) def test_nn_functional_FP(test_data): module, inputs = test_data @@ -93,7 +88,6 @@ def test_nn_functional_FP(test_data): module, inputs, "", use_to_edge_transform_and_lower=False ) pipeline.pop_stage("check.aten") - pipeline.dump_artifact("to_edge") pipeline.pop_stage("check_count.exir") try: pipeline.run() @@ -105,14 +99,11 @@ def test_nn_functional_FP(test_data): raise e -x_fails = { - "normalize": "MLETORCH-852: Support aten.index_put.default", - "unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor", - "fold": "Int64 input && MLETORCH-827: Support aten.index_put.default", -} - - -@parametrize("test_data", module_tests, x_fails, strict=False) +@parametrize( + "test_data", + module_tests, + {"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"}, +) def test_nn_functional_INT(test_data): module, inputs = test_data pipeline = TosaPipelineINT[input_t]( diff --git a/backends/arm/test/ops/test_arange.py b/backends/arm/test/ops/test_arange.py index ede00768f52..33cca542922 100644 --- a/backends/arm/test/ops/test_arange.py +++ b/backends/arm/test/ops/test_arange.py @@ -12,6 +12,7 @@ from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, + OpNotSupportedPipeline, TosaPipelineFP, TosaPipelineINT, VgfPipeline, @@ -46,6 +47,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: lambda: (torch.randint(0, 10, [10], dtype=torch.int32),), (0.0, 10.0, 1.0, torch.int32), ), + } + test_reject: dict[str, test_data_t] = { "int32_int64": ( lambda: (torch.randint(0, 10, [10], dtype=torch.int32),), (0.0, 10.0, 1.0, torch.int64), @@ -77,6 +80,15 @@ def test_arange_start_step_tosa_FP_dtypes(test_data: test_data_t): pipeline.run() +@common.parametrize("test_data", ArangeAdd.test_reject) +def test_arange_start_step_tosa_FP_not_delegated(test_data: test_data_t): + input_data, init_data = test_data + pipeline = OpNotSupportedPipeline[input_t]( + ArangeAdd(*init_data), input_data(), non_delegated_ops={ArangeAdd.exir_op: 1} + ) + pipeline.run() + + @common.parametrize("test_data", ArangeAdd.test_data) def test_arange_start_step_tosa_INT(test_data: test_data_t): input_data, init_data = test_data diff --git a/backends/arm/test/ops/test_ones.py b/backends/arm/test/ops/test_ones.py index 18204a8eaaa..ca2313b11a0 100644 --- a/backends/arm/test/ops/test_ones.py +++ b/backends/arm/test/ops/test_ones.py @@ -106,7 +106,6 @@ def test_ones_u85_INT(test_data: test_data_t): xfails={ "fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela", "fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela", - "int32_int64": "MLETORCG-716: Do not delegate empty networks to vela", }, ) def test_ones_tosa_INT_not_delegated(test_data: test_data_t): diff --git a/backends/arm/test/ops/test_zeros.py b/backends/arm/test/ops/test_zeros.py index a1cf39c906f..88ac62cea1c 100644 --- a/backends/arm/test/ops/test_zeros.py +++ b/backends/arm/test/ops/test_zeros.py @@ -106,7 +106,6 @@ def test_zeros_u85_INT(test_data: test_data_t): xfails={ "fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela", "fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela", - "int32_int64": "MLETORCG-716: Do not delegate empty networks to vela", }, ) def test_zeros_tosa_INT_not_delegated(test_data: test_data_t): diff --git a/backends/arm/tosa/dialect/ops/table.py b/backends/arm/tosa/dialect/ops/table.py index 5fbbf55f910..3faf478893e 100644 --- a/backends/arm/tosa/dialect/ops/table.py +++ b/backends/arm/tosa/dialect/ops/table.py @@ -48,6 +48,6 @@ def TABLE(a, table): raise TosaValueError(f"Table dtype {table.dtype} is not int32", op="TABLE") return_dtype = torch.int32 else: - raise TosaValueError(f"Unsupported dtype for {tosa_spec}", op="TABLE") + raise TosaValueError(f"Unsupported dtype {a.dtype} for {tosa_spec}", op="TABLE") return torch.empty_like(a, dtype=return_dtype)