From c711ad0d50c0feed889f1c7f53474487b2d52aab Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Thu, 24 Apr 2025 16:41:53 +0200 Subject: [PATCH] NXP backend: Add support for Sigmoid operator conversion --- .../nxp/backend/edge_program_converter.py | 1 + .../ops_converters/__init__.py | 4 + .../ops_converters/relu_converter.py | 2 + .../ops_converters/sigmoid_converter.py | 42 ++++++++++ backends/nxp/neutron_partitioner.py | 1 + backends/nxp/quantizer/neutron_quantizer.py | 2 + backends/nxp/quantizer/patterns.py | 58 +++++++++----- .../node_converter/test_sigmoid_converter.py | 76 +++++++++++++++++++ backends/nxp/tests/models.py | 17 +++++ 9 files changed, 185 insertions(+), 18 deletions(-) create mode 100644 backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py create mode 100644 backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index a73c4af347e..1e930d37a6a 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -39,6 +39,7 @@ exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 + exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 } diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py index 25954b71595..8a0498810ce 100755 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py @@ -46,6 +46,9 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.relu_converter import ( ReLUConverter, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.sigmoid_converter import ( + SigmoidConverter, +) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import ( SoftmaxConverter, ) @@ -72,4 +75,5 @@ "AbsConverter", "AdaptiveAvgPool2dConverter", "HardTanhConverter", + "SigmoidConverter", ] diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py index 5835667671f..d1af0ec2de5 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/relu_converter.py @@ -25,6 +25,8 @@ def _is_supported_in_IR( return True def convert(self, node: Node): + self.assert_convertible(node) + t_op = self._create_tflite_op_with_io_tensors(node) t_op.opcode_index = self.builder.op_code_index_for_op_type(BuiltinOperator.RELU) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py new file mode 100644 index 00000000000..a9af12f60dd --- /dev/null +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sigmoid_converter.py @@ -0,0 +1,42 @@ +# Copyright 2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.nxp.backend.ir.converter.node_converter import ( + NodeConverter, + Target, +) +from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( + BuiltinOperator, +) +from torch.fx import Node +from torch.nn import Parameter + + +class SigmoidConverter(NodeConverter): + @staticmethod + def _is_supported_on_target(target: Target) -> bool: + match target: + case Target.RT700: + return True + + case _: + return False + + @staticmethod + def _is_supported_in_IR( + node: Node, parameters_mapping: dict[str, Parameter] + ) -> bool: + return True + + def convert(self, node: Node): + self.assert_convertible(node) + + t_op = self._create_tflite_op_with_io_tensors(node) + t_op.opcode_index = self.builder.op_code_index_for_op_type( + BuiltinOperator.LOGISTIC + ) + + self.builder.append_operators([t_op]) diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 67d5d6f1f5d..952946ae26d 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -203,6 +203,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]): exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 + exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 } diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index b2fe2c9bbac..7566da61c8d 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -32,6 +32,7 @@ ReluPattern, ReshapePattern, SharedSpecPattern, + SigmoidPattern, SoftMaxPattern, ViewPattern, ) @@ -217,6 +218,7 @@ def __init__(self): NeutronAtenQuantizer(ReluPattern(), static_qconfig), NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig), NeutronAtenQuantizer(ReshapePattern(), static_qconfig), + NeutronAtenQuantizer(SigmoidPattern(), static_qconfig), NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig), NeutronAtenQuantizer(ViewPattern(), static_qconfig), ] diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 5d1351ac303..35649f0c0fc 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -408,6 +408,31 @@ def partition_types(self): return [torch.ops.aten.view.default] +def get_anchors_for_softmax_like_operators( + fused_partition: List[fx.GraphModule], +) -> PartitionAnchors: + node = fused_partition[0].nodes[-1] + assert len(fused_partition[0].input_nodes) == 1 + + qspec = FixedQParamsQuantizationSpec( + dtype=torch.int8, + scale=1.0 / 256.0, + zero_point=-128, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + ) + + return PartitionAnchors( + inputs=[(node, 0)], + weights=[], + biases=[], + output=[ + (node, qspec), + ], + ) + + class SoftMaxPattern(QuantizationPattern): """ Quantizer for Softmax operator. @@ -421,23 +446,20 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] ) -> PartitionAnchors: - node = fused_partition[0].nodes[-1] - assert len(fused_partition[0].input_nodes) == 1 + return get_anchors_for_softmax_like_operators(fused_partition) - qspec = FixedQParamsQuantizationSpec( - dtype=torch.int8, - scale=1.0 / 256.0, - zero_point=-128, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - ) - return PartitionAnchors( - inputs=[(node, 0)], - weights=[], - biases=[], - output=[ - (node, qspec), - ], - ) +class SigmoidPattern(QuantizationPattern): + """ + Quantizer for Sigmoid operator. + + The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8. + """ + + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.sigmoid.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + return get_anchors_for_softmax_like_operators(fused_partition) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py new file mode 100644 index 00000000000..9139dd97f9a --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py @@ -0,0 +1,76 @@ +# Copyright 2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import pytest +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + ToNCHWPreprocess, + ToNHWCPreprocess, +) +from executorch.backends.nxp.tests.models import ConvWithSigmoid +from torch import nn +from torch.export import ExportedProgram + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(23) + np.random.seed(23) + + +def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): + model = ConvWithSigmoid(conv_in_channels=input_shape[1]) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + to_quantized_edge_program(model, input_shape).exported_program() + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape) * 50).astype(np.int8) + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + input_data=input_data, + atol=1.0, + ) + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((10,), id="Scalar"), + pytest.param((10, 25), id="1D"), + pytest.param((10, 25, 25), id="2D"), + pytest.param((10, 3, 25, 25), id="3D"), + pytest.param((10, 3, 25, 25, 25), id="4D"), + ], +) +def test_sigmoid_only(mocker, input_shape): + model = nn.Sigmoid() + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + to_quantized_edge_program(model, input_shape).exported_program() + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape) * 50).astype(np.int8) + convert_run_compare( + exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data + ) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index e1e4896a38f..3aafab36a95 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -85,6 +85,23 @@ def forward(self, x): return self.softmax(x) +class ConvWithSigmoid(torch.nn.Module): + def __init__(self, conv_in_channels: int = 3): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=conv_in_channels, + out_channels=3, + kernel_size=(2, 2), + stride=(2, 2), + ), + torch.nn.Sigmoid(), + ) + + def forward(self, x): + return self.block(x) + + class LinearModule(torch.nn.Module): def __init__(self, bias: bool): super().__init__()