From 6660b99c52bcae1d21338f091450e078f2962e37 Mon Sep 17 00:00:00 2001 From: Roman Janik Date: Wed, 17 Sep 2025 15:42:26 +0200 Subject: [PATCH 1/2] Add quantizer for aten.mm --- backends/nxp/quantizer/neutron_quantizer.py | 2 + backends/nxp/quantizer/patterns.py | 116 +++++++++----------- 2 files changed, 55 insertions(+), 63 deletions(-) diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index d9dd019c864..db19bcb8ba8 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -25,6 +25,7 @@ LinearPattern, MaxPoolPattern, MeanDimPattern, + MmPattern, NodeArgsIdx, PadPattern, PermutePattern, @@ -199,6 +200,7 @@ def __init__(self): NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig), NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig), NeutronAtenQuantizer(MeanDimPattern(), static_qconfig), + NeutronAtenQuantizer(MmPattern(), static_qconfig), NeutronAtenQuantizer(PadPattern(), static_qconfig), NeutronAtenQuantizer(PermutePattern(), static_qconfig), NeutronAtenQuantizer(ReluPattern(), static_qconfig), diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index e2d6f6dc9ea..34ee611b8b2 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -276,55 +276,20 @@ def get_anchors( ) -class Conv1dPattern(QuantizationPattern): - def partition_types(self) -> list[OpOverload]: - return [torch.ops.aten.conv1d.default] - - def get_anchors( - self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] - ) -> PartitionAnchors: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... - conv1d_node = fused_partition[0].nodes[-1] - - bias_qspec = DerivedQuantizationSpec( - derived_from=[ - (conv1d_node.args[0], conv1d_node), - (conv1d_node.args[1], conv1d_node), - ], - derive_qparams_fn=get_bias_qparams, - dtype=torch.int32, - quant_min=-(2**31), - quant_max=2**31 - 1, - qscheme=torch.per_tensor_affine, - ) - - # Keep bias empty if not supplied - bias = [] - if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: - bias = [(conv1d_node, NodeArgsIdx(2), bias_qspec)] - - return PartitionAnchors( - inputs=[(conv1d_node, NodeArgsIdx(0))], - weights=[(conv1d_node, NodeArgsIdx(1))], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(conv1d_node,)], - ) - - -class Conv2dPattern(QuantizationPattern): +class ConvPattern(QuantizationPattern): + @abstractmethod def partition_types(self) -> list[OpOverload]: - return [torch.ops.aten.conv2d.default] + pass def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: - conv2d_node = fused_partition[0].nodes[-1] + conv_node = fused_partition[0].nodes[-1] bias_quantization_qspec = DerivedQuantizationSpec( derived_from=[ - (conv2d_node.args[0], conv2d_node), - (conv2d_node.args[1], conv2d_node), + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), ], derive_qparams_fn=get_bias_qparams, dtype=torch.int32, @@ -346,17 +311,27 @@ def get_anchors( # Keep bias empty if not supplied bias = [] - if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: - bias = [(conv2d_node, NodeArgsIdx(2), bias_quantization_qspec)] + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)] return PartitionAnchors( - inputs=[(conv2d_node, NodeArgsIdx(0))], - weights=[(conv2d_node, NodeArgsIdx(1), weight_quantization_spec)], + inputs=[(conv_node, NodeArgsIdx(0))], + weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)], biases=bias, - output=[(conv2d_node,)], + output=[(conv_node,)], ) +class Conv1dPattern(ConvPattern): + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv1d.default] + + +class Conv2dPattern(ConvPattern): + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv2d.default] + + class DropoutPattern(SharedSpecPattern): """ Quantizer for Dropout operator. @@ -432,7 +407,6 @@ def partition_types(self) -> list[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... linear_node = fused_partition[0].nodes[-1] bias_qspec = DerivedQuantizationSpec( @@ -455,7 +429,6 @@ def get_anchors( return PartitionAnchors( inputs=[(linear_node, NodeArgsIdx(0))], weights=[(linear_node, NodeArgsIdx(1))], - # pyre-fixme[6]: Incompatible parameter type biases=bias, output=[(linear_node,)], ) @@ -479,6 +452,23 @@ def partition_types(self): return [torch.ops.aten.mean.dim] +class MmPattern(QuantizationPattern): + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.mm.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors: + mm_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(mm_node, NodeArgsIdx(0))], + weights=[(mm_node, NodeArgsIdx(1))], + biases=[], + output=[(mm_node,)], + ) + + class PadPattern(SharedSpecPattern): """ Quantizer for Pad operator. @@ -552,33 +542,33 @@ def get_anchors( ) -class TanhPattern(QuantizationPattern): +class SigmoidPattern(QuantizationPattern): """ - Quantizer for Tanh operator. + Quantizer for Sigmoid operator. - The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8. + The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8. """ - def partition_types(self): - return [torch.ops.aten.tanh.default] + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.sigmoid.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: return get_anchors_for_fixed_quant_specs( - fused_partition, scale=1.0 / 128.0, zero_point=0 + fused_partition, scale=1.0 / 256.0, zero_point=-128 ) -class TanhInPlacePattern(QuantizationPattern): +class TanhPattern(QuantizationPattern): """ - Quantizer for inplace version of Tanh operator (torch.tanh_). + Quantizer for Tanh operator. The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8. """ def partition_types(self): - return [torch.ops.aten.tanh_.default] + return [torch.ops.aten.tanh.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] @@ -588,19 +578,19 @@ def get_anchors( ) -class SigmoidPattern(QuantizationPattern): +class TanhInPlacePattern(QuantizationPattern): """ - Quantizer for Sigmoid operator. + Quantizer for inplace version of Tanh operator (torch.tanh_). - The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8. + The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8. """ - def partition_types(self) -> list[OpOverload]: - return [torch.ops.aten.sigmoid.default] + def partition_types(self): + return [torch.ops.aten.tanh_.default] def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: return get_anchors_for_fixed_quant_specs( - fused_partition, scale=1.0 / 256.0, zero_point=-128 + fused_partition, scale=1.0 / 128.0, zero_point=0 ) From 88f6c22b01efd5e37eff161b208239b1171039ff Mon Sep 17 00:00:00 2001 From: Roman Janik Date: Wed, 17 Sep 2025 16:12:44 +0200 Subject: [PATCH 2/2] Extend tests for Linear, Addmmm, Mm converters --- .../node_converter/test_addmm_converter.py | 89 +++++++++++++++++++ .../node_converter/test_linear_converter.py | 49 ---------- .../node_converter/test_mm_converter.py | 89 +++++++++++++++++++ backends/nxp/tests/models.py | 27 ++++++ 4 files changed, 205 insertions(+), 49 deletions(-) create mode 100644 backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py delete mode 100644 backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py create mode 100644 backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py diff --git a/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py new file mode 100644 index 00000000000..6571ef8773e --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py @@ -0,0 +1,89 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import kgb +import numpy as np +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, +) +from executorch.backends.nxp.tests.models import AddmmModule, LinearModule +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import ExportedProgram + + +class TestAddmmConversion(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + def test_addmm_conversion(self): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + input_shape = (1, 32) + model = AddmmModule(input_shape[1]) + + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) + + def test_linear_conversion__with_bias(self): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + input_shape = (10, 32) + model = LinearModule(bias=True) + + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py deleted file mode 100644 index 858724522cd..00000000000 --- a/backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import pytest -import torch - -from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program -from executorch.backends.nxp.tests.executors import convert_run_compare -from executorch.backends.nxp.tests.models import LinearModule -from executorch.exir.dialects._ops import ops as exir_ops - - -@pytest.fixture(autouse=True) -def reseed_model_per_test_run(): - torch.manual_seed(23) - np.random.seed(23) - - -def test_linear_conversion__with_bias(): - input_shape = (10, 32) - edge_program = to_edge_program( - LinearModule(bias=True), input_shape - ).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - nodes = list(edge_program.graph.nodes) - assert nodes[4].target == exir_ops.edge.aten.addmm.default - assert len(nodes[4].args) == 3 # Has bias. - - convert_run_compare(edge_program, input_data=input_data) - - -def test_linear_conversion__without_bias(): - input_shape = (10, 32) - edge_program = to_edge_program( - LinearModule(bias=False), input_shape - ).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - nodes = list(edge_program.graph.nodes) - assert nodes[3].target == exir_ops.edge.aten.mm.default - assert len(nodes[3].args) == 2 # No bias. - - convert_run_compare(edge_program, input_data=input_data) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py new file mode 100644 index 00000000000..609c0f6c78c --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py @@ -0,0 +1,89 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import kgb +import numpy as np +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, +) +from executorch.backends.nxp.tests.models import LinearModule, MmModule +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import ExportedProgram + + +class TestMmConversion(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + def test_mm_conversion(self): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + input_shape = (1, 32) + model = MmModule(input_shape[1]) + + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) + + def test_linear_conversion__without_bias(self): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + input_shape = (10, 32) + model = LinearModule(bias=False) + + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index bdad9ddc4b4..e7b60b2566c 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from typing import Callable, Collection, Union import torch @@ -169,6 +170,32 @@ def forward(self, x): return self.linear(x) +class AddmmModule(torch.nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels)) + self.bias = torch.nn.Parameter(torch.empty(in_channels)) + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + self.eval() + + def forward(self, x): + return torch.addmm(self.bias, x, self.weight) + + +class MmModule(torch.nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels)) + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.eval() + + def forward(self, x): + return torch.mm(x, self.weight) + + class LinearSoftmaxModule(torch.nn.Module): def __init__(self): super().__init__()