diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 6b57c3d9658..4f3ca0830c9 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -58,6 +58,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.mm.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 7b94bfa837d..40eaf598948 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -28,6 +28,7 @@ op_quant, op_relu, op_repeat, + op_rsqrt, op_sigmoid, op_slice, op_softmax, diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py new file mode 100644 index 00000000000..e8f3394ea2c --- /dev/null +++ b/backends/arm/operators/op_rsqrt.py @@ -0,0 +1,70 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import numpy as np +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_quant_utils import ( + dequantize_value, + get_quant_node_args, + QuantArgs, + quantize_value, +) +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class RsqrtVisitor(NodeVisitor): + target = "aten.rsqrt.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + if is_quant_node: + # Assume quantized input is 8 bit. + # Create attribute for 8 bit table lookup. + input_node = node.all_input_nodes[0] + in_quantargs = get_quant_node_args(input_node) + output_node = list(node.users)[0] + out_quantargs = get_quant_node_args(output_node) + table = rsqrt_table_8bit(in_quantargs, out_quantargs) + table_attr = ts.TosaSerializerAttribute() + table_attr.TableAttribute(table) + tosa_graph.addOperator( + TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr + ) + else: + tosa_graph.addOperator(TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) + + +def rsqrt_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): + """ + Returns a table mapping 256 entries to rqsrt([qmin,qmax]) + Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_rsqrt + """ + + def rqsrt(x): + # Convert quantized input to floating point rqsrt input space. + v = dequantize_value(x, in_quantargs) + # Compute rqsrt. + v = 1 / np.sqrt(v) + # Convert rqsrt output back to quantized space. + return quantize_value(v, out_quantargs) + + return [ + rqsrt(x) + for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) + ] diff --git a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py b/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py index 8d507c11ef3..3a189c0d8f1 100644 --- a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py @@ -35,7 +35,11 @@ def _annotate_one_to_one( Typical ops are ops implemented with a lookup table. """ annotated_partitions = [] - one_to_one_ops = (torch.ops.aten.exp.default, torch.ops.aten.log.default) + one_to_one_ops = { + torch.ops.aten.exp.default, + torch.ops.aten.log.default, + torch.ops.aten.rsqrt.default, + } for node in gm.graph.nodes: if node.op != "call_function" or node.target not in one_to_one_ops: continue diff --git a/backends/arm/test/ops/test_rsqrt.py b/backends/arm/test/ops/test_rsqrt.py new file mode 100644 index 00000000000..2ccb7ec9916 --- /dev/null +++ b/backends/arm/test/ops/test_rsqrt.py @@ -0,0 +1,107 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Tests the rsqrt op. +# + +import unittest + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +class TestRsqrt(unittest.TestCase): + class Rsqrt(torch.nn.Module): + test_parameters = [ + (torch.ones(1, 10, 10, 10),), + (torch.rand(1, 10, 10, 10),), + (torch.rand(1, 5, 10, 20),), + (torch.rand(5, 10, 20),), + ] + + def forward(self, x: torch.Tensor): + return x.rsqrt() + + def _test_rsqrt_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check_count({"torch.ops.aten.rsqrt.default": 1}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_rsqrt_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.rsqrt.default": 1}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_rsqrt_ethosu_BI_pipeline( + self, + compile_spec: CompileSpec, + module: torch.nn.Module, + test_data: tuple[torch.Tensor], + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check_count({"torch.ops.aten.rsqrt.default": 1}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(Rsqrt.test_parameters) + def test_rsqrt_tosa_MI(self, test_tensor: torch.Tensor): + self._test_rsqrt_tosa_MI_pipeline(self.Rsqrt(), (test_tensor,)) + + @parameterized.expand(Rsqrt.test_parameters) + def test_rsqrt_tosa_BI(self, test_tensor: torch.Tensor): + self._test_rsqrt_tosa_BI_pipeline(self.Rsqrt(), (test_tensor,)) + + @parameterized.expand(Rsqrt.test_parameters) + def test_rsqrt_u55_BI(self, test_tensor: torch.Tensor): + self._test_rsqrt_ethosu_BI_pipeline( + common.get_u55_compile_spec(), self.Rsqrt(), (test_tensor,) + ) + + @parameterized.expand(Rsqrt.test_parameters) + def test_rsqrt_u85_BI(self, test_tensor: torch.Tensor): + self._test_rsqrt_ethosu_BI_pipeline( + common.get_u85_compile_spec(), self.Rsqrt(), (test_tensor,) + )