From 86a6ff9de7aa457316f55a12864416019cca94df Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Mon, 11 Nov 2024 15:37:23 +0000 Subject: [PATCH] Add support for torch.ops.aten._to_copy.default Lower torch.ops.aten._to_copy.default to TOSA CAST op. This resolves issues around arithmetic operators when using int scalars in unquantized networks (see new test cases in test_scalars.py). Note: Parameter 'memory_format' is not supported. Change-Id: I7a921ca510c5b46f15b5399218f9230ba0f93d88 --- backends/arm/operator_support/__init__.py | 1 + .../arm/operator_support/to_copy_support.py | 120 ++++++++++++++++++ backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_to_copy.py | 43 +++++++ backends/arm/test/ops/test_scalars.py | 16 ++- backends/arm/test/ops/test_to_copy.py | 70 ++++++++++ 6 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 backends/arm/operator_support/to_copy_support.py create mode 100644 backends/arm/operators/op_to_copy.py create mode 100644 backends/arm/test/ops/test_to_copy.py diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index c133ce8003a..297047963c6 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -8,6 +8,7 @@ from . import ( # noqa mean_dim_support, right_shift_support, + to_copy_support, tosa_supported_operators, var_correction_support, ) diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py new file mode 100644 index 00000000000..9bba274804e --- /dev/null +++ b/backends/arm/operator_support/to_copy_support.py @@ -0,0 +1,120 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import torch + +import torch.fx as fx + +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) + + +@register_tosa_support_check +class ToCopySupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten._to_copy.default] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + SupportedTypeDict = dict[torch.dtype, list[torch.dtype]] + + @staticmethod + def _merge_supported_types( + dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict + ) -> SupportedTypeDict: + merged_dtypes = dtypes1 + for k, v in dtypes2.items(): + merged_dtypes[k] = merged_dtypes.get(k, []) + v + return merged_dtypes + + SUPPORTED_INT_TYPES: SupportedTypeDict = { + torch.bool: [torch.int8, torch.int16, torch.int32], + torch.int8: [torch.bool, torch.int16, torch.int32], + torch.int16: [torch.bool, torch.int8, torch.int32], + torch.int32: [torch.bool, torch.int8, torch.int16], + } + SUPPORTED_FLOAT_TYPES: SupportedTypeDict = { + torch.int8: [torch.float16, torch.bfloat16, torch.float32], + torch.int16: [torch.float16, torch.bfloat16, torch.float32], + torch.int32: [torch.float16, torch.bfloat16, torch.float32], + torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32], + torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32], + torch.float32: [ + torch.int8, + torch.int16, + torch.int32, + torch.bfloat16, + torch.float16, + ], + } + ALL_SUPPORTED_TYPES = _merge_supported_types( + SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES + ) + POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32} + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + assert node.target in self.targets + + if tosa_spec not in self.tosa_specs: + return False + + assert tosa_spec.support_integer() + supported_dtypes = ( + self.ALL_SUPPORTED_TYPES + if tosa_spec.support_float() + else self.SUPPORTED_INT_TYPES + ) + # Take into account possible type conversions + supported_dtypes.update( + (k, supported_dtypes[v]) + for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items() + if v in supported_dtypes + ) + + # Check input type + assert len(node.all_input_nodes) == 1 + input_val = node.all_input_nodes[0].meta["val"] + assert isinstance(input_val, torch._subclasses.FakeTensor) + input_dtype = input_val.dtype + if input_dtype not in supported_dtypes: + logger.info( + f"Input dtype {input_val.dtype} is not supported in " + f"{node.target.name()}." + ) + return False + + # Check output type + output_val = node.meta["val"] + assert isinstance(output_val, torch._subclasses.FakeTensor) + if output_val.dtype not in supported_dtypes[input_dtype]: + logger.info( + f"Output dtype {output_val.dtype} is not supported in " + f"{node.target.name()} for input dtype {input_dtype}. " + f"Supported output types: " + f"{''.join(str(t) for t in supported_dtypes[input_dtype])}" + ) + return False + + # Check memory format + if "memory_format" in node.kwargs: + if node.kwargs["memory_format"] in (torch.preserve_format,): + logger.info( + f"Argument 'memory_format' is not supported for " + f"{node.target.name()} right now." + ) + return False + + return True diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a5c2dd8dc5f..8c4aa85e579 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -36,6 +36,7 @@ op_sub, op_sum, op_tanh, + op_to_copy, op_transpose, op_unsqueeze, op_upsample_nearest2d, diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py new file mode 100644 index 00000000000..15077d6df77 --- /dev/null +++ b/backends/arm/operators/op_to_copy.py @@ -0,0 +1,43 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts +import torch +import tosa.Op as TosaOp + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg + + +@register_node_visitor +class ToCopyVisitor(NodeVisitor): + """ + Implement the type cast functionality of _to_copy. + + Other features like setting of the memory_format or moving a tensor to a + different device are not supported. + + Also note that the node should not be quantized. + """ + + target = "aten._to_copy.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert not is_quant_node, "Casting of quantized values is not supported." + assert inputs + tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index cd3dd72f608..455b484b948 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -153,9 +153,21 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): .run_method_and_compare_outputs(inputs=test_data) ) - # Most MI tests fail, just show one working for now. - @parameterized.expand((tensor_scalar_tests[6],)) + @parameterized.expand(tensor_scalar_tests) def test_MI(self, test_name: str, op: torch.nn.Module, x, y): + expected_exception = None + if any(token in test_name for token in ("Sub_int", "Sub__int")): + expected_exception = RuntimeError + elif test_name.endswith("_st"): + expected_exception = AttributeError + + if expected_exception: + with self.assertRaises( + expected_exception, msg=f"Test {test_name} is expected to fail." + ): + self._test_add_tosa_MI_pipeline(op, (x, y)) + return + self._test_add_tosa_MI_pipeline(op, (x, y)) # op(Scalar float, tensor) works if the scalar is constant. diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py new file mode 100644 index 00000000000..8499512e10d --- /dev/null +++ b/backends/arm/test/ops/test_to_copy.py @@ -0,0 +1,70 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Tests the _to_copy op which is interpreted as a cast for our purposes. +# + +import unittest + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from parameterized import parameterized + + +class Cast(torch.nn.Module): + def __init__(self, target_dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor): + return x.to(dtype=self.target_dtype) + + +class TestToCopy(unittest.TestCase): + """ + Tests the _to_copy operation. + + Only test unquantized graphs as explicit casting of dtypes messes with the + quantization. + + Note: This is also covered by test_scalars.py. + """ + + _TO_COPY_TEST_DATA = ( + (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.float32), + (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.float16), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.float32), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.int32), + (torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), torch.int8), + ) + + def _test_to_copy_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: torch.Tensor + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .dump_artifact() + .check_count({"torch.ops.aten._to_copy.default": 1}) + .to_edge() + .dump_artifact() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + @parameterized.expand(_TO_COPY_TEST_DATA) + def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_dtype): + self._test_to_copy_tosa_MI_pipeline(Cast(new_dtype), (test_tensor,))