From 1531b5b30ea675963601969707002e63bbf1ec76 Mon Sep 17 00:00:00 2001 From: Jiri Ocenasek Date: Wed, 9 Jul 2025 13:54:28 +0200 Subject: [PATCH] NXP backend: Improve Neutron targets handling using NeutronTargetSpec class --- .../nxp/backend/edge_program_converter.py | 10 ++- .../ir/converter/builder/model_builder.py | 70 +++++++++------- .../backend/ir/converter/node_converter.py | 25 ++---- .../ops_converters/add_tensor_converter.py | 17 ++-- .../ops_converters/cat_converter.py | 81 ++++++++++--------- .../constant_pad_nd_converter.py | 22 ++--- .../ops_converters/convolution_converter.py | 67 +++++++-------- .../ops_converters/mean_dim_converter.py | 38 ++++----- .../ops_converters/softmax_converter.py | 17 +--- .../prune_transpose_operators.py | 2 +- .../nxp/backend/neutron_converter_manager.py | 45 +++++++---- backends/nxp/backend/neutron_target_spec.py | 64 +++++++++++++++ backends/nxp/neutron_partitioner.py | 44 +++++----- backends/nxp/nxp_backend.py | 20 ++--- backends/nxp/tests/executors.py | 10 +-- backends/nxp/tests/test_neutron_backend.py | 2 +- .../tests/test_neutron_converter_manager.py | 9 +-- 17 files changed, 289 insertions(+), 254 deletions(-) create mode 100644 backends/nxp/backend/neutron_target_spec.py diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index 192798c151e..febcd03913a 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -18,6 +18,7 @@ from torch.fx import Node from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.backends.nxp.backend.node_format_inference import ( NodeFormat, NodeFormatInference, @@ -54,12 +55,14 @@ class EdgeProgramToIRConverter: """ _default_conversion_config = ConversionConfig() + _default_target_spec = NeutronTargetSpec("imxrt700", "SDK_25_09") _default_delegation_options = CustomDelegationOptions() def convert_program( self, edge_program: ExportedProgram, - conversion_config=_default_conversion_config, + conversion_config: ConversionConfig = _default_conversion_config, + neutron_target_spec: NeutronTargetSpec = _default_target_spec, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, ) -> (bytes, dict): """ @@ -67,6 +70,7 @@ def convert_program( :param edge_program: Converter ExportedProgram. :param conversion_config: ConversionConfig instance. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. :param custom_delegation_options: Custom user options which affect node delegation. :return: TFLite flatbuffers as bytes. """ @@ -76,6 +80,7 @@ def convert_program( cc = self.build_conversion_context( parameters_mapping, node_formats, + neutron_target_spec, conversion_config, custom_delegation_options, ) @@ -173,11 +178,12 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet def build_conversion_context( parameters_mapping: dict, node_formats: dict[Node, NodeFormat], + neutron_target_spec: NeutronTargetSpec, conversion_config: ConversionConfig = _default_conversion_config, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, ) -> ConversionContext: tflite_builder = AtenModelBuilderDirector( - 3, "TFLite from EdgeProgram", conversion_config + 3, "TFLite from EdgeProgram", neutron_target_spec, conversion_config ) # Add "sentinel" buffer (defined in schema.fbs) diff --git a/backends/nxp/backend/ir/converter/builder/model_builder.py b/backends/nxp/backend/ir/converter/builder/model_builder.py index 496fa752853..643a6231d15 100755 --- a/backends/nxp/backend/ir/converter/builder/model_builder.py +++ b/backends/nxp/backend/ir/converter/builder/model_builder.py @@ -48,6 +48,7 @@ FlexTranspose, ) from executorch.backends.nxp.backend.ir.tflite_optimizer import optimizer +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class ModelBuilder: @@ -74,17 +75,21 @@ class ModelBuilder: _zeros_tensor_map: Dict # Mapping 'string' shapes to 'tflT.Tensor' objects - _default_conversion_config = ConversionConfig() + neutron_target_spec: NeutronTargetSpec conversion_config: ConversionConfig + _default_conversion_config = ConversionConfig() + def __init__( self, model_version: int, model_description: str, + neutron_target_spec: NeutronTargetSpec, conversion_config: ConversionConfig = _default_conversion_config, ) -> None: self._tfl_model = tflite_model.Model(model_version, model_description) + self.neutron_target_spec = neutron_target_spec self.conversion_config = conversion_config self.op_code_type_index_map = {} @@ -471,31 +476,7 @@ def finish(self) -> tflite_model.Model: return self._tfl_model - def _assign_tensor_and_buffer_indices( # noqa C901 - self, allow_inputs_stripping: bool - ): - """Correctly initialize all references via indices in all tensors and buffers.""" - - # Assign each buffer its index - for i, buffer in enumerate(self.get_buffers().vector): - buffer.tmp_index = i - - # Assign each tensor its index and its buffer index - for i, tensor in enumerate(self.get_tensors().vector): - if tensor.tmp_null_tensor: - # Using -1 as the index to the 'tensors' vector is way of telling the TFLite inference engine, that - # this tensor should not be used. - # https://github.com/tensorflow/tensorflow/blob/05404d959119d41a8ffb8a75c6f232cfd8540d45/tensorflow/lite/kernels/kernel_util.cc#L79-L98 - tensor.tmp_index = -1 - else: - tensor.tmp_index = i - - tensor.buffer = tensor.tmp_buffer.tmp_index - - # TODO Remove inputs and outputs that are not in the tensors collection - - # Assign 'Outputs' and 'Inputs' their tensor indices - outputs = self.get_sub_graph().outputs + def _assign_io_tensor_indices(self, inputs, outputs, allow_inputs_stripping: bool): for tensor in outputs.tmp_outputs: try: outputs.append(tensor.tmp_index) @@ -505,7 +486,6 @@ def _assign_tensor_and_buffer_indices( # noqa C901 f"The tensor '{tensor.name}' is among the model outputs, but does NOT appear in the graph!", ) - inputs = self.get_sub_graph().inputs for tensor in inputs.tmp_inputs: try: inputs.append(tensor.tmp_index) @@ -520,14 +500,46 @@ def _assign_tensor_and_buffer_indices( # noqa C901 f"The tensor '{tensor.name}' is among the model inputs, but does NOT appear in the graph!", ) - # Assign each operator its inputs and outputs indices - for operator in self.get_sub_graph().operators.vector: + def _assign_operators_io_tensor_indices(self, operators): + for operator in operators.vector: for inputTensor in operator.tmp_inputs: operator.inputs.append(inputTensor.tmp_index) for outputTensor in operator.tmp_outputs: operator.outputs.append(outputTensor.tmp_index) + def _assign_tensor_and_buffer_indices(self, allow_inputs_stripping: bool): + """Correctly initialize all references via indices in all tensors and buffers.""" + + # Assign each buffer its index + for i, buffer in enumerate(self.get_buffers().vector): + buffer.tmp_index = i + + # Assign each tensor its index and its buffer index + for i, tensor in enumerate(self.get_tensors().vector): + if tensor.tmp_null_tensor: + # Using -1 as the index to the 'tensors' vector is way of telling the TFLite inference engine, that + # this tensor should not be used. + # https://github.com/tensorflow/tensorflow/blob/05404d959119d41a8ffb8a75c6f232cfd8540d45/tensorflow/lite/kernels/kernel_util.cc#L79-L98 + tensor.tmp_index = -1 + else: + tensor.tmp_index = i + + tensor.buffer = tensor.tmp_buffer.tmp_index + + # TODO Remove inputs and outputs that are not in the tensors collection + + subgraph = self.get_sub_graph() + + # Assign 'Outputs' and 'Inputs' their tensor indices + self._assign_io_tensor_indices( + inputs=subgraph.inputs, + outputs=subgraph.outputs, + allow_inputs_stripping=allow_inputs_stripping, + ) + # Assign each operator its inputs and outputs indices + self._assign_operators_io_tensor_indices(operators=subgraph.operators) + def _build_operator_code( self, op_type: BuiltinOperator, version, custom_code: str = None ): diff --git a/backends/nxp/backend/ir/converter/node_converter.py b/backends/nxp/backend/ir/converter/node_converter.py index c44a6e19955..36266486aac 100755 --- a/backends/nxp/backend/ir/converter/node_converter.py +++ b/backends/nxp/backend/ir/converter/node_converter.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from enum import Enum import torch @@ -16,6 +15,7 @@ AtenModelBuilderDirector, ) from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.exir.dialects._ops import ops as exir_ops from torch.fx import Node from torch.fx.passes.infra.partitioner import Partition @@ -42,17 +42,6 @@ def is_not_qdq_node(node: torch.fx.Node) -> bool: return not (_is_quant_node(node) or _is_dequant_node(node)) -class Target(Enum): - IGNORE = "ignore" # No target platform. Any target specific restrictions will be ignored. - - RT700 = "imxrt700" - IMX95 = "imx95" - - @classmethod - def values(cls) -> list[str]: - return [elt.value for elt in cls] - - class NodeConverter(ABC): """ Classes which implement conversion of torch.Node to TFLite should inherit from this class and overwrite the @@ -94,7 +83,7 @@ def _is_supported_in_IR( @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: @@ -103,31 +92,31 @@ def _is_supported_on_target( can be used by operators with no target specific requirements. :param node: The node (edge operator) to check. - :param target: Value of the `Target` enum representing the target platform to check for. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. :param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it). :param custom_delegation_options: Custom options which affect delegation. """ - return target == Target.RT700 + return True @classmethod def is_supported( cls, node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: """Check if the given `node` is supported in the IR and on the given `target` platform. :param node: torch.Node to check. - :param target: Value of the `Target` enum representing the target platform to check for. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. :param parameters_mapping: Dict mapping tensor names to their data. :param custom_delegation_options: Custom user options which affect node delegation. """ return cls._is_supported_in_IR( node, parameters_mapping, custom_delegation_options ) and cls._is_supported_on_target( - node, target, parameters_mapping, custom_delegation_options + node, neutron_target_spec, parameters_mapping, custom_delegation_options ) @classmethod diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py index c74baa61f67..cd5aa2ead81 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py @@ -9,11 +9,11 @@ from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( add_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -22,20 +22,15 @@ class AddTensorConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - if node_uses_shape_broadcasting(node): - # Shape broadcasting may require the addition of `Transpose` ops during conversion. - return False - - return True + if node_uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False - case _: - return False + return True @staticmethod def _is_supported_in_IR( diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py index 4f7f00fe5ba..22ca258cd4f 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py @@ -13,11 +13,11 @@ _is_dequant_node, _is_quant_node, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import ( Concatenation, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -72,51 +72,52 @@ def _all_io_shares_quantization_parameters(node: Node) -> bool: @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: if custom_delegation_options.force_delegate_cat: return True - match target: - case Target.RT700: - dim = CatConverter._get_normalized_dim(node) - - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491 - if dim == 0: - return False - - # Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the - # last dimension, depending on the formats of the node. The format, however, cannot be determined - # during conversion, as it depends on what other nodes are delegated. - input_channels = [ - # The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it - # will still be the channels in the IR. - _get_shape(input_)[1] - for input_ in node.all_input_nodes - ] + [ - # If the inputs/outputs are channels first, the last dimension will be the channels. - _get_shape(input_)[-1] - for input_ in node.all_input_nodes - ] - if any((input_channel % 8) != 0 for input_channel in input_channels): - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492 - return False - - output_channels = [_get_shape(node)[1], _get_shape(node)[-1]] - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 - if any((out_c % 8) != 0 for out_c in output_channels): - return False - - if len(node.all_input_nodes) < 2: # Not supported on Neutron - # TODO Try to skip the operator if this case is realistic. - return False - - return True - - case _: - return False + dim = CatConverter._get_normalized_dim(node) + + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491 + if dim == 0: + return False + + # Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the + # last dimension, depending on the formats of the node. The format, however, cannot be determined + # during conversion, as it depends on what other nodes are delegated. + input_channels = [ + # The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it + # will still be the channels in the IR. + _get_shape(input_)[1] + for input_ in node.all_input_nodes + ] + [ + # If the inputs/outputs are channels first, the last dimension will be the channels. + _get_shape(input_)[-1] + for input_ in node.all_input_nodes + ] + if any( + (input_channel % neutron_target_spec.get_num_macs()) != 0 + for input_channel in input_channels + ): + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492 + return False + + output_channels = [_get_shape(node)[1], _get_shape(node)[-1]] + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 + if any( + (out_c % neutron_target_spec.get_num_macs()) != 0 + for out_c in output_channels + ): + return False + + if len(node.all_input_nodes) < 2: # Not supported on Neutron + # TODO Try to skip the operator if this case is realistic. + return False + + return True @staticmethod def _is_supported_in_IR( diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py index f58df1a88d9..499541aa58c 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py @@ -17,7 +17,6 @@ from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.converter.quantization_utils import ( quantize_int8, @@ -27,6 +26,7 @@ pad_options, pad_v2_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -35,22 +35,16 @@ class ConstantPadNDConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - # TODO: Consider different tensor formats (dim-order) - paddings = node.args[1] - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension, which is not supported on Neutron. - return False - - return True - - case _: - return False + paddings = node.args[1] + if len(paddings) > 4 and paddings[4:6] != [0, 0]: + # Attempt to Pad channels dimension, which is not supported on Neutron. + return False + + return True @staticmethod def _is_supported_in_IR( diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index 8955b4c8fd4..f32b5a65cac 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -25,7 +25,6 @@ from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.converter.node_converters.shared import ( conv_utils, @@ -45,6 +44,7 @@ depthwise_conv_2d_options, reshape_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -53,45 +53,38 @@ class ConvolutionConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - activations = node.args[0] - weights = node.args[1] - groups = node.args[8] - - if activations.meta["val"].shape[0] != 1: - # Only batch size 1 is supported on neutron. - return False - - if groups == 1: # Regular convolution. - pass - elif conv_utils.group_conv_convertible_as_depthwise( - node, groups - ): # Depthwise convolution. - # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted - # weights. In case the weights are dynamic, a Transpose operator would have to be added, which - # is not supported on Neutron. - if not node_is_effectively_static_tensor( - weights, parameters_mapping - ): - return False - elif conv_utils.group_conv_convertible_into_multiple_convolutions( - node, groups - ): # Separable conv. This should never be reached, as the node should have been decomposed into - # multiple parallel convolutions by the `SplitGroupConvolution` pre-processing pass. - logging.warning("Group convolution was not decomposed.") - return False - else: # Unexpected case (should never happen). - return False - - return True - - case _: + activations = node.args[0] + weights = node.args[1] + groups = node.args[8] + + if activations.meta["val"].shape[0] != 1: + # Only batch size 1 is supported on neutron. + return False + + if groups == 1: # Regular convolution. + pass + elif conv_utils.group_conv_convertible_as_depthwise( + node, groups + ): # Depthwise convolution. + # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted + # weights. In case the weights are dynamic, a Transpose operator would have to be added, which + # is not supported on Neutron. + if not node_is_effectively_static_tensor(weights, parameters_mapping): return False + elif conv_utils.group_conv_convertible_into_multiple_convolutions( + node, groups + ): # Separable conv. This should never be reached, as the node should have been decomposed into + # multiple parallel convolutions by the `SplitGroupConvolution` pre-processing pass. + logging.warning("Group convolution was not decomposed.") + return False + else: # Unexpected case (should never happen). + return False + + return True @staticmethod def _is_supported_in_IR( @@ -238,7 +231,7 @@ def _convert_1d_conv( def _convert_unpadded_2D( self, t_op: tflite_model.Operator, conv_params: ConvParameters ) -> conv_utils.ConvConversionResult: - """Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converter by the + """Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converted by the caller. """ common.assign_2d_strides(t_op.builtin_options, conv_params.stride) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py index f03c403876f..c1dd7b600be 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py @@ -12,7 +12,6 @@ from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.converter.node_converters.shared.reduce_utils import ( convert_axes_from_attribute, @@ -20,6 +19,7 @@ from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( mean_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -28,34 +28,20 @@ class MeanDimConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - # TODO: Consider different tensor formats (dim-order) - dim = node.args[1] - keepdim = node.args[2] if len(node.args) >= 3 else False - rank = len(node.args[0].meta["val"].shape) - dim = [MeanDimConverter._to_neg_dim(d, rank) for d in dim] - - # Only last 2 dimensions (H, W) and keepdim=True with rank=4 are supported on Neutron. - if rank != 4 or dim not in [[-1, -2], [-2, -1]] or not keepdim: - return False - - return True - - case _: - return False + dim = node.args[1] + keepdim = node.args[2] if len(node.args) >= 3 else False + rank = len(node.args[0].meta["val"].shape) + dim = [d - rank if d > 0 else d for d in dim] - @staticmethod - def _to_pos_dim(d, rank): - return d + rank if d < 0 else d + # Only last 2 dimensions (H, W) and keepdim=True with rank=4 are supported on Neutron. + if rank != 4 or dim not in [[-1, -2], [-2, -1]] or not keepdim: + return False - @staticmethod - def _to_neg_dim(d, rank): - return d - rank if d > 0 else d + return True @staticmethod def _is_supported_in_IR( @@ -75,6 +61,10 @@ def _is_supported_in_IR( return True + @staticmethod + def _to_pos_dim(d: int, rank: int): + return d + rank if d < 0 else d + @staticmethod def _normalize_and_to_channel_last_dim(dim: list[int], rank: int) -> list[int]: # convert negative index to positive diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/softmax_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/softmax_converter.py index aa74c78ca24..5e4404d8476 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/softmax_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/softmax_converter.py @@ -7,13 +7,11 @@ CustomDelegationOptions, ) from executorch.backends.nxp.backend.edge_helper import input_rank -from executorch.backends.nxp.backend.ir.converter.node_converter import ( - NodeConverter, - Target, -) +from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( softmax_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -22,18 +20,11 @@ class SoftmaxConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - # The eIQ Neutron NPU runtime software has a known issue with the SoftMax operation. - # As long as the issue is present, return False for the i.MX RT700 target also. - return False - - case _: - return False + return False @staticmethod def _is_supported_in_IR( diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py index dc9ad9999b4..0be46efcaa8 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/backends/nxp/backend/neutron_converter_manager.py b/backends/nxp/backend/neutron_converter_manager.py index 2bc4380f89b..a6884a9ee24 100644 --- a/backends/nxp/backend/neutron_converter_manager.py +++ b/backends/nxp/backend/neutron_converter_manager.py @@ -7,8 +7,6 @@ import multiprocessing import pkgutil -from executorch.backends.nxp.backend.ir.converter.node_converter import Target - def convert_unsafe(neutron_converter, tflite_model, cctx, queue): """ @@ -27,16 +25,7 @@ class NeutronConverterManager: contains NeutronGraph nodes. """ - _supported_target_names = [Target.RT700.value] - - def convert( - self, tflite_model: bytes, target: str, neutron_converter_flavor: str - ) -> bytes: - # Neutron converter crashes if we provide invalid target -> verify. - if target not in self._supported_target_names: - raise RuntimeError( - f"Target '{target}' is not supported by NeutronConverterManager." - ) + def __init__(self, neutron_converter_flavor: str = "SDK_25_09"): neutron_converter_modules = [ module.name @@ -57,13 +46,34 @@ def convert( f"not found. Install 'neutron_converter_[flavor]' Python package." ) - neutron_converter = importlib.import_module( + self.neutron_converter = importlib.import_module( f"{requested_module_name}.neutron_converter" ) + self.neutron_library_utils = importlib.import_module( + f"{requested_module_name}.neutron_library_utils" + ) + + def get_converter(self): + return self.neutron_converter + + def get_library_utils(self): + return self.neutron_library_utils + + def verify_target(self, target: str): + if not self.neutron_library_utils.isNeutronTarget(target): + valid_targets = [ + target.name for target in self.neutron_library_utils.getNeutronTargets() + ] + raise ValueError( + f"Target `{target}` is not a valid target. Must be one of `{valid_targets}`." + ) + + def convert(self, tflite_model: bytes, target: str) -> bytes: + # Neutron converter crashes if we provide invalid target -> verify. + self.verify_target(target) - cctx = neutron_converter.CompilationContext() - cctx.targetOpts = neutron_converter.getNeutronTarget(target) - # New switch since Neutron Converter SDK_25.06 + cctx = self.neutron_converter.CompilationContext() + cctx.targetOpts = self.neutron_converter.getNeutronTarget(target) cctx.compilationOpts.minNumOpsPerGraph = 1 logger = multiprocessing.log_to_stderr() @@ -71,7 +81,8 @@ def convert( queue = multiprocessing.Manager().Queue() process = multiprocessing.Process( - target=convert_unsafe, args=(neutron_converter, tflite_model, cctx, queue) + target=convert_unsafe, + args=(self.neutron_converter, tflite_model, cctx, queue), ) process.start() process.join() # waits until the subprocess is complete diff --git a/backends/nxp/backend/neutron_target_spec.py b/backends/nxp/backend/neutron_target_spec.py new file mode 100644 index 00000000000..44399982e29 --- /dev/null +++ b/backends/nxp/backend/neutron_target_spec.py @@ -0,0 +1,64 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Target Spec for the NXP Neutron NPU + +from enum import Enum + +from executorch.backends.nxp.backend.neutron_converter_manager import ( + NeutronConverterManager, +) + + +class NeutronHWVersion(Enum): + N1 = 1 + N3 = 2 + + +class NeutronTargetSpec: + """ + The functionality for probing the properties of Neutron Target. + """ + + def __init__(self, target: str, neutron_converter_flavor: str): + + converter_manager = NeutronConverterManager(neutron_converter_flavor) + converter_manager.verify_target(target) + neutron_converter = converter_manager.get_converter() + self.neutron_target = neutron_converter.getNeutronTarget(target) + + if self.is_subsystem(): + raise ValueError( + f"Target `{target}` is not a neutron-C target. Only MCU targets are supported at the moment." + ) + + if self.get_hw_version() != NeutronHWVersion.N3: + raise ValueError( + f"Target `{target}` contains unsupported HW version. Only N3/N3+ targets are supported at the moment." + ) + + # Target name. + def get_name(self) -> str: + return self.neutron_target.name + + # Whether the target has subsystem (Neutron-S) or not (Neutron-C). + def is_subsystem(self) -> bool: + return self.neutron_target.subsystem + + # Number of compute units. + def get_num_units(self) -> int: + return self.neutron_target.numUnits + + # Number of compute pipelines. + def get_num_pipes(self) -> int: + return self.neutron_target.numPipes + + # Number of compute MACs. + def get_num_macs(self) -> int: + return self.neutron_target.numMacs + + # Neutron compute block hardware version. + def get_hw_version(self) -> NeutronHWVersion: + return NeutronHWVersion(self.neutron_target.hwVersion) diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 371b7474f58..917545e6c89 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -8,7 +8,7 @@ import logging import operator from dataclasses import dataclass -from typing import Dict, final, List, Mapping +from typing import final, Mapping import torch @@ -18,13 +18,13 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.backend.ir.converter.node_converter import Target from torch.export.exported_program import ExportedProgram from torch.fx import Graph from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupportBase from torch.nn import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.backends.nxp.nxp_backend import NeutronBackend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( @@ -64,7 +64,7 @@ class QDQCluster: """ compute_node: torch.fx.Node - ops: List[torch.fx.Node] + ops: list[torch.fx.Node] QUANTIZE_OPERATORS = [ exir_ops.edge.quantized_decomposed.quantize_per_channel.default, @@ -97,7 +97,7 @@ def is_dequant_node(node: torch.fx.Node) -> bool: def is_auxiliary_node(node: torch.fx.Node) -> bool: return node.target in QDQClusterRecognizer.AUXILIARY_OPS - def get_qdq_cluster_input_part(self, node: torch.fx.Node) -> List[torch.fx.Node]: + def get_qdq_cluster_input_part(self, node: torch.fx.Node) -> list[torch.fx.Node]: """ Return the list of nodes representing the input part of the QDQ cluster of the node `node`. Those are various dequantization nodes (see DEQUANTIZE_OPERATORS) optionally followed by auxiliary @@ -125,7 +125,7 @@ def get_qdq_cluster_input_part(self, node: torch.fx.Node) -> List[torch.fx.Node] logging.debug(f"Dequant Cluster for {node} is: {qdq_cluster}") return qdq_cluster - def get_qdq_cluster_output_part(self, node: torch.fx.Node) -> List[torch.fx.Node]: + def get_qdq_cluster_output_part(self, node: torch.fx.Node) -> list[torch.fx.Node]: """ Returns the list of nodes representing the output part of the QDQ cluster of the `node`. Those are various quantize nodes (see QUANTIZE_OPERATORS) preceded by auxiliary nodes. @@ -155,7 +155,7 @@ def get_qdq_cluster_output_part(self, node: torch.fx.Node) -> List[torch.fx.Node logging.debug(f"Quant Cluster for {node} is {qdq_cluster}") return qdq_cluster - def get_qdq_cluster(self, node: torch.fx.Node) -> List[torch.fx.Node]: + def get_qdq_cluster(self, node: torch.fx.Node) -> list[torch.fx.Node]: """ Returns the QDQ cluster of the operator, if quantized. If operator is not quantized, returns empty list. """ @@ -167,7 +167,7 @@ def get_qdq_cluster(self, node: torch.fx.Node) -> List[torch.fx.Node]: else: return [] - def tag_nodes(self, nodes: List[torch.fx.Node], cluster_name: str) -> None: + def tag_nodes(self, nodes: list[torch.fx.Node], cluster_name: str) -> None: """ Tags a node and its related dequant and quant nodes with a specified cluster name """ @@ -175,7 +175,7 @@ def tag_nodes(self, nodes: List[torch.fx.Node], cluster_name: str) -> None: logging.info(f"Tagging node {node} as {cluster_name}") node.meta["cluster"] = cluster_name - def tag_qdq_clusters(self, nodes: List[torch.fx.Node]): + def tag_qdq_clusters(self, nodes: list[torch.fx.Node]): """ Identifies QDQ clusters and tag them based on compute operation inside. """ @@ -220,14 +220,14 @@ class NeutronSupportedOperators(OperatorSupportBase): def __init__( self, - qdq_clusters: Dict[str, QDQClusterRecognizer.QDQCluster], - target: Target, - operators_not_to_delegate: List[str], + qdq_clusters: dict[str, QDQClusterRecognizer.QDQCluster], + neutron_target_spec: NeutronTargetSpec, + operators_not_to_delegate: list[str], parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ): self.qdq_clusters = qdq_clusters - self.target = target + self.neutron_target_spec = neutron_target_spec self.operators_not_to_delegate = operators_not_to_delegate self.parameters_mapping = parameters_mapping self.custom_delegation_options = custom_delegation_options @@ -269,7 +269,7 @@ def _is_node_supported_compute(self, node: torch.fx.node.Node) -> bool: # TODO: `view_copy` node should be delegated only if it's not the only operator in the cluster. node_converter.is_supported( node, - self.target, + self.neutron_target_spec, self.parameters_mapping, self.custom_delegation_options, ) @@ -305,13 +305,16 @@ def is_node_supported( class NeutronPartitioner(Partitioner): def __init__( self, - compile_spec: List[CompileSpec], + compile_spec: list[CompileSpec], custom_delegation_options: CustomDelegationOptions | None = None, ) -> None: self.delegation_spec = DelegationSpec(NeutronBackend.__name__, compile_spec) self.custom_delegation_options = ( custom_delegation_options or CustomDelegationOptions() ) + target = self.delegation_spec[1][2].value.decode() + converter_flavor = self.delegation_spec[1][3].value.decode() + self.neutron_target_spec = NeutronTargetSpec(target, converter_flavor) def validate_partitioning_result( self, @@ -343,22 +346,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # subgraphs containing the nodes with the tags logging.info("NeutronPartitioner::partition") partition_tags = {} + partition_list = [] graph_module = exported_program.graph_module nodes = list(graph_module.graph.nodes) qdq_cluster_recognizer = QDQClusterRecognizer() qdq_cluster_recognizer.tag_qdq_clusters(nodes) + graph_module.recompile() - target = None - operators_not_to_delegate = "" - for spec in self.delegation_spec.compile_specs: - if spec.key == "target": - target = Target(spec.value.decode()) - if spec.key == "operators_not_to_delegate": - operators_not_to_delegate = spec.value.decode().split(",") - assert target is not None + operators_not_to_delegate = self.delegation_spec[1][4].value.decode().split(",") logging.info(f"Operators not to delegate: {operators_not_to_delegate}") parameters_mapping = EdgeProgramToIRConverter.map_inputs_to_parameters( @@ -368,7 +366,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: exported_program.graph_module, NeutronSupportedOperators( qdq_cluster_recognizer.cluster_map, - target, + self.neutron_target_spec, operators_not_to_delegate, parameters_mapping, self.custom_delegation_options, diff --git a/backends/nxp/nxp_backend.py b/backends/nxp/nxp_backend.py index fd1687d73fd..44e9a19d9f2 100644 --- a/backends/nxp/nxp_backend.py +++ b/backends/nxp/nxp_backend.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,11 +18,11 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.backend.ir.converter.node_converter import Target from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.backends.nxp.neutron_node_extraction import ( extract_artifacts_from_neutron_node, NeutronNodeArtifacts, @@ -36,9 +36,9 @@ class NeutronCompileSpecBuilder: + config: NeutronTargetSpec def __init__(self): - self.config: Target = None self.compile_spec: List[CompileSpec] = [] self.compiler_flags = [] self.output_format = None @@ -68,14 +68,9 @@ def neutron_compile_spec( extra_flags: Extra flags for the Neutron compiler operators_not_to_delegate: List of operators that should not be delegated """ - try: - self.config = Target(config) - except ValueError: - raise ValueError( - f"Config `{config}` is not a valid target. Must be one of `{Target.values()}`." - ) self.neutron_converter_flavor = neutron_converter_flavor + self.config = NeutronTargetSpec(config, neutron_converter_flavor) assert ( self.output_format is None @@ -101,7 +96,7 @@ def build(self): self.compile_spec += [ CompileSpec("output_format", "tflite".encode()), CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()), - CompileSpec("target", self.config.value.encode()), + CompileSpec("target", self.config.get_name().encode()), CompileSpec( "neutron_converter_flavor", self.neutron_converter_flavor.encode() ), @@ -187,10 +182,11 @@ def preprocess( # noqa C901 # Convert the edge program to TFLite. tflite_model, io_formats = EdgeProgramToIRConverter().convert_program( edge_program, + neutron_target_spec=NeutronTargetSpec(target, neutron_converter_flavor), ) - neutron_model = NeutronConverterManager().convert( - tflite_model, target, neutron_converter_flavor + neutron_model = NeutronConverterManager(neutron_converter_flavor).convert( + tflite_model, target ) # Dump the tflite file if logging level is enabled diff --git a/backends/nxp/tests/executors.py b/backends/nxp/tests/executors.py index 592717c0b3b..9626a2779c4 100644 --- a/backends/nxp/tests/executors.py +++ b/backends/nxp/tests/executors.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,10 +18,8 @@ create_channels_first_to_channels_last_permutation, create_channels_last_to_channels_first_permutation, ) -from executorch.backends.nxp.backend.ir.converter.node_converter import ( - NodeConverter, - Target, -) +from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.export import ExportedProgram from torch.fx import Node from torch.fx.graph import Graph @@ -373,7 +371,7 @@ def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool: return any(node.target in ops for node in graph.nodes) -target_support_check_function = Callable[[Node, Target], bool] +target_support_check_function = Callable[[Node, NeutronTargetSpec], bool] class OverrideTargetSupportCheck: diff --git a/backends/nxp/tests/test_neutron_backend.py b/backends/nxp/tests/test_neutron_backend.py index 53e54ec2f56..c9917651fbd 100644 --- a/backends/nxp/tests/test_neutron_backend.py +++ b/backends/nxp/tests/test_neutron_backend.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/backends/nxp/tests/test_neutron_converter_manager.py b/backends/nxp/tests/test_neutron_converter_manager.py index e10e8cca67b..2fcfd8cd987 100644 --- a/backends/nxp/tests/test_neutron_converter_manager.py +++ b/backends/nxp/tests/test_neutron_converter_manager.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -29,9 +29,7 @@ def test_conv2d_neutron_conversion__default_flavor(): ) neutron_converter_manager = NeutronConverterManager() - neutron_model = neutron_converter_manager.convert( - tflite_model, "imxrt700", "SDK_25_09" - ) + neutron_model = neutron_converter_manager.convert(tflite_model, "imxrt700") assert len( neutron_model @@ -50,9 +48,8 @@ def test__conv2d_neutron_conversion__invalid_flavor(): edge_program_manager.exported_program() ) - neutron_converter_manager = NeutronConverterManager() with pytest.raises(RuntimeError) as excinfo: - _ = neutron_converter_manager.convert(tflite_model, "imxrt700", "bad_flavor") + _ = NeutronConverterManager("bad_flavor").convert(tflite_model, "imxrt700") assert "Neutron Converter module with flavor 'bad_flavor' not found." in str( excinfo