Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
exir_ops.edge.aten.max_pool2d.default: MaxPool2dConverter, # noqa F405
exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
exir_ops.edge.aten.mul.Tensor: MulTensorConverter, # noqa F405
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mm_converter import (
MMConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mul_tensor_converter import (
MulTensorConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.permute_copy_converter import (
PermuteCopyConverter,
)
Expand Down Expand Up @@ -67,27 +70,28 @@
)

__all__ = [
"AbsConverter",
"AdaptiveAvgPool2dConverter",
"AddMMConverter",
"AddTensorConverter",
"AvgPool2dConverter",
"CatConverter",
"CloneConverter",
"ConstantPadNDConverter",
"ConvolutionConverter",
"HardTanhConverter",
"MaxPool2dConverter",
"MeanDimConverter",
"MMConverter",
"MulTensorConverter",
"PermuteCopyConverter",
"SoftmaxConverter",
"ViewCopyConverter",
"QDQPerTensorDequantizeConverter",
"QDQPerChannelDequantizeConverter",
"QDQPerTensorDequantizeConverter",
"QDQQuantizeConverter",
"ConstantPadNDConverter",
"ReLUConverter",
"MeanDimConverter",
"MaxPool2dConverter",
"AvgPool2dConverter",
"AddTensorConverter",
"SubTensorConverter",
"CloneConverter",
"AbsConverter",
"AdaptiveAvgPool2dConverter",
"HardTanhConverter",
"SigmoidConverter",
"SoftmaxConverter",
"SubTensorConverter",
"TanhConverter",
"ViewCopyConverter",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.backend.ir.converter.conversion.common import (
node_uses_shape_broadcasting,
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
NodeConverter,
)
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
mul_options,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torch.fx import Node
from torch.nn import Parameter


class MulTensorConverter(NodeConverter):
@staticmethod
def _is_supported_on_target(
node: Node,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if node_uses_shape_broadcasting(node):
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
return False

node_shape = node.meta["val"].shape

# Check that at least one dimension is divisible by number of MACS
# or all dimensions are equal to one
# Otherwise Neutron cannot convert it
dim_divisible = any(s % 8 == 0 for s in node_shape) or all(
s == 1 for s in node_shape
)
return dim_divisible

@staticmethod
def _is_supported_in_IR(
node: Node,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if len(node.args) != 2:
return False

return True

# mul.Tensor Node format: (Tensor self, Tensor other, *)
def convert(self, node: Node):
"""Convert 'mul_tensor' operator to NeutronIR 'Mul'."""
self.assert_convertible(node)
t_op = self._create_tflite_op_with_io_tensors(node)
t_op.builtin_options = mul_options.Mul()

self.builder.append_operators([t_op])
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405
exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
exir_ops.edge.aten.mul.Tensor: MulTensorConverter, # noqa F405
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
Expand Down
2 changes: 2 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
MaxPoolPattern,
MeanDimPattern,
MmPattern,
MulTensorPattern,
NodeArgsIdx,
PadPattern,
PermutePattern,
Expand Down Expand Up @@ -208,6 +209,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
NeutronAtenQuantizer(MmPattern(self), static_qconfig),
NeutronAtenQuantizer(MulTensorPattern(), static_qconfig),
NeutronAtenQuantizer(PadPattern(), static_qconfig),
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
Expand Down
43 changes: 43 additions & 0 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,49 @@ def get_anchors(
)


class MulTensorPattern(QuantizationPattern):
"""
Quantization pattern for Mul Tensor quantization. Accepts 1 or 2 input nodes.

Basic quantization for all inputs and output.
"""

def partition_types(self) -> list[torch.nn.Module]:
return [torch.ops.aten.mul.Tensor]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
input_nodes = node.all_input_nodes

qspec = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=1.0 / 256.0,
zero_point=0,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
)

# The "Mul" operator in Neutron IR requires a specific scale and zero_point
# (defined above) for its inputs.
# Since these input nodes have already been annotated by their own patterns
# which didn't take the requirements of "Mul" into account, we need to overwrite
# the existing "quantization_annotation".
for input_node in input_nodes:
input_node.meta["quantization_annotation"].output_qspec = qspec

return PartitionAnchors(
inputs=[(node, NodeArgsIdx(0), qspec), (node, NodeArgsIdx(1), qspec)],
weights=[],
biases=[],
output=[
(node,),
],
)


class PadPattern(SharedSpecPattern):
"""
Quantizer for Pad operator.
Expand Down
Loading
Loading