diff --git a/backends/nxp/TARGETS b/backends/nxp/TARGETS index 875f9813f43..a5a0508b33c 100644 --- a/backends/nxp/TARGETS +++ b/backends/nxp/TARGETS @@ -32,6 +32,18 @@ runtime.python_library( ], ) +runtime.python_library( + name = "_passes", + srcs = glob([ + "_passes/*.py", + ]), + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:pass_manager", + ], +) + runtime.python_library( name = "quantizer", srcs = [ @@ -65,6 +77,7 @@ runtime.python_library( deps = [ ":neutron_sdk", ":aten_passes", + ":_passes", ":quantizer", "fbsource//third-party/pypi/flatbuffers:flatbuffers", "fbsource//third-party/pypi/ml-dtypes:ml-dtypes", diff --git a/backends/nxp/_passes/remove_getitem_pass.py b/backends/nxp/_passes/remove_getitem_pass.py new file mode 100644 index 00000000000..646f5083adf --- /dev/null +++ b/backends/nxp/_passes/remove_getitem_pass.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.nxp.backend.node_format_inference import ( + NodeFormat, + NXP_NODE_FORMAT, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RemoveGetItemPass(ExportPass): + """ + This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator, + that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator. + Before Pass: + MaxPool2d ---> GetItem[max_values, max_indexes] + After Pass: + MaxPool2d -> max_values + """ + + def call(self, graph_module: torch.fx.GraphModule): + module = graph_module + for node in module.graph.nodes: + if node.op == "call_function": + if ( + node.target.__name__ == "aten.max_pool2d_with_indices.default" + or node.target.__name__ == "aten.max.dim" + ): + users = list(node.users.keys()) + + if len(users) != 1: + if len(users) == 2 and node.target.__name__ == "aten.max.dim": + # Two users is allowed for max.dim. For that case, + # rather than removing the getitem node in this + # pass, we handle the getitem nodes in the op's + # visitor when serializing + continue + else: + raise AssertionError( + f"Invalid number of users for {node.target.__name__}: {len(users)}" + ) + + getitem_node = list(node.users.keys())[0] + + if getitem_node.target.__name__ != "getitem": + raise AssertionError( + f"Expected max node's user to be getitem, got {getitem_node.target.__name__}" + ) + + getitem_index = getitem_node.args[1] + + with module.graph.inserting_before(node): + if ( + node.target.__name__ + == "aten.max_pool2d_with_indices.default" + ): + if getitem_index != 0: + raise AssertionError( + f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices." + ) + new_max_wd = module.graph.create_node( + "call_function", + exir_ops.edge.aten.max_pool2d.default, + args=node.args, + kwargs=node.kwargs, + ) + + else: + if getitem_index != 0: + raise AssertionError( + f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone." + ) + new_max_wd = module.graph.create_node( + "call_function", + exir_ops.edge.aten.amax.default, + args=node.args, + kwargs=node.kwargs, + ) + + # MODIFIED PART START + # Make sure to preserve the inferred node format. + new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get( + NXP_NODE_FORMAT, NodeFormat.NONE + ) + # MODIFIED PART END + + getitem_node.replace_all_uses_with(new_max_wd) + + module.graph.erase_node(getitem_node) + module.graph.erase_node(node) + + graph_module.recompile() + # Propagate metadata and retrace module + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index 60b367c0f39..f0864bd0dec 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. import torch + +from executorch.exir.dialects._ops import ops as exir_ops from torch.fx import GraphModule, Node from torch.nn import Parameter @@ -87,3 +89,33 @@ def try_get_tensor_constant_from_node( return None attr_itr = getattr(attr_itr, atom) return attr_itr + + +def _is_dequantize(node_: Node) -> bool: + return node_.op == "call_function" and node_.target in [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ] + + +def _is_quantize(node_: Node) -> bool: + return node_.op == "call_function" and node_.target.__name__ in [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ] + + +def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None: + """Return the first node which is not a `quantize` or `dequantize`, found by traversing the graph backwards + starting with the `node.args[input_index]`, + """ + current_node = node.args[input_index] + while True: + if _is_quantize(current_node) or _is_dequantize(current_node): + current_node = current_node.args[0] + else: + return current_node diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index febcd03913a..fadd58d39f6 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -19,10 +19,7 @@ from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec -from executorch.backends.nxp.backend.node_format_inference import ( - NodeFormat, - NodeFormatInference, -) +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops # noinspection PyProtectedMember @@ -74,12 +71,10 @@ def convert_program( :param custom_delegation_options: Custom user options which affect node delegation. :return: TFLite flatbuffers as bytes. """ - node_formats = NodeFormatInference(edge_program).identify_node_formats() parameters_mapping = self.map_inputs_to_parameters(edge_program) cc = self.build_conversion_context( parameters_mapping, - node_formats, neutron_target_spec, conversion_config, custom_delegation_options, @@ -106,7 +101,7 @@ def convert_program( def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext): for node in nodes: if node.op == "placeholder": - node_format = context.node_formats[node] + node_format = node.meta[NXP_NODE_FORMAT] if node.name in context.parameters_mapping: # Node is placeholder and has data -> append as static tensor with data @@ -119,7 +114,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex context.tflite_builder.append_as_fake_tensor(node, node_format) elif node.op == "call_function": # Node is call function -> append only output as a tensor - node_format = context.node_formats[node] + node_format = node.meta[NXP_NODE_FORMAT] context.tflite_builder.append_as_fake_tensor(node, node_format) elif node.op == "output": # Nothing to do @@ -177,7 +172,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet @staticmethod def build_conversion_context( parameters_mapping: dict, - node_formats: dict[Node, NodeFormat], neutron_target_spec: NeutronTargetSpec, conversion_config: ConversionConfig = _default_conversion_config, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, @@ -193,7 +187,6 @@ def build_conversion_context( tflite_builder, conversion_config, parameters_mapping, - node_formats, custom_delegation_options, ) diff --git a/backends/nxp/backend/ir/conversion_context.py b/backends/nxp/backend/ir/conversion_context.py index 6fb7e98424e..d4746fbde01 100644 --- a/backends/nxp/backend/ir/conversion_context.py +++ b/backends/nxp/backend/ir/conversion_context.py @@ -10,8 +10,6 @@ from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import ( AtenModelBuilderDirector, ) -from executorch.backends.nxp.backend.node_format_inference import NodeFormat -from torch import Node from torch.nn import Parameter @@ -19,7 +17,6 @@ class ConversionContext: tflite_builder: AtenModelBuilderDirector conversion_config: ConversionConfig parameters_mapping: dict[str, Parameter] - node_formats: dict[Node, NodeFormat] custom_delegation_options: CustomDelegationOptions def __init__( @@ -27,7 +24,6 @@ def __init__( tflite_builder: AtenModelBuilderDirector, conversion_config: ConversionConfig, parameters_mapping: dict, - node_formats: dict[Node, NodeFormat], custom_delegation_options: CustomDelegationOptions, ): """ @@ -39,5 +35,4 @@ def __init__( self.tflite_builder = tflite_builder self.conversion_config = conversion_config self.parameters_mapping = parameters_mapping - self.node_formats = node_formats self.custom_delegation_options = custom_delegation_options diff --git a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py index a420cea9aa7..51a4a226fc8 100644 --- a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py +++ b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py @@ -9,7 +9,7 @@ from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.node_format_inference import NodeFormat +from executorch.backends.nxp.backend.node_format import NodeFormat from torch.fx import Node from torch.nn import Parameter diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py index 22ca258cd4f..417be9eb3f4 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py @@ -8,7 +8,12 @@ from executorch.backends.nxp.backend.custom_delegation_options import ( CustomDelegationOptions, ) +from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node from executorch.backends.nxp.backend.ir.converter.conversion import translator +from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + apply_permutation_to, + create_channels_first_to_channels_last_permutation, +) from executorch.backends.nxp.backend.ir.converter.node_converter import ( _is_dequant_node, _is_quant_node, @@ -18,7 +23,9 @@ Concatenation, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter @@ -79,38 +86,28 @@ def _is_supported_on_target( if custom_delegation_options.force_delegate_cat: return True - dim = CatConverter._get_normalized_dim(node) - - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491 - if dim == 0: - return False + # Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the + # last dimension, depending on the formats of the node. + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # During conversion to IR, the shape will be permuted to channels last, and the dimension on index + # `1` will end up being the channels (last dim in NHWC). + channels_index = 1 + else: + # The shape will not be permuted during conversion, so the channels will remain the last dimension. + channels_index = -1 - # Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the - # last dimension, depending on the formats of the node. The format, however, cannot be determined - # during conversion, as it depends on what other nodes are delegated. input_channels = [ - # The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it - # will still be the channels in the IR. - _get_shape(input_)[1] - for input_ in node.all_input_nodes - ] + [ - # If the inputs/outputs are channels first, the last dimension will be the channels. - _get_shape(input_)[-1] - for input_ in node.all_input_nodes + _get_shape(input_)[channels_index] for input_ in node.all_input_nodes ] - if any( - (input_channel % neutron_target_spec.get_num_macs()) != 0 - for input_channel in input_channels - ): + output_channels = _get_shape(node)[channels_index] + + num_macs = neutron_target_spec.get_num_macs() + if any((input_channel % num_macs) != 0 for input_channel in input_channels): # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492 return False - output_channels = [_get_shape(node)[1], _get_shape(node)[-1]] - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 - if any( - (out_c % neutron_target_spec.get_num_macs()) != 0 - for out_c in output_channels - ): + if (output_channels % num_macs) != 0: + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 return False if len(node.all_input_nodes) < 2: # Not supported on Neutron @@ -132,6 +129,46 @@ def _is_supported_in_IR( return True + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + ): + # There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by + # `dim` are `!= 1`, the `Concat` is not delegated. + # This only happens when the inputs to the `Concat` are model inputs, and not outputs of other + # operators. + cat_partition = [p for p in partition_list if node in p.nodes][0] + cat_inputs = map(previous_non_qdq_node, node.args[0]) + + if not all( + input_.op == "call_function" and input_ in cat_partition.nodes + for input_ in cat_inputs + ): + # Some inputs of the `cat` are NOT in the same partition as `cat`. + dim = CatConverter._get_normalized_dim(node) + input_shapes = [list(n.meta["val"].shape) for n in node.args[0]] + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # Transform the shapes to channels last. + to_nhwc_perm = create_channels_first_to_channels_last_permutation( + len(node.meta["val"].shape), True + ) + input_shapes = [ + apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes + ] + + # Transform the `dim` to refer to a channels last dimension. + dim = to_nhwc_perm.index(dim) + + for input_shape in input_shapes: + if not any(d != 1 for d in input_shape[:dim]): + # Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension. + return False + + return True + def convert(self, node: Node): """Convert the 'aten.cat' operator to TFLite 'Concatenation'.""" self.assert_convertible(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py index 499541aa58c..29a8f7d51bb 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py @@ -27,6 +27,8 @@ pad_v2_options, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter @@ -40,9 +42,16 @@ def _is_supported_on_target( custom_delegation_options: CustomDelegationOptions, ) -> bool: paddings = node.args[1] - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension, which is not supported on Neutron. - return False + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # Dim `1` will end up being the channels. It is padded by paddings[4:6]. + if len(paddings) > 4 and paddings[4:6] != [0, 0]: + # Attempt to Pad channels dimension -> currently not supported + return False + else: + # Dim `-1` will end up being the channels. It is padded by paddings[:2]. + if len(paddings) > 0 and paddings[:2] != [0, 0]: + # Attempt to Pad channels dimension -> currently not supported + return False return True @@ -65,10 +74,6 @@ def _is_supported_in_IR( if not NodeConverter._has_shared_q_params_if_quantized(node): return False - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension -> currently not supported - return False - return True # noinspection PyMethodMayBeStatic diff --git a/backends/nxp/backend/ir/tensor_formatting.py b/backends/nxp/backend/ir/tensor_formatting.py index db24576e81f..32967ff047a 100644 --- a/backends/nxp/backend/ir/tensor_formatting.py +++ b/backends/nxp/backend/ir/tensor_formatting.py @@ -6,7 +6,7 @@ # from enum import Enum -from executorch.backends.nxp.backend.node_format_inference import NodeFormat +from executorch.backends.nxp.backend.node_format import NodeFormat class TensorFormat(Enum): diff --git a/backends/nxp/backend/node_format.py b/backends/nxp/backend/node_format.py new file mode 100644 index 00000000000..91049c200d7 --- /dev/null +++ b/backends/nxp/backend/node_format.py @@ -0,0 +1,23 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + +# Key into the `meta` attribute of nodes, which is mapped to their inferred node format. +NXP_NODE_FORMAT = "nxp_node_format" + + +class NodeFormat(Enum): + # Node's output in NCHW format + CHANNELS_FIRST = 0 + + # Node's output format has no meaning + FORMATLESS = 1 + + # Format has not been identified + NONE = 2 + + def is_channels_first(self) -> bool: + return self == NodeFormat.CHANNELS_FIRST diff --git a/backends/nxp/backend/node_format_inference.py b/backends/nxp/backend/node_format_inference.py index 76b05d172a4..3791616b1ab 100644 --- a/backends/nxp/backend/node_format_inference.py +++ b/backends/nxp/backend/node_format_inference.py @@ -4,30 +4,19 @@ # LICENSE file in the root directory of this source tree. import logging -from enum import Enum +import operator +from executorch.backends.nxp.backend.edge_program_converter import functions_converters +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload -from torch import Node from torch.export import ExportedProgram +from torch.fx import Node logger = logging.getLogger(__name__) -class NodeFormat(Enum): - # Node's output in NCHW format - CHANNELS_FIRST = 0 - - # Node's output format has no meaning - FORMATLESS = 1 - - # Format has not been identified - NONE = 2 - - def is_channels_first(self) -> bool: - return self == NodeFormat.CHANNELS_FIRST - - class NodeFormatInference: # Dictionary with Edge Aten ops that always use channels first format. # The op in the dictionary is mapped to a dictionary, which holds indices to input nodes @@ -43,8 +32,6 @@ class NodeFormatInference: # are channels first but output is formatless). ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default} - _node_format_mapping: dict[Node, NodeFormat] - _type_changed_during_last_run: bool # Mapping between Node and its ancestors (inputs) @@ -53,11 +40,13 @@ class NodeFormatInference: # Mapping between Node and its children (outputs) _node_outputs: dict[Node, list[Node]] + # List of all edge operations, which are supported by the converter. + _known_targets: list[EdgeOpOverload] + def __init__(self, edge_program: ExportedProgram): self._edge_program = edge_program self._nodes = edge_program.graph.nodes - self._node_format_mapping = {} self._node_inputs = { node: node.all_input_nodes for node in edge_program.graph.nodes } @@ -67,7 +56,14 @@ def __init__(self, edge_program: ExportedProgram): self._type_changed_during_last_run = False - def identify_node_formats(self) -> dict[Node, NodeFormat]: + self._known_targets = list(functions_converters) + [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + operator.getitem, + ] + + def identify_node_formats(self): self._type_changed_during_last_run = True # Re-run format inference until there are no changes @@ -77,7 +73,15 @@ def identify_node_formats(self) -> dict[Node, NodeFormat]: for node in self._nodes: self._infer_format_of_nodes(node) - return self._node_format_mapping + for node in self._nodes: + if self._get_node_op_type(node) is None: + continue + if not hasattr(node, "meta"): + logging.warning(f"Node `{node}` does not have the `meta` attribute.") + node.meta = {} + if NXP_NODE_FORMAT not in node.meta: + logging.warning(f"Node `{node}` does not have inferred format.") + node.meta[NXP_NODE_FORMAT] = NodeFormat.NONE def _infer_format_of_nodes(self, node: Node): op_type = self._get_node_op_type(node) @@ -93,9 +97,19 @@ def _infer_format_of_nodes(self, node: Node): logger.error( f"Node format inference for node type: {op_type} not found!" ) - else: + elif node.op != "call_function" or ( + hasattr(node, "target") and node.target in self._known_targets + ): + # Generic node, or tensor. self._handle_node_which_can_use_any_node_format(node) + else: + # Don't infer the format for unknown nodes. These nodes will never be delegated, so they will divide + # delegated partitions. Propagating the format here could unnecessarily enforce the format in one of these + # partitions, which would require extra transpositions. + for processed_node in self._node_inputs[node] + [node]: + self._assign_format_to_node(processed_node, NodeFormat.NONE) + def _infer_format_based_on_io_ranks(self, node: Node): """Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input and output. @@ -148,10 +162,14 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat): # Once CHANNEL_FIRST was assigned, we don't want to reassign return + if node_format is NodeFormat.NONE and old_node_format is not NodeFormat.NONE: + # A format has already been assigned to the node before. Don't replace it with `NONE`. + return + if old_node_format != node_format: self._type_changed_during_last_run = True - self._node_format_mapping[node] = node_format + node.meta[NXP_NODE_FORMAT] = node_format def _get_node_op_type(self, node: Node) -> str | None: """ @@ -252,8 +270,10 @@ def _node_produces_or_consumes_channels_first_format(self, node) -> bool: for ancestor_node in input_nodes ) - def _get_node_format(self, node): - return self._node_format_mapping.get(node, NodeFormat.NONE) + def _get_node_format(self, node) -> NodeFormat: + if not hasattr(node, "meta"): + node.meta = {} + return node.meta.get(NXP_NODE_FORMAT, NodeFormat.NONE) - def _node_is_placeholder(self, node: Node): + def _node_is_placeholder(self, node: Node) -> bool: return node.op == "placeholder" diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 917545e6c89..3d4775b0353 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -25,6 +25,7 @@ from torch.nn import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.nxp_backend import NeutronBackend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( @@ -374,6 +375,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: allows_single_node_partition=True, ) + # Identify the format (NCHW/NHWC/...) for all nodes in the graph, and store it in the `node.meta`. + # This format will be used by the `CapabilityBasedPartitioner` to determine which nodes will be delegated. + NodeFormatInference(exported_program).identify_node_formats() + iteration_limit = len(exported_program.graph.nodes) for _ in range(iteration_limit): # Run the partitioning. diff --git a/backends/nxp/nxp_backend.py b/backends/nxp/nxp_backend.py index 44e9a19d9f2..b133a588c03 100644 --- a/backends/nxp/nxp_backend.py +++ b/backends/nxp/nxp_backend.py @@ -14,6 +14,7 @@ import numpy as np import torch +from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, @@ -28,7 +29,6 @@ NeutronNodeArtifacts, ) from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager -from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.verification.verifier import EXIREdgeDialectVerifier diff --git a/backends/nxp/tests/executors.py b/backends/nxp/tests/executors.py index 9626a2779c4..65c2acbf9ce 100644 --- a/backends/nxp/tests/executors.py +++ b/backends/nxp/tests/executors.py @@ -20,11 +20,12 @@ ) from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from torch.export import ExportedProgram from torch.fx import Node from torch.fx.graph import Graph - # If executed on i.MX platform, there is no tensorflow module. And typically the intention is to use the tflite python # interpreter available in tflite_runtime try: @@ -308,6 +309,7 @@ def convert_run_compare( ) -> (TFLiteExecutor, EdgeProgramExecutor): if tfl_model is None: + NodeFormatInference(edge_program).identify_node_formats() tfl_model, _ = EdgeProgramToIRConverter().convert_program( edge_program, conversion_config ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py index 3df703f5bba..1196d4e0f8c 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py @@ -17,6 +17,8 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, + ToNCHWPreprocess, + ToNHWCPreprocess, ) from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -42,6 +44,18 @@ def forward(self, *inputs: torch.Tensor): return torch.cat(list(inputs), self.dim) +class AddCatModule(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, *inputs: torch.Tensor): + inputs = [input_ + input_ for input_ in inputs] + + return torch.cat(list(inputs), self.dim) + + class CatConvModule(torch.nn.Module): def __init__(self, dim: int, channels: int = 4): @@ -71,7 +85,7 @@ def forward(self, *inputs: torch.Tensor): ], ) def test_cat__same_shapes(dim, num_inputs, rank, mocker): - input_shape = tuple([2, 8, 8, 8, 8][-rank:]) + input_shape = tuple([8, 8, 8, 8][:rank]) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") @@ -126,15 +140,29 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker): exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), atol=1, ) -@pytest.mark.parametrize("dim", [0, -4]) -@pytest.mark.parametrize("num_inputs", [2]) -def test_cat__unsupported_dim__imxrt700(dim, num_inputs): - input_shape = (2, 8, 6, 8) - +@pytest.mark.parametrize( + "dim, input_shape", + [ + pytest.param(0, (1, 8, 8, 8), id="axis = 0"), + pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."), + pytest.param(-4, (1, 8, 8, 8), id="axis = -4"), + pytest.param(1, (1, 1, 8, 8), id="axis = 1"), + pytest.param(-3, (1, 1, 8, 8), id="axis = -3"), + pytest.param(2, (1, 1, 1, 8), id="axis = 2"), + pytest.param(-2, (1, 1, 1, 8), id="axis = -2"), + ], +) +def test_cat__unsupported__imxrt700(dim, input_shape): + """This test is conjoined with the one below (`test_cat__context_dependent__imxrt700`). + In this case, the inputs of the `cat` are NOT compute ops, so the `cat` is NOT delegated. + """ + num_inputs = 2 quantized_program = to_quantized_edge_program( CatModule(dim), [input_shape] * num_inputs, target="imxrt700" ).exported_program() @@ -148,6 +176,32 @@ def test_cat__unsupported_dim__imxrt700(dim, num_inputs): ) +@pytest.mark.parametrize( + "dim, input_shape", + [ + pytest.param(0, (1, 8, 8, 8), id="axis = 0"), + pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."), + pytest.param(-4, (1, 8, 8, 8), id="axis = -4"), + pytest.param(1, (1, 1, 8, 8), id="axis = 1"), + pytest.param(-3, (1, 1, 8, 8), id="axis = -3"), + pytest.param(2, (1, 1, 1, 8), id="axis = 2"), + pytest.param(-2, (1, 1, 1, 8), id="axis = -2"), + ], +) +def test_cat__context_dependent__imxrt700(dim, input_shape): + """This test is conjoined with the one above (`test_cat__unsupported__imxrt700`). + In this case, the inputs of the `cat` are compute ops, so the `cat` is delegated. + """ + num_inputs = 2 + ep = to_quantized_edge_program( + AddCatModule(dim), [input_shape] * num_inputs, target="imxrt700" + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.cat.default]) + assert any("lowered_module" in node.name for node in ep.graph.nodes) + + @pytest.mark.parametrize( "rank, num_inputs, dim", [ @@ -241,6 +295,8 @@ def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), atol=1, ) @@ -290,3 +346,78 @@ def test_cat__force_delegate(): graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] ) assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + +def test_cat__format_specific_support__formatless(mocker): + # The last dim will end up being the channels, as the format is `formatless`. + # Only the last dim satisfies the Neutron requirements for the channels. + input_shape = (3, 3, 3, 8) + num_inputs = 2 + dim = 2 + + input_shapes = [input_shape] * num_inputs + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + quantized_program = to_quantized_edge_program( + CatModule(dim), input_shapes + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + input_data = { + i: (np.random.random(shape) * 50).astype(np.int8) + for i, shape in enumerate(input_shapes) + } + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + input_data=input_data, + atol=1, + ) + + +def test_cat__format_specific_support__channels_first(mocker): + # The second dim will end up being the channels, as the format is `formatless`. + # Only the second dim satisfies the Neutron requirements for the channels. + input_shape = (3, 8, 3, 3) + num_inputs = 2 + dim = 2 + + input_shapes = [input_shape] * num_inputs + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + channels = ( + sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1] + ) + quantized_program = to_quantized_edge_program( + CatConvModule(dim, channels), input_shapes + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + input_data = { + i: (np.random.random(shape) * 50).astype(np.int8) + for i, shape in enumerate(input_shapes) + } + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + atol=1, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py index 47cd54c4efb..56be613a664 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -13,6 +13,7 @@ ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToNCHWPreprocess, ToNHWCPreprocess, ) @@ -20,6 +21,7 @@ ConstantPadNDConvModule, ConstantPadNDModule, ) +from executorch.exir.dialects._ops import ops as exir_ops @pytest.fixture(autouse=True) @@ -121,3 +123,51 @@ def test_constant_pad_nd__unsupported_paddings(input_shape, paddings): nodes = list(exec_program.graph.nodes) # There is at least one non-delegated Pad node assert any(node.name == "aten_constant_pad_nd_default" for node in nodes) + + +def test_constant_pad_nd__delegation__formatless__supported_padding(): + input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. + paddings = [0, 0, 1, 2, 3, 4] # The last dim is padded using the first 2 paddings. + model = ConstantPadNDModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was delegated. + assert not graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__formatless__unsupported_padding(): + input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. + paddings = [0, 1] # The last dim is padded using the first 2 paddings. + model = ConstantPadNDModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was NOT delegated. + assert graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__channels_first__supported_padding(): + input_shape = (2, 4, 6, 8) # Channels first -> the second dim (4) will be padded. + paddings = [1, 2, 3, 4, 0, 0] # The second dim is padded using the paddings[4:6]. + model = ConstantPadNDConvModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was delegated. + assert not graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__channels_first__unsupported_padding(): + input_shape = (2, 3, 6, 8) # Channels first -> the second dim (3) will be padded. + paddings = [0, 0, 0, 0, 1, 0] # The second dim is padded using the paddings[4:6]. + model = ConstantPadNDConvModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was NOT delegated. + assert graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py index 92af90b923d..b2e00fefc5a 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py @@ -11,6 +11,7 @@ EdgeProgramToIRConverter, ) from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program from executorch.backends.nxp.tests.executors import convert_run_compare from executorch.backends.nxp.tests.models import SoftmaxConvModule, SoftmaxModule @@ -56,6 +57,7 @@ def test_softmax_conversion__unknown_input_format(input_shape, dim: int): model = SoftmaxModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() # Currently this test not pass because the convertibility checker doesn't use tensor formats. with pytest.raises( @@ -78,6 +80,7 @@ def test_softmax_conversion_channel_last(input_shape, dim: int): model = SoftmaxConvModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() # TODO (Robert Kalmar) Currently this test not pass because the convertibility checker doesn't use tensor formats. with pytest.raises( @@ -104,6 +107,7 @@ def test_softmax_conversion_unsupported_dims(input_shape, dim: int): model = SoftmaxModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() with pytest.raises( AssertionError, match="`aten__softmax_default` is not convertible" diff --git a/backends/nxp/tests/test_neutron_converter_manager.py b/backends/nxp/tests/test_neutron_converter_manager.py index 2fcfd8cd987..5b105d7ef64 100644 --- a/backends/nxp/tests/test_neutron_converter_manager.py +++ b/backends/nxp/tests/test_neutron_converter_manager.py @@ -13,6 +13,7 @@ from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.tests.models import Conv2dModule @@ -23,6 +24,7 @@ def test_conv2d_neutron_conversion__default_flavor(): exir_program = torch.export.export(model, example_input) edge_program_manager = exir.to_edge(exir_program) + NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats() edge_program_converter = EdgeProgramToIRConverter() tflite_model, _ = edge_program_converter.convert_program( edge_program_manager.exported_program() @@ -43,6 +45,7 @@ def test__conv2d_neutron_conversion__invalid_flavor(): exir_program = torch.export.export(model, example_input) edge_program_manager = exir.to_edge(exir_program) + NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats() edge_program_converter = EdgeProgramToIRConverter() tflite_model, _ = edge_program_converter.convert_program( edge_program_manager.exported_program() diff --git a/backends/nxp/tests/test_node_format_inference.py b/backends/nxp/tests/test_node_format_inference.py index e2796187ce8..d0a73328037 100644 --- a/backends/nxp/tests/test_node_format_inference.py +++ b/backends/nxp/tests/test_node_format_inference.py @@ -9,6 +9,7 @@ from executorch.backends.nxp.backend.node_format_inference import ( NodeFormat, NodeFormatInference, + NXP_NODE_FORMAT, ) from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager from executorch.backends.nxp.tests.models import ( @@ -27,7 +28,7 @@ def test_convolution(): exir_program = torch.export.export(model, example_input) edge_program = exir.to_edge(exir_program).exported_program() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "p_conv_weight": NodeFormat.CHANNELS_FIRST, @@ -37,8 +38,8 @@ def test_convolution(): "output": NodeFormat.CHANNELS_FIRST, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] def test_softmax(): @@ -48,7 +49,7 @@ def test_softmax(): exir_program = torch.export.export(model, example_input) edge_program = exir.to_edge(exir_program).exported_program() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "x": NodeFormat.FORMATLESS, @@ -56,8 +57,8 @@ def test_softmax(): "output": NodeFormat.FORMATLESS, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] def test_maxpool2d(): @@ -78,7 +79,7 @@ def test_maxpool2d(): # Remove MaxPool-related "getitem" nodes from graph edge_program = NeutronPassManager(edge_program, [RemoveGetItemPass]).transform() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "x": NodeFormat.CHANNELS_FIRST, @@ -86,5 +87,5 @@ def test_maxpool2d(): "output": NodeFormat.CHANNELS_FIRST, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT]