Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LinearPattern,
MaxPoolPattern,
MeanDimPattern,
MmPattern,
NodeArgsIdx,
PadPattern,
PermutePattern,
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(self):
NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig),
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
NeutronAtenQuantizer(MmPattern(), static_qconfig),
NeutronAtenQuantizer(PadPattern(), static_qconfig),
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
Expand Down
116 changes: 53 additions & 63 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,55 +276,20 @@ def get_anchors(
)


class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv1d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv1d_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv1d_node.args[0], conv1d_node),
(conv1d_node.args[1], conv1d_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

# Keep bias empty if not supplied
bias = []
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
bias = [(conv1d_node, NodeArgsIdx(2), bias_qspec)]

return PartitionAnchors(
inputs=[(conv1d_node, NodeArgsIdx(0))],
weights=[(conv1d_node, NodeArgsIdx(1))],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv1d_node,)],
)


class Conv2dPattern(QuantizationPattern):
class ConvPattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv2d.default]
pass

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
conv2d_node = fused_partition[0].nodes[-1]
conv_node = fused_partition[0].nodes[-1]

bias_quantization_qspec = DerivedQuantizationSpec(
derived_from=[
(conv2d_node.args[0], conv2d_node),
(conv2d_node.args[1], conv2d_node),
(conv_node.args[0], conv_node),
(conv_node.args[1], conv_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
Expand All @@ -346,17 +311,27 @@ def get_anchors(

# Keep bias empty if not supplied
bias = []
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
bias = [(conv2d_node, NodeArgsIdx(2), bias_quantization_qspec)]
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]

return PartitionAnchors(
inputs=[(conv2d_node, NodeArgsIdx(0))],
weights=[(conv2d_node, NodeArgsIdx(1), weight_quantization_spec)],
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=[(conv2d_node,)],
output=[(conv_node,)],
)


class Conv1dPattern(ConvPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv1d.default]


class Conv2dPattern(ConvPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.conv2d.default]


class DropoutPattern(SharedSpecPattern):
"""
Quantizer for Dropout operator.
Expand Down Expand Up @@ -432,7 +407,6 @@ def partition_types(self) -> list[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
linear_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
Expand All @@ -455,7 +429,6 @@ def get_anchors(
return PartitionAnchors(
inputs=[(linear_node, NodeArgsIdx(0))],
weights=[(linear_node, NodeArgsIdx(1))],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(linear_node,)],
)
Expand All @@ -479,6 +452,23 @@ def partition_types(self):
return [torch.ops.aten.mean.dim]


class MmPattern(QuantizationPattern):
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.mm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
mm_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[(mm_node, NodeArgsIdx(0))],
weights=[(mm_node, NodeArgsIdx(1))],
biases=[],
output=[(mm_node,)],
)


class PadPattern(SharedSpecPattern):
"""
Quantizer for Pad operator.
Expand Down Expand Up @@ -552,33 +542,33 @@ def get_anchors(
)


class TanhPattern(QuantizationPattern):
class SigmoidPattern(QuantizationPattern):
"""
Quantizer for Tanh operator.
Quantizer for Sigmoid operator.

The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8.
"""

def partition_types(self):
return [torch.ops.aten.tanh.default]
def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.sigmoid.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 128.0, zero_point=0
fused_partition, scale=1.0 / 256.0, zero_point=-128
)


class TanhInPlacePattern(QuantizationPattern):
class TanhPattern(QuantizationPattern):
"""
Quantizer for inplace version of Tanh operator (torch.tanh_).
Quantizer for Tanh operator.

The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
"""

def partition_types(self):
return [torch.ops.aten.tanh_.default]
return [torch.ops.aten.tanh.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
Expand All @@ -588,19 +578,19 @@ def get_anchors(
)


class SigmoidPattern(QuantizationPattern):
class TanhInPlacePattern(QuantizationPattern):
"""
Quantizer for Sigmoid operator.
Quantizer for inplace version of Tanh operator (torch.tanh_).

The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8.
The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
"""

def partition_types(self) -> list[OpOverload]:
return [torch.ops.aten.sigmoid.default]
def partition_types(self):
return [torch.ops.aten.tanh_.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 256.0, zero_point=-128
fused_partition, scale=1.0 / 128.0, zero_point=0
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import kgb
import numpy as np
import torch

from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
)
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
graph_contains_any_of_ops,
)
from executorch.backends.nxp.tests.models import AddmmModule, LinearModule
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram


class TestAddmmConversion(unittest.TestCase):
@classmethod
def setUpClass(cls):
torch.manual_seed(23)
np.random.seed(42)

def test_addmm_conversion(self):
with kgb.spy_on(
EdgeProgramToIRConverter.convert_program, call_original=True
) as converter_spy:
input_shape = (1, 32)
model = AddmmModule(input_shape[1])

edge_program = to_quantized_edge_program(
model, input_shape
).exported_program()

# Make sure that all nodes were delegated.
assert not graph_contains_any_of_ops(
graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default]
)
assert any(
"lowered_module" in node.name for node in edge_program.graph.nodes
)

tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
np.int8
)
convert_run_compare(
exported_program,
input_data,
tfl_model=tflite_flatbuffers_model,
)

def test_linear_conversion__with_bias(self):
with kgb.spy_on(
EdgeProgramToIRConverter.convert_program, call_original=True
) as converter_spy:
input_shape = (10, 32)
model = LinearModule(bias=True)

edge_program = to_quantized_edge_program(
model, input_shape
).exported_program()

# Make sure that all nodes were delegated.
assert not graph_contains_any_of_ops(
graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default]
)
assert any(
"lowered_module" in node.name for node in edge_program.graph.nodes
)

tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
np.int8
)
convert_run_compare(
exported_program,
input_data,
tfl_model=tflite_flatbuffers_model,
)

This file was deleted.

Loading
Loading