Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.relu_converter import (
ReLUConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.sigmoid_converter import (
SigmoidConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import (
SoftmaxConverter,
)
Expand All @@ -72,4 +75,5 @@
"AbsConverter",
"AdaptiveAvgPool2dConverter",
"HardTanhConverter",
"SigmoidConverter",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def _is_supported_in_IR(
return True

def convert(self, node: Node):
self.assert_convertible(node)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change, but OK.


t_op = self._create_tflite_op_with_io_tensors(node)
t_op.opcode_index = self.builder.op_code_index_for_op_type(BuiltinOperator.RELU)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.backend.ir.converter.node_converter import (
NodeConverter,
Target,
)
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
BuiltinOperator,
)
from torch.fx import Node
from torch.nn import Parameter


class SigmoidConverter(NodeConverter):
@staticmethod
def _is_supported_on_target(target: Target) -> bool:
match target:
case Target.RT700:
return True

case _:
return False

@staticmethod
def _is_supported_in_IR(
node: Node, parameters_mapping: dict[str, Parameter]
) -> bool:
return True

def convert(self, node: Node):
self.assert_convertible(node)

t_op = self._create_tflite_op_with_io_tensors(node)
t_op.opcode_index = self.builder.op_code_index_for_op_type(
BuiltinOperator.LOGISTIC
)

self.builder.append_operators([t_op])
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
}


Expand Down
2 changes: 2 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ReluPattern,
ReshapePattern,
SharedSpecPattern,
SigmoidPattern,
SoftMaxPattern,
ViewPattern,
)
Expand Down Expand Up @@ -217,6 +218,7 @@ def __init__(self):
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
]
Expand Down
58 changes: 40 additions & 18 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,31 @@ def partition_types(self):
return [torch.ops.aten.view.default]


def get_anchors_for_softmax_like_operators(
fused_partition: List[fx.GraphModule],
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1

qspec = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=1.0 / 256.0,
zero_point=-128,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
)

return PartitionAnchors(
inputs=[(node, 0)],
weights=[],
biases=[],
output=[
(node, qspec),
],
)


class SoftMaxPattern(QuantizationPattern):
"""
Quantizer for Softmax operator.
Expand All @@ -421,23 +446,20 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1
return get_anchors_for_softmax_like_operators(fused_partition)

qspec = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=1.0 / 256.0,
zero_point=-128,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
)

return PartitionAnchors(
inputs=[(node, 0)],
weights=[],
biases=[],
output=[
(node, qspec),
],
)
class SigmoidPattern(QuantizationPattern):
"""
Quantizer for Sigmoid operator.

The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8.
"""

def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.sigmoid.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_softmax_like_operators(fused_partition)
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import numpy as np
import pytest
import torch

from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
)
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
ToNCHWPreprocess,
ToNHWCPreprocess,
)
from executorch.backends.nxp.tests.models import ConvWithSigmoid
from torch import nn
from torch.export import ExportedProgram


@pytest.fixture(autouse=True)
def reseed_model_per_test_run():
torch.manual_seed(23)
np.random.seed(23)


def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)):
model = ConvWithSigmoid(conv_in_channels=input_shape[1])

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

to_quantized_edge_program(model, input_shape).exported_program()

tflite_flatbuffers_model, io_formats = converter_spy.spy_return
exported_program: ExportedProgram = converter_spy.call_args.args[1]

input_data = (np.random.random(input_shape) * 50).astype(np.int8)
convert_run_compare(
exported_program,
tfl_model=tflite_flatbuffers_model,
tflite_input_preprocess=ToNHWCPreprocess(),
tflite_output_preprocess=ToNCHWPreprocess(),
input_data=input_data,
atol=1.0,
)


@pytest.mark.parametrize(
"input_shape",
[
pytest.param((10,), id="Scalar"),
pytest.param((10, 25), id="1D"),
pytest.param((10, 25, 25), id="2D"),
pytest.param((10, 3, 25, 25), id="3D"),
pytest.param((10, 3, 25, 25, 25), id="4D"),
],
)
def test_sigmoid_only(mocker, input_shape):
model = nn.Sigmoid()

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

to_quantized_edge_program(model, input_shape).exported_program()

tflite_flatbuffers_model, io_formats = converter_spy.spy_return
exported_program: ExportedProgram = converter_spy.call_args.args[1]

input_data = (np.random.random(input_shape) * 50).astype(np.int8)
convert_run_compare(
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
)
17 changes: 17 additions & 0 deletions backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ def forward(self, x):
return self.softmax(x)


class ConvWithSigmoid(torch.nn.Module):
def __init__(self, conv_in_channels: int = 3):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=conv_in_channels,
out_channels=3,
kernel_size=(2, 2),
stride=(2, 2),
),
torch.nn.Sigmoid(),
)

def forward(self, x):
return self.block(x)


class LinearModule(torch.nn.Module):
def __init__(self, bias: bool):
super().__init__()
Expand Down
Loading