From 1e3f84e40ac85ecaed96485ff63c6d341a2c0bb1 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 11 Jun 2024 09:37:12 +0200 Subject: [PATCH] Add sigmoid operator to Arm backend Implemented node visitor, annotator, and test TOSA MI and BI passes, U55 BI fails on compilation. Change-Id: I229f3ceb8b2edeccdc23003be8aefe20f327d835 Signed-off-by: Erik Lundell --- backends/arm/arm_partitioner.py | 1 + backends/arm/operators/__init__.py | 3 +- backends/arm/operators/op_sigmoid.py | 82 ++++++++++ backends/arm/quantizer/arm_quantizer.py | 10 +- .../quantization_annotation/__init__.py | 1 + .../sigmoid_annotator.py | 54 +++++++ backends/arm/quantizer/quantization_config.py | 23 ++- backends/arm/test/ops/test_sigmoid.py | 152 ++++++++++++++++++ backends/arm/tosa_quant_utils.py | 14 ++ 9 files changed, 337 insertions(+), 3 deletions(-) create mode 100644 backends/arm/operators/op_sigmoid.py create mode 100644 backends/arm/quantizer/quantization_annotation/sigmoid_annotator.py create mode 100644 backends/arm/test/ops/test_sigmoid.py diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 8628399e084..b7a672cbd7d 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -42,6 +42,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 3566171f54a..c72a3b7d6fb 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,6 +18,7 @@ op_mean_dim, op_permute, op_quant, + op_sigmoid, op_softmax, op_view, ) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py new file mode 100644 index 00000000000..884c803482b --- /dev/null +++ b/backends/arm/operators/op_sigmoid.py @@ -0,0 +1,82 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import numpy as np + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg + +from executorch.backends.arm.tosa_quant_utils import ( + dequantize_value, + get_quant_node_args, + QuantArgs, + quantize_value, +) +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class SigmoidVisitor(NodeVisitor): + target = "aten.sigmoid.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + assert len(node.all_input_nodes) == 1 + assert len(node.users) == 1 + + if is_quant_node: + # Assume quantized input is 8 bit. + + # Create attribute for 8 bit table lookup. + input_node = node.all_input_nodes[0] + in_quantargs = get_quant_node_args(input_node) + output_node = list(node.users)[0] + out_quantargs = get_quant_node_args(output_node) + + table = sigmoid_table_8bit(in_quantargs, out_quantargs) + table_attr = ts.TosaSerializerAttribute() + table_attr.TableAttribute(table) + + tosa_graph.addOperator( + TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr + ) + else: + tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) + + +def sigmoid_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): + """ + Returns a table mapping 256 entries to sigmoid([qmin,qmax]) + Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_sigmoid + """ + + def sigmoid(x): + # Convert quantized input to floating point sigmoid input space. + v = dequantize_value(x, in_quantargs) + # Compute sigmoid. + v = 1.0 / (1.0 + np.exp(-v)) + # Convert sigmoid output back to quantized space. + return quantize_value(v, out_quantargs) + + return [ + sigmoid(x) + for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) + ] diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 9926e1d8d0e..1b2ba1a7038 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -241,7 +241,15 @@ class ArmQuantizer(Quantizer): # A list of supported static quantization ops (both PTQ and QAT) # The name must match the name used when registering the annotator. # Preserve the order that fusions come before singular ops - STATIC_OPS = ["linear", "conv", "adaptive_avg_pool2d", "max_pool2d", "add", "mul"] + STATIC_OPS = [ + "linear", + "conv", + "adaptive_avg_pool2d", + "max_pool2d", + "add", + "mul", + "sigmoid", + ] def __init__(self): super().__init__() diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index d5499b052d1..5e372fe5b74 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -53,4 +53,5 @@ def decorator(annotator: AnnotatorType): linear_annotator, max_pool2d_annotator, mul_annotator, + sigmoid_annotator, ) diff --git a/backends/arm/quantizer/quantization_annotation/sigmoid_annotator.py b/backends/arm/quantizer/quantization_annotation/sigmoid_annotator.py new file mode 100644 index 00000000000..eb5fe9d482a --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/sigmoid_annotator.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional + +import torch +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) +from torch.fx import Node + + +@register_annotator("sigmoid") +def _annotate_sigmoid( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + + # input/ output range of sigmoid is always same -> quantize with fixed qspec. + # this configuration maps input: (-128, 127) -> (-6.0, 5.95). Outside these bounds, sigmoid ~= const. + # output: (-1,0.99) -> (-128, 127). Sigmoid has output value range (-1,1) + # Note that this exact choice is somewhat arbitrary. + + input_act_qspec = quantization_config.get_fixed_qspec(scale=6 / 128, zp=0) + output_act_qspec = quantization_config.get_fixed_qspec(scale=1 / 128, zp=0) + + for node in gm.graph.nodes: + if node.op != "call_function" or node.target != torch.ops.aten.sigmoid.default: + continue + if filter_fn and not filter_fn(node): + continue + input_node = node.args[0] + + if not arm_quantizer_utils.is_annotated([node]): + _annotate_input_qspec_map( + node, + input_node, + input_act_qspec, + ) + _annotate_output_qspec(node, output_act_qspec) + + arm_quantizer_utils.mark_nodes_as_annotated([node]) + annotated_partitions.append([node]) + + return annotated_partitions diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 64247b53d40..f94c3e18da6 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -8,7 +8,10 @@ import torch -from torch.ao.quantization.quantizer import QuantizationSpec +from torch.ao.quantization.quantizer import ( + FixedQParamsQuantizationSpec, + QuantizationSpec, +) @dataclass(eq=True, frozen=True) @@ -56,3 +59,21 @@ def get_bias_qspec(self) -> QuantizationSpec | None: self.bias.dtype == torch.float ), "Only float dtype for bias is supported for bias right now" return self.bias + + def get_fixed_qspec( + self, + scale: float, + zp: int, + dtype: torch.dtype = torch.int8, + quant_min: int = -128, + quant_max: int = 127, + ) -> FixedQParamsQuantizationSpec: + """Returns a new FixedQParamsQuantizationSpec with the given parameters.""" + return FixedQParamsQuantizationSpec( + dtype=dtype, + qscheme=torch.per_tensor_affine, + scale=scale, + zero_point=zp, + quant_min=quant_min, + quant_max=quant_max, + ) diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py new file mode 100644 index 00000000000..7a0435689f4 --- /dev/null +++ b/backends/arm/test/ops/test_sigmoid.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +test_data_suite = [ + # (test_name, test_data) + ("zeros", torch.zeros(10, 10, 10, 10)), + ("ones", torch.ones(10, 10, 10)), + ("rand", torch.rand(10, 10) - 0.5), + ("randn_pos", torch.randn(10) + 10), + ("randn_neg", torch.randn(10) - 10), + ("ramp", torch.arange(-16, 16, 0.2)), +] + + +class TestSigmoid(unittest.TestCase): + class Sigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + return self.sigmoid(x) + + class AddSigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + return self.sigmoid(x + x) + + class SigmoidAdd(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + return x + self.sigmoid(x) + + class SigmoidAddSigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y): + return self.sigmoid((self.sigmoid(y) + self.sigmoid(x))) + + def _test_sigmoid_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check(["torch.ops.aten.sigmoid.default"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .check(["torch.ops.aten.sigmoid.default"]) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_sigmoid_tosa_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.sigmoid.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(test_data_suite) + def test_sigmoid_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + ): + self._test_sigmoid_tosa_MI_pipeline(self.Sigmoid(), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_sigmoid_tosa_BI(self, test_name: str, test_data: torch.Tensor): + self._test_sigmoid_tosa_BI_pipeline(self.Sigmoid(), (test_data,)) + + def test_add_sigmoid_tosa_BI(self): + self._test_sigmoid_tosa_BI_pipeline(self.AddSigmoid(), (test_data_suite[0][1],)) + + def test_sigmoid_add_tosa_BI(self): + self._test_sigmoid_tosa_BI_pipeline(self.SigmoidAdd(), (test_data_suite[0][1],)) + + def test_sigmoid_add_sigmoid_tosa_BI(self): + self._test_sigmoid_tosa_BI_pipeline( + self.SigmoidAddSigmoid(), (test_data_suite[4][1], test_data_suite[3][1]) + ) + + # Fails due to Vela diff from Tosa spec, expected to work with Regor. + @parameterized.expand(test_data_suite) + @unittest.expectedFailure + def test_sigmoid_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): + self._test_sigmoid_tosa_u55_BI_pipeline(self.Sigmoid(), (test_data,)) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 5cb5b6fe31a..f9c6b990730 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -8,6 +8,8 @@ import math from typing import NamedTuple +import numpy as np + import serializer.tosa_serializer as ts import torch.fx from executorch.backends.arm.tosa_mapping import TosaArg @@ -26,6 +28,18 @@ class QuantArgs(NamedTuple): qmax: int +def quantize_value(x, qargs: QuantArgs, dtype=np.int8): + return np.clip( + np.round(x / qargs.scale) + qargs.zp, + qargs.qmin, + qargs.qmax, + ).astype(dtype) + + +def dequantize_value(qx, qargs: QuantArgs): + return (qx - qargs.zp) * qargs.scale + + def is_quant_node(node: torch.fx.Node): consumer_node = list(node.users)[0] input = node.all_input_nodes[0]