diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index 6eef2017ec5..4189ac2dc47 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -19,7 +19,7 @@ from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec -from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops # noinspection PyProtectedMember @@ -63,7 +63,7 @@ def convert_program( conversion_config: ConversionConfig = _default_conversion_config, neutron_target_spec: NeutronTargetSpec = _default_target_spec, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, - ) -> (bytes, dict): + ) -> (bytes, dict[str, NodeFormat]): """ Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes. @@ -87,13 +87,16 @@ def convert_program( self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc) self._process_nodes(edge_program.graph.nodes, cc) - # Assign output - io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats( - edge_program.graph_signature - ) + # Assign the model its inputs and outputs. + cc.tflite_builder.assign_model_io_to_subgraph(edge_program.graph_signature) - # TFLite model generation + # Apply optimizations and finalize the model. internal_tflite_model = cc.tflite_builder.finish() + + # Extract the formats of the model's inputs and outputs. + io_formats = cc.tflite_builder.get_io_formats(edge_program.graph_signature) + + # TFLite model generation flatbuffers_builder = flatbuffers.Builder() internal_tflite_model.gen_tflite(flatbuffers_builder) diff --git a/backends/nxp/backend/ir/conversion_config.py b/backends/nxp/backend/ir/conversion_config.py index 622735e881f..4ba66adc942 100644 --- a/backends/nxp/backend/ir/conversion_config.py +++ b/backends/nxp/backend/ir/conversion_config.py @@ -13,7 +13,7 @@ def __init__(self, args: dict | None = None): :param args: Optional dictionary with conversion arguments. Unknown arguments are ignored. """ - self.keep_io_format: bool = False + self.use_neutron_for_format_conversion: bool = True self.allow_inputs_stripping: bool = True self.qdq_aware_conversion: bool = True self.symbolic_dimensions_mapping: dict[str, int] | None = None diff --git a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py index 51a4a226fc8..658b4fc93f7 100644 --- a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py +++ b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py @@ -88,19 +88,40 @@ def append_operators(self, ops_to_add: list[tflite_model.Operator]): self.check_and_append_operator(op) - def assign_model_io_to_subgraph_and_get_io_formats( - self, graph_signature - ) -> dict[str, dict]: - """ - Assign model's inputs/outputs to SubGraph. + def get_io_formats(self, graph_signature) -> dict[str, dict[str, TensorFormat]]: + """Get a mapping from tensor names to their formats. - :param graph_signature: Instance of GraphSignature. + :param graph_signature: Instance of GraphSignature. :returns: Mapping between IO tensors' names and their formats. """ io_formats = { "inputs": {}, "outputs": {}, } + for input_name in graph_signature.user_inputs: + tensor = self.tensor_for_name(input_name) + assert input_name == tensor.name, ( + "Program's input name doesn't match with tensor name in TFLite. " + "Input was probably redirected." + ) + io_formats["inputs"][tensor.name] = tensor.tensor_format + + for output_name in graph_signature.user_outputs: + tensor = self.tensor_for_name(output_name) + assert output_name == tensor.name, ( + "Program's output name doesn't match with tensor name in TFLite. " + "Output was probably redirected." + ) + io_formats["outputs"][tensor.name] = tensor.tensor_format + + return io_formats + + def assign_model_io_to_subgraph(self, graph_signature): + """ + Assign model's inputs/outputs to SubGraph. + + :param graph_signature: Instance of GraphSignature. + """ self.get_sub_graph().inputs = tflite_model.SubGraphInputs() for input_name in graph_signature.user_inputs: @@ -110,7 +131,6 @@ def assign_model_io_to_subgraph_and_get_io_formats( "Input was probably redirected." ) self.get_sub_graph().inputs.tmp_inputs.append(tensor) - io_formats["inputs"][tensor.name] = tensor.tensor_format self.get_sub_graph().outputs = tflite_model.SubGraphOutputs() for output_name in graph_signature.user_outputs: @@ -120,7 +140,3 @@ def assign_model_io_to_subgraph_and_get_io_formats( "Output was probably redirected." ) self.get_sub_graph().outputs.tmp_outputs.append(tensor) - - io_formats["outputs"][tensor.name] = tensor.tensor_format - - return io_formats diff --git a/backends/nxp/backend/ir/converter/builder/model_builder.py b/backends/nxp/backend/ir/converter/builder/model_builder.py index 643a6231d15..cfd80d8e300 100755 --- a/backends/nxp/backend/ir/converter/builder/model_builder.py +++ b/backends/nxp/backend/ir/converter/builder/model_builder.py @@ -5,7 +5,9 @@ # License: MIT # See the LICENSE_MIT for more details. # + from copy import deepcopy +from itertools import chain from typing import Dict, List, Optional, Union import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator @@ -48,6 +50,9 @@ FlexTranspose, ) from executorch.backends.nxp.backend.ir.tflite_optimizer import optimizer +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec @@ -218,7 +223,7 @@ def channels_first_version_of(self, t_tensor: tflite_model.Tensor): new_tensor.shape = translator.channels_last_shape_to_channels_first( t_tensor.shape ) - new_tensor.tensor_format = new_tensor.tensor_format.to_node_format() + new_tensor.tensor_format = TensorFormat.CHANNELS_FIRST perm = translator.create_channels_last_to_channels_first_permutation( t_tensor.rank @@ -355,6 +360,19 @@ def _make_inputs_channels_first(self): if input_tensor.tensor_format.is_channels_last(): # Create a Transpose operator and replace the graph input + new_input_shape = translator.channels_last_shape_to_channels_first( + input_tensor.shape + ) + perm = translator.create_channels_first_to_channels_last_permutation( + input_tensor.rank + ) + + if not transposition_is_supported_on_neutron( + new_input_shape.vector, list(perm), self.neutron_target_spec + ): + new_inputs.append(input_tensor) + continue + if input_tensor.rank > 6: msg = ( f"Couldn't preserve the shape of input tensor '{input_tensor.name}', because it has " @@ -365,14 +383,9 @@ def _make_inputs_channels_first(self): new_input = self.duplicate_tensor( input_tensor, input_tensor.name + "_channels_first" ) - new_input.shape = translator.channels_last_shape_to_channels_first( - input_tensor.shape - ) - new_input.tensor_format = input_tensor.tensor_format.to_node_format() + new_input.shape = new_input_shape + new_input.tensor_format = TensorFormat.CHANNELS_FIRST - perm = translator.create_channels_first_to_channels_last_permutation( - input_tensor.rank - ) transpose = self._create_transpose_operator( new_input, input_tensor, perm ) @@ -397,6 +410,16 @@ def _make_outputs_channels_first(self): if output_tensor.tensor_format.is_channels_last(): # Add a Transpose operator, to make the output channels first + shape = output_tensor.shape.vector + perm = translator.create_channels_last_to_channels_first_permutation( + len(shape), True + ) + if not transposition_is_supported_on_neutron( + shape, perm, self.neutron_target_spec + ): + new_outputs.append(output_tensor) + continue + if output_tensor.rank > 6: logger.e( logger.Code.IO_PRESERVATION_ERROR, @@ -437,6 +460,14 @@ def _keep_one_empty_buffer(self): # It's safe to replace the buffer. t.tmp_buffer = empty_buffer + def replace_io_tensor_format_with_node_format(self): + for t in chain( + self.get_sub_graph().inputs.tmp_inputs, + self.get_sub_graph().outputs.tmp_outputs, + ): + if isinstance(t.tensor_format, TensorFormat): + t.tensor_format = t.tensor_format.to_equal_node_format() + def finish(self) -> tflite_model.Model: """Finalize and optimize the converted TFLite model. Then return it. @@ -444,19 +475,23 @@ def finish(self) -> tflite_model.Model: :return: The final TFLite model. """ - if self.conversion_config.keep_io_format: + if self.conversion_config.use_neutron_for_format_conversion: # If the input or output is channels last, add a Transpose operator, to make is channels first. self._make_inputs_channels_first() self._make_outputs_channels_first() # Apply optimizations to the internal TFLite model. - optimizer.Optimizer(self, self.conversion_config).optimize( + optimizer.Optimizer( + self, self.conversion_config, self.neutron_target_spec + ).optimize( self.conversion_config.optimization_whitelist, self.conversion_config.optimization_blacklist, ) self._keep_one_empty_buffer() + self.replace_io_tensor_format_with_node_format() + # Remove outputs, which are not produced by any node. Otherwise, there would be errors after inference. operator_outputs = [] for op in self.get_operators().vector: diff --git a/backends/nxp/backend/ir/converter/node_converter.py b/backends/nxp/backend/ir/converter/node_converter.py index 36266486aac..b69861f85b0 100755 --- a/backends/nxp/backend/ir/converter/node_converter.py +++ b/backends/nxp/backend/ir/converter/node_converter.py @@ -185,6 +185,14 @@ def builder(self) -> AtenModelBuilderDirector: """ return self.context.tflite_builder + @property + def neutron_target_spec(self) -> NeutronTargetSpec: + """ + Get an instance of NeutronTargetSpec from the conversion context. + :return: NeutronTargetSpec instance. + """ + return self.builder.neutron_target_spec + def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator: """ Create TFLite op wrapper with input/output tensors added into 'tmp_inputs' and 'tmp_outputs'. diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py index f0150b4bc1f..35bef6c8035 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py @@ -4,28 +4,438 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import torch +from executorch.backends.nxp.backend.edge_helper import ( + node_is_effectively_static_tensor, +) +from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext from executorch.backends.nxp.backend.ir.converter import quantization_utils +from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + NeutronTargetSpec, NodeConverter, ) +from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( transpose_options, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + is_tensor_invariant_permutation, + transposition_is_supported_on_neutron, +) +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter +Permutation = list[int] +PermutationSupportDict = dict[str, dict[str, bool | Permutation]] + + +def _get_shape(node: torch.fx.Node) -> list[int]: + return list(node.meta["val"].shape) + + +def get_supported_transpositions( + node: Node, neutron_target_spec: NeutronTargetSpec +) -> PermutationSupportDict: + """Since ExecuTorch and NeutronIR use different tensor formats, we must consider the different possible cases + which may occur. The main permutation is always done on channels_first/formatless data, and the output is + channels_first/formatless as well. If this is not the case, a `Transpose` is inserted before and/or + after the main `Transpose`, to make the input/output channels_first. These additional `Transpose` + ops must be supported by Neutron as well. Alternatively, consecutive `Transpose` ops can be fused + together. It is possible for a pair of unsupported permutation to result in a supported one. + Therefore, the merged permutations must also be considered. + + This function identifies which of these permutations are supported on neutron, and returns a dictionary with the + support summary and the corresponding permutations. + + :param node: The `permute_copy` node to base the support analysis from/ + :param neutron_target_spec: NeutronTagetSpec instance. + :return: A dictionary containing the support status and permutation, for all the possible permutations which may be + used during the conversion of the `node`. + """ + + input_shape = node.args[0].meta["val"].shape + output_shape = node.meta["val"].shape + perm = list(node.args[1]) + + to_nchw_perm = translator.create_channels_last_to_channels_first_permutation( + len(input_shape), True + ) + to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation( + len(input_shape), True + ) + channels_last_input_shape = translator.apply_permutation_to( + input_shape, to_nhwc_perm + ) + + main_perm_supported = transposition_is_supported_on_neutron( + input_shape, perm, neutron_target_spec + ) + + # "To NCHW" permutation, in case the input is channels last. + separate_pre_transpose_supported = transposition_is_supported_on_neutron( + channels_last_input_shape, to_nchw_perm, neutron_target_spec + ) + # The main permutation and the previous one merged. + merged_pre_transpose_supported = transposition_is_supported_on_neutron( + channels_last_input_shape, + merged_pre_transpose_permutation := translator.combine_permutations( + to_nchw_perm, perm + ), + neutron_target_spec, + ) + + # "To NHWC" permutation after the main `Transpose`. + separate_post_transpose_supported = transposition_is_supported_on_neutron( + output_shape, to_nhwc_perm, neutron_target_spec + ) + + # The main permutation and the previous one merged. + merged_post_transpose_supported = transposition_is_supported_on_neutron( + input_shape, + merged_post_transpose_permutation := translator.combine_permutations( + perm, to_nhwc_perm + ), + neutron_target_spec, + ) + + # "To NCHW", main permutation, and "to NHWC" all merged. + everything_merged_supported = transposition_is_supported_on_neutron( + input_shape, + everything_merged_permutation := translator.combine_permutations( + translator.combine_permutations(to_nchw_perm, perm), to_nhwc_perm + ), + neutron_target_spec, + ) + + return { + "main": {"supported": main_perm_supported, "perm": perm}, + "separate_pre": { + "supported": separate_pre_transpose_supported, + "perm": to_nchw_perm, + }, + "merged_pre": { + "supported": merged_pre_transpose_supported, + "perm": merged_pre_transpose_permutation, + }, + "separate_post": { + "supported": separate_post_transpose_supported, + "perm": to_nhwc_perm, + }, + "merged_post": { + "supported": merged_post_transpose_supported, + "perm": merged_post_transpose_permutation, + }, + "everything_merged": { + "supported": everything_merged_supported, + "perm": everything_merged_permutation, + }, + } + + +class PermuteCopyFormatHandler: + def __init__(self, context: ConversionContext): + self.context = context + + @property + def neutron_target_spec(self): + return self.context.tflite_builder.neutron_target_spec + + @property + def builder(self): + return self.context.tflite_builder + + def _handle_channels_first_input_and_formatless_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # The input must be permuted. + # Either combine the permutations, or prepend a `Transpose` operator. + + if node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + # The input is static, so the operator will be removed by an optimization. + perm = perm_dict["main"]["perm"] + + elif perm_dict["merged_pre"]["supported"]: + # Use the combined permutation. + perm = perm_dict["merged_pre"]["perm"] + + elif perm_dict["separate_pre"]["supported"] and perm_dict["main"]["supported"]: + # Prepend a `Transpose` operator to make the input channels first. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_formatless_input_and_channels_first_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # The output must be permuted. + # Either combine the permutations, or append a `Transpose` operator. + + if node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + # The input is static, so the operator will be removed by an optimization. + perm = perm_dict["main"]["perm"] + + elif perm_dict["merged_post"]["supported"]: + # Use the combined permutation. + perm = perm_dict["merged_post"]["perm"] + + elif perm_dict["main"]["supported"] and perm_dict["separate_post"]["supported"]: + # Append a `Transpose` operator to make the output channels first. + perm = perm_dict["main"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_channels_first_input_and_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # Both input and output must be permuted, or some merged permutations must be supported. + if perm_dict["everything_merged"]["supported"]: + # Combine all 3 permutations into 1. + perm = perm_dict["everything_merged"]["perm"] + + elif ( + perm_dict["merged_pre"]["supported"] + and perm_dict["separate_post"]["supported"] + ): + # Combine the input and main permutations, and append a `Transpose` to handle the output permutation. + perm = perm_dict["merged_pre"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif ( + perm_dict["separate_pre"]["supported"] + and perm_dict["merged_post"]["supported"] + ): + # Prepend a `Transpose` to handle the input permutation, and combine the main and output permutations. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["everything_merged"]["supported"] + + elif ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + and perm_dict["separate_post"]["supported"] + ): + # Handle each permutation separately. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["main"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_formatless_input_and_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # Neither the input nor the output have to be permuted. + if perm_dict["main"]["supported"]: + perm = perm_dict["main"]["perm"] + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + return perm + + def handle_tensor_formats(self, t_op: tflite_model.Operator, node: Node) -> OpsList: + """Due to the different tensor formats used by ExecuTorch and NeutronIR, it may be necessary to modify the + permutation, or insert extra permutations to equalize the tensor formats. + This method identifies the four possible cases of input/output formats, and finds the conversion solution + which minimizes the number of necessary `Transpose` operators. + """ + perm_dict = get_supported_transpositions(node, self.neutron_target_spec) + + ops = OpsList(middle_op=t_op) + input_format, output_format = ( + node.args[0].meta[NXP_NODE_FORMAT], + node.meta[NXP_NODE_FORMAT], + ) + if input_format.is_channels_first() and (not output_format.is_channels_first()): + perm = self._handle_channels_first_input_and_formatless_output( + perm_dict, node, t_op, ops + ) + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + perm = self._handle_formatless_input_and_channels_first_output( + perm_dict, node, t_op, ops + ) + + elif input_format.is_channels_first() and output_format.is_channels_first(): + perm = self._handle_channels_first_input_and_output( + perm_dict, node, t_op, ops + ) + + else: + perm = self._handle_formatless_input_and_output(perm_dict, node, t_op, ops) + + perm_tensor = self.builder.create_tensor_for_data( + np.array(perm, "int32"), "perm" + ) + + # Use the final permutation as the operator's second input. + t_op.tmp_inputs = [t_op.tmp_inputs[0], perm_tensor] + + return ops + class PermuteCopyConverter(NodeConverter): + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if node_is_effectively_static_tensor(node.args[0], parameters_mapping): + return ( + True # The operator computes on static data. It will be removed later. + ) + + input_shape = _get_shape(node.args[0]) + perm = list(node.args[1]) + + to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation( + len(input_shape), True + ) + channels_last_input_shape = translator.apply_permutation_to( + input_shape, to_nhwc_perm + ) + + if is_tensor_invariant_permutation( + input_shape, perm + ) and is_tensor_invariant_permutation(channels_last_input_shape, perm): + # The `permute_copy` can always be represented as a Reshape. + return True + + perm_dict = get_supported_transpositions(node, neutron_target_spec) + + input_format, output_format = ( + node.args[0].meta[NXP_NODE_FORMAT], + node.meta[NXP_NODE_FORMAT], + ) + if input_format.is_channels_first() and (not output_format.is_channels_first()): + # Just the input must be permuted. + return ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + ) or perm_dict["merged_pre"]["supported"] + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + # Just the output must be permuted. + return ( + perm_dict["separate_post"]["supported"] + and perm_dict["main"]["supported"] + ) or perm_dict["merged_post"]["supported"] + + elif input_format.is_channels_first() and output_format.is_channels_first(): + # Both input and output must be permuted. + return ( + # Separate IO transpositions. + ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + and perm_dict["separate_post"]["supported"] + ) + # Separate input, merged output. + or ( + perm_dict["separate_pre"]["supported"] + and perm_dict["merged_post"]["supported"] + ) + # Merged input, separate output. + or ( + perm_dict["merged_pre"]["supported"] + and perm_dict["separate_post"]["supported"] + ) + # Merged input and output. + or perm_dict["everything_merged"]["supported"] + ) + else: + # Simplest case. No format changes required. + return perm_dict["main"]["supported"] + @staticmethod def _is_supported_in_IR( node: Node, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: + if not NodeConverter._has_shared_q_params_if_quantized(node): + return False + return True def convert(self, node: Node): @@ -53,13 +463,6 @@ def convert(self, node: Node): "match. This indicates error in quantizer." ) - perm = np.array(node.args[1], "int32") - perm_tensor = self.builder.create_tensor_for_data(perm, "perm") - - # Assign the operator its TFLite inputs and outputs - t_op.tmp_inputs = [x, perm_tensor] - t_op.tmp_outputs = [y] - - ops_to_add = OpsList(middle_op=t_op) + ops = PermuteCopyFormatHandler(self.context).handle_tensor_formats(t_op, node) - self.builder.append_operators(ops_to_add.flatten()) + self.builder.append_operators(ops.flatten()) diff --git a/backends/nxp/backend/ir/tensor_formatting.py b/backends/nxp/backend/ir/tensor_formatting.py index 32967ff047a..71b697a0eba 100644 --- a/backends/nxp/backend/ir/tensor_formatting.py +++ b/backends/nxp/backend/ir/tensor_formatting.py @@ -38,8 +38,10 @@ def is_channels_last(self) -> bool: @staticmethod def from_node_format(node_format: NodeFormat): - if node_format.is_channels_first(): - return TensorFormat.CHANNELS_LAST + if node_format == NodeFormat.CHANNELS_FIRST: + return TensorFormat.CHANNELS_LAST # Format is swapped. + elif node_format == NodeFormat.CHANNELS_LAST: + return TensorFormat.CHANNELS_FIRST # Format is swapped. elif node_format == NodeFormat.FORMATLESS: return TensorFormat.FORMATLESS else: @@ -47,8 +49,21 @@ def from_node_format(node_format: NodeFormat): def to_node_format(self): if self == TensorFormat.CHANNELS_LAST: - return NodeFormat.CHANNELS_FIRST + return NodeFormat.CHANNELS_FIRST # Format is swapped. elif self == TensorFormat.FORMATLESS: return NodeFormat.FORMATLESS + elif self == TensorFormat.CHANNELS_FIRST: + return NodeFormat.CHANNELS_LAST # Format is swapped. else: return NodeFormat.NONE + + def to_equal_node_format(self): + match self: + case TensorFormat.CHANNELS_FIRST: + return NodeFormat.CHANNELS_FIRST + case TensorFormat.CHANNELS_LAST: + return NodeFormat.CHANNELS_LAST + case TensorFormat.FORMATLESS: + return NodeFormat.FORMATLESS + case _: + return NodeFormat.NONE diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py index 6001ca961b8..18e397cc1bd 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py @@ -12,16 +12,21 @@ InputTensorToOpsMap, OutputTensorToOpMap, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class BaseOptimization(ABC): _builder: "model_builder.ModelBuilder" def __init__( - self, builder: "model_builder.ModelBuilder", conversion_config: ConversionConfig + self, + builder: "model_builder.ModelBuilder", + conversion_config: ConversionConfig, + neutron_target_spec: NeutronTargetSpec, ): self._builder = builder self._conversion_config = conversion_config + self.neutron_target_spec = neutron_target_spec def _create_tensor_to_operator_dictionaries( self, diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py index 0be46efcaa8..053e53d9df8 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py @@ -24,10 +24,14 @@ TensorIsNotModelOutput, TensorsHaveData, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) class FuseTransposeOperators(BaseOptimization): - """Remove some `Transpose` operators in the following pattern. + """Remove some `Transpose` operators in the following pattern. This is only done if the resulting permutation is + supported on Neutron. │ 'x' ┌─────▼─────┐ @@ -61,12 +65,27 @@ def __call__(self) -> bool: ) in matcher.match_patterns(): x = tensor_map["x"] perm1 = tensor_map["perm1"].tmp_buffer.data + combined_perms = [] # Remove the leading transpose. for second_transpose in following_transposes: # Combine the permutations for a new permutation of the second `Transpose`. perm2 = second_transpose.tmp_inputs[1].tmp_buffer.data - combined_perm = np.array(combine_permutations(perm1, perm2), np.int32) + combined_perms.append( + np.array(combine_permutations(perm1, perm2), np.int32) + ) + + if not all( + transposition_is_supported_on_neutron( + x.shape.vector, list(perm), self.neutron_target_spec + ) + for perm in combined_perms + ): + continue # Avoid creating an unsupported permutation. + + for second_transpose, combined_perm in zip( + following_transposes, combined_perms + ): second_transpose.tmp_inputs[1] = self._builder.create_tensor_for_data( combined_perm, "perm" ) diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizer.py b/backends/nxp/backend/ir/tflite_optimizer/optimizer.py index 52de6f224eb..1a96422e377 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizer.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizer.py @@ -18,6 +18,7 @@ FuseTransposeOperators, RemoveIdentityTransposeOperators, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class Optimization(Enum): @@ -50,18 +51,19 @@ def __init__( self, builder: "model_builder.ModelBuilder", # noqa F821 conversion_config: ConversionConfig, + neutron_target_spec: NeutronTargetSpec, ): self._builder = builder self.optimization_map = { Optimization.FUSE_TRANSPOSE_OPERATORS: FuseTransposeOperators( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.REMOVE_IDENTITY_TRANSPOSE_OPERATORS: RemoveIdentityTransposeOperators( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE: PermuteFullyConnectedWeightsAfterReshape( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), } diff --git a/backends/nxp/backend/neutron_converter_manager.py b/backends/nxp/backend/neutron_converter_manager.py index a6884a9ee24..7124929411e 100644 --- a/backends/nxp/backend/neutron_converter_manager.py +++ b/backends/nxp/backend/neutron_converter_manager.py @@ -2,6 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import importlib import logging import multiprocessing @@ -75,6 +76,7 @@ def convert(self, tflite_model: bytes, target: str) -> bytes: cctx = self.neutron_converter.CompilationContext() cctx.targetOpts = self.neutron_converter.getNeutronTarget(target) cctx.compilationOpts.minNumOpsPerGraph = 1 + cctx.compilationOpts.excludeGraphPasses = "MergeTranspose" logger = multiprocessing.log_to_stderr() logger.setLevel(logging.WARNING) diff --git a/backends/nxp/backend/neutron_operator_support.py b/backends/nxp/backend/neutron_operator_support.py new file mode 100644 index 00000000000..cdb46870b2e --- /dev/null +++ b/backends/nxp/backend/neutron_operator_support.py @@ -0,0 +1,79 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + + +def is_tensor_invariant_permutation( + input_shape: list[int], permutation: list[int] +) -> bool: + def input_dim_is_not_one(index): + return input_shape[index] != 1 + + new_permutation = list(filter(input_dim_is_not_one, permutation)) + + return new_permutation == sorted(new_permutation) + + +def transposition_is_supported_on_neutron( + input_shape: list[int], + permutation: list[int], + neutron_target_spec: NeutronTargetSpec, +) -> bool: + """This function determines if the current NeutronSoftware properly supports a `Transpose` operator with given + `input_shape` and `permutation`. + + :param input_shape: The shape of the main input tensor of the `Transpose` operator. + :param permutation: The permutation the `Transpose` operator is computing. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. + """ + num_macs = neutron_target_spec.get_num_macs() + + if is_tensor_invariant_permutation(input_shape, permutation): + # The `Transpose` will be turned into a `Reshape` by Neutron. The check includes the identity permutation. + return True + + if permutation == [0, 3, 1, 2]: + # NHWC -> NCHW + n, h, w, c = input_shape + + if h * w * c % num_macs != 0: # Official Neutron requirement. + return False + + if not ( + c % num_macs == 0 and h * w % num_macs == 0 + ): # Neutron would produce incorrect outputs. + return False + + if n != 1: + # Neutron only supports `Transpose` operators where the dimensions can be combined into 2 consecutive + # groups. These 2 groups are then transposed like a matrix, and the result is reshaped. Therefore, for the + # [0, 3, 1, 2] permutation, when h * w != 1 and c != 1, batch size must be 1. + return False + + return True + + elif permutation == [0, 2, 3, 1]: + # NCHW -> NHWC + + n, c, h, w = input_shape + + if w % num_macs != 0: # Official Neutron requirement. + return False + + if not ( + c % num_macs == 0 and h * w % num_macs == 0 + ): # Neutron would produce incorrect outputs. + return False + + if n != 1: + # Neutron only supports `Transpose` operators where the dimensions can be combined into 2 consecutive + # groups. These 2 groups are then transposed like a matrix, and the result is reshaped. Therefore, for the + # [0, 2, 3, 1] permutation, when h * w != 1 and c != 1, batch size must be 1. + return False + + return True + + return False diff --git a/backends/nxp/backend/node_format.py b/backends/nxp/backend/node_format.py index 91049c200d7..fd54e2365ed 100644 --- a/backends/nxp/backend/node_format.py +++ b/backends/nxp/backend/node_format.py @@ -19,5 +19,8 @@ class NodeFormat(Enum): # Format has not been identified NONE = 2 + # NHWC + CHANNELS_LAST = 3 + def is_channels_first(self) -> bool: return self == NodeFormat.CHANNELS_FIRST diff --git a/backends/nxp/backend/node_format_inference.py b/backends/nxp/backend/node_format_inference.py index 78f8dff8c32..244fd76d588 100644 --- a/backends/nxp/backend/node_format_inference.py +++ b/backends/nxp/backend/node_format_inference.py @@ -30,7 +30,10 @@ class NodeFormatInference: # A set of Edge Aten ops, which have the ability to change the format (for example - input nodes # are channels first but output is formatless). - ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default} + ops_that_can_change_tensor_format = { + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + } _type_changed_during_last_run: bool @@ -88,11 +91,23 @@ def _infer_format_of_nodes(self, node: Node): if op_type in self.ops_with_channels_first_nodes: self._handle_node_which_uses_channels_first_format(node) + elif op_type in self.ops_that_can_change_tensor_format: - if op_type == exir_ops.edge.aten.view_copy.default: # view_copy + if op_type in [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + ]: + # Try to assign the `formatless` format to the input and output. The converter will then handle the + # transition. + # Note: If the format for the input/output has already been assigned as channels first, it will NOT be + # overwritten. self._assign_format_to_node( self._node_outputs[node][0], NodeFormat.FORMATLESS ) + self._assign_format_to_node( + self._node_inputs[node][0], NodeFormat.FORMATLESS + ) + else: logger.error( f"Node format inference for node type: {op_type} not found!" diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 6be4495d615..f89bac55bc5 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -208,12 +208,13 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]): exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405 exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405 exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 + exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 + exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405 exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 - exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 } diff --git a/backends/nxp/nxp_backend.py b/backends/nxp/nxp_backend.py index b133a588c03..457fa335ba6 100644 --- a/backends/nxp/nxp_backend.py +++ b/backends/nxp/nxp_backend.py @@ -14,16 +14,17 @@ import numpy as np import torch -from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass +from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NodeFormat from executorch.backends.nxp.neutron_node_extraction import ( extract_artifacts_from_neutron_node, NeutronNodeArtifacts, @@ -44,6 +45,7 @@ def __init__(self): self.output_format = None self.operators_not_to_delegate: List[str] = [] self.neutron_converter_flavor = None + self.use_neutron_for_format_conversion = True def _replace_colons(self, operator: str) -> str: """ @@ -57,6 +59,7 @@ def neutron_compile_spec( neutron_converter_flavor: str, extra_flags: Optional[str] = None, operators_not_to_delegate: Optional[List[str]] = None, + use_neutron_for_format_conversion: bool = True, ): """ Generate compile spec for Neutron NPU @@ -67,6 +70,9 @@ def neutron_compile_spec( "'neutron_converter_SDK_25_09' has flavor 'SDK_25_09'. extra_flags: Extra flags for the Neutron compiler operators_not_to_delegate: List of operators that should not be delegated + use_neutron_for_format_conversion: If True, the EdgeProgramToIRConverter will insert `Transpose` ops to + ensure that the IO matches the executorch partition, which will be + delegated to Neutron. """ self.neutron_converter_flavor = neutron_converter_flavor @@ -86,6 +92,8 @@ def neutron_compile_spec( self._replace_colons(op) for op in operators_not_to_delegate ] + self.use_neutron_for_format_conversion = use_neutron_for_format_conversion + return self def build(self): @@ -104,6 +112,10 @@ def build(self): "operators_not_to_delegate", ",".join(self.operators_not_to_delegate).encode(), ), + CompileSpec( + "use_neutron_for_format_conversion", + f"{self.use_neutron_for_format_conversion}".encode(), + ), ] return self.compile_spec @@ -115,6 +127,7 @@ def generate_neutron_compile_spec( system_config: Optional[str] = None, extra_flags: Optional[str] = None, operators_not_to_delegate: Optional[List[str]] = None, + use_neutron_for_format_conversion: bool = True, ) -> List[CompileSpec]: return ( NeutronCompileSpecBuilder() @@ -123,6 +136,7 @@ def generate_neutron_compile_spec( neutron_converter_flavor, extra_flags=extra_flags, operators_not_to_delegate=operators_not_to_delegate, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, ) .build() ) @@ -145,6 +159,7 @@ def preprocess( # noqa C901 binary = bytes() target = "" neutron_converter_flavor = "" + use_neutron_for_format_conversion = None for spec in compile_spec: if spec.key == "output_format": output_format = spec.value.decode() @@ -154,6 +169,8 @@ def preprocess( # noqa C901 compile_flags.append(spec.value.decode()) if spec.key == "neutron_converter_flavor": neutron_converter_flavor = spec.value.decode() + if spec.key == "use_neutron_for_format_conversion": + use_neutron_for_format_conversion = spec.value.decode() == "True" # Check that the output format is set in the compile spec if not output_format: @@ -180,9 +197,15 @@ def preprocess( # noqa C901 ).transform() # Convert the edge program to TFLite. + conversion_config = ConversionConfig( + {"use_neutron_for_format_conversion": use_neutron_for_format_conversion} + if use_neutron_for_format_conversion is not None + else {} + ) tflite_model, io_formats = EdgeProgramToIRConverter().convert_program( edge_program, neutron_target_spec=NeutronTargetSpec(target, neutron_converter_flavor), + conversion_config=conversion_config, ) neutron_model = NeutronConverterManager(neutron_converter_flavor).convert( @@ -241,7 +264,9 @@ def _format_string_for_array(self, array: np.ndarray) -> str: return f"{array.size}s{self._padding_format_string_for_array(array)}" - def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: + def _create_payload_header( + self, io_formats: dict[str, list[NodeFormat]], neutron_artifacts + ) -> np.ndarray: """ Create bytes header for returned payload. It contains information about input and output tensor formats. Tensors are ordered based on graph signature @@ -279,9 +304,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: for input_name in neutron_artifacts.input_names: try: header_data.append( - 1 - if inputs[input_name.decode()] == TensorFormat.CHANNELS_LAST - else 0 + 1 if inputs[input_name.decode()] == NodeFormat.CHANNELS_LAST else 0 ) except KeyError: raise AssertionError( @@ -292,7 +315,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: try: header_data.append( 1 - if outputs[output_name.decode()] == TensorFormat.CHANNELS_LAST + if outputs[output_name.decode()] == NodeFormat.CHANNELS_LAST else 0 ) except KeyError: @@ -331,7 +354,9 @@ def _pack_with_alignment( neutron_artifacts.kernels.tobytes(), ) - def get_binary_payload(self, io_formats, neutron_model) -> bytes: + def get_binary_payload( + self, io_formats: dict[str, list[NodeFormat]], neutron_model + ) -> bytes: """ Get binary payload for provided input/output tensor formats and neutron_model. Returned data have following structure: @@ -351,7 +376,7 @@ def get_binary_payload(self, io_formats, neutron_model) -> bytes: Tensor format definition: '0x1' == CHANNELS_LAST, '0x0' == FORMATLESS (no format). :param io_formats: Dictionary with keys 'inputs' and 'outputs' that contains dictionaries - mapping tensor name to TensorFormat. + mapping tensor name to NodeFormat. :param neutron_model: Neutron model with single NeutronGraph node. :return: 16 bytes aligned binary payload. """ diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index 74e10edc4c8..24fe13555ca 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -42,6 +42,7 @@ SubTensorPattern, TanhInPlacePattern, TanhPattern, + TransposeIntPattern, ViewPattern, ) from executorch.backends.nxp.quantizer.utils import ( @@ -217,6 +218,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec): NeutronAtenQuantizer(SubTensorPattern(), static_qconfig), NeutronAtenQuantizer(TanhPattern(), static_qconfig), NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig), + NeutronAtenQuantizer(TransposeIntPattern(), static_qconfig), NeutronAtenQuantizer(ViewPattern(), static_qconfig), ] ) diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index ec8ecb83bb3..90c43d1971e 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -691,6 +691,15 @@ def partition_types(self): return [torch.ops.aten.permute.default] +class TransposeIntPattern(SharedSpecPattern): + """ + Quantizer for Transpose Int operator. + """ + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.transpose.int] + + class ReluPattern(SingleInputBasicPattern): """ Quantizer for Relu operator. diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index c6f9296b485..a2dd8cade7b 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -95,6 +95,7 @@ def to_quantized_edge_program( remove_quant_io_ops=False, custom_delegation_options=CustomDelegationOptions(), # noqa B008 get_quantizer_fn=None, + use_neutron_for_format_conversion=True, ) -> EdgeProgramManager: _neutron_target_spec = NeutronTargetSpec(target, neutron_converter_flavor) if get_quantizer_fn is None: @@ -118,6 +119,7 @@ def to_quantized_edge_program( target, operators_not_to_delegate=operators_not_to_delegate, neutron_converter_flavor=neutron_converter_flavor, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, ) partitioners = [ NeutronPartitioner( @@ -143,8 +145,13 @@ def to_quantized_edge_program( def to_quantized_executorch_program( model: torch.nn.Module, input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]], + use_neutron_for_format_conversion: bool = True, ) -> ExecutorchProgramManager: - edge_program_manager = to_quantized_edge_program(model, input_spec) + edge_program_manager = to_quantized_edge_program( + model, + input_spec, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, + ) return edge_program_manager.to_executorch( config=ExecutorchBackendConfig(extract_delegate_segments=False) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py index 315c76a7614..96b9abfe117 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py @@ -14,9 +14,10 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, - ToNCHWPreprocess, - ToNHWCPreprocess, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) + from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -67,7 +68,9 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -80,8 +83,8 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=1.0, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py index 9c8235f7eda..a80d2014487 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py @@ -47,7 +47,9 @@ def test_adaptive_avg_pool_2d_delegated_quant_conversion( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = [str(node) for node in edge_program.graph.nodes] # Input size is a multiple of output size, can be converted to AveragePool, node is delegated @@ -91,7 +93,9 @@ def test_adaptive_avg_pool_2d_non_delegated_quant_conversion( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = list(edge_program.graph.nodes) # Input size is not a multiple of output size, cannot be converted to AveragePool, node is not delegated @@ -122,7 +126,9 @@ def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 2c3107eae77..02e799723d4 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -103,7 +103,9 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index bcdbd955c71..7aed0236043 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -6,10 +6,11 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -91,6 +92,9 @@ def test_avg_pool_2d_conversion(input_shape, padding, count_include_pad): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -145,7 +149,9 @@ def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_includ converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -172,7 +178,9 @@ def test_avg_pool_2d_quant_conversion__padded(mocker): ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture the converter operators. ops = ops_spy.spy_return.sub_graphs[0].operators.vector diff --git a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py index d2aafb570fa..427ddaf14a5 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py @@ -113,7 +113,7 @@ def test_conv_dropout_quant(self, inplace_dropout: bool, input_shape: tuple[int] owner=EdgeProgramToIRConverter, ) as converter_spy: quantized_program = to_quantized_edge_program( - model, input_shape + model, input_shape, use_neutron_for_format_conversion=False ).exported_program() tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value diff --git a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py index 56be613a664..bd1f894001c 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py @@ -7,6 +7,7 @@ import pytest import torch +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, to_quantized_edge_program, @@ -101,6 +102,9 @@ def test_constant_pad_nd_conversion__channels_first(input_shape, paddings): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index ca4a12146fe..0fabbf615c9 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -10,6 +10,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -404,7 +405,9 @@ def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -446,6 +449,7 @@ def test_conv2d_conversion__depthwise__quantized( kernel_size=kernel_shape, ), tuple(input_shape), + use_neutron_for_format_conversion=False, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector @@ -480,6 +484,9 @@ def test_conv2d_conversion__depthwise__padded(padding, mocker): tflite_input_preprocess=ToChannelLastPreprocess(), tflite_output_preprocess=ToChannelFirstPreprocess(), atol=4e-7, + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) conversion_result = spy.spy_return ops = conversion_result.sub_graphs[0].operators.vector @@ -500,6 +507,7 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): group=group, in_channels=group, out_channels=group, padding=padding ), tuple(input_shape), + use_neutron_for_format_conversion=False, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector @@ -576,7 +584,9 @@ def test_conv_transpose2d_conversion__quantized( ): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() # Make sure the `TransposeConv` was delegated. assert not graph_contains_any_of_ops( diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py index c4bc559817b..dad8ce6a0e3 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -42,7 +42,9 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -79,7 +81,9 @@ def test_custom_hardtanh_quant( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py index 50bbf100980..8b938ef7fff 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py @@ -6,10 +6,11 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, @@ -76,6 +77,9 @@ def test_max_pool_2d_conversion(input_shape, padding): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -103,7 +107,11 @@ def test_max_pool_2d_quant_conversion(mocker, input_shape, padding): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(MaxPool2dConvModule(padding=padding), input_shape) + _ = to_quantized_edge_program( + MaxPool2dConvModule(padding=padding), + input_shape, + use_neutron_for_format_conversion=False, + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py index 4bbd89cc01d..ee69b1ea352 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py @@ -53,8 +53,9 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keepdim=True): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - ep = to_quantized_edge_program(model, input_shape).exported_program() - + ep = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() # Make sure the `mean.dim` was delegated. assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) assert any("lowered_module" in n.name for n in ep.graph.nodes) @@ -143,7 +144,9 @@ def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, ke converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated diff --git a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py index d25e2759cc8..57d15aefdc0 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py @@ -3,8 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import unittest + +import kgb import numpy as np -import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( @@ -13,52 +15,312 @@ from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, - ToNCHWPreprocess, - ToNHWCPreprocess, + graph_contains_any_of_ops, ) from executorch.backends.nxp.tests.models import Conv2dModule +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized from torch.export import ExportedProgram -@pytest.fixture(autouse=True) -def reseed_model_per_test_run(): - torch.manual_seed(23) - np.random.seed(23) +class Conv2dTransposeModule(torch.nn.Module): + def __init__(self, in_channels: int, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + self.conv = Conv2dModule( + in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1) + ) + + def forward(self, x): + x = self.conv(x) + return torch.transpose(x, self.dim0, self.dim1) + + +class Conv2dPermuteModule(torch.nn.Module): + def __init__(self, in_channels: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) + + def forward(self, x): + x = self.conv(x) + return torch.permute(x, self.perm) + + +class PermuteConv2dModule(torch.nn.Module): + def __init__(self, in_channels: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) + + def forward(self, x): + x = torch.permute(x, self.perm) + return self.conv(x) -class Conv2dPermuteCopyModule(torch.nn.Module): - def __init__(self, new_dims: tuple[int, ...]): +class PermuteConv2dPermuteModule(torch.nn.Module): + def __init__( + self, in_channels: int, perm1: tuple[int, ...], perm2: tuple[int, ...] + ): super().__init__() - self.new_dims = new_dims - self.conv = Conv2dModule() + self.perm1 = perm1 + self.perm2 = perm2 + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) def forward(self, x): + x = torch.permute(x, self.perm1) x = self.conv(x) - return torch.permute(x, self.new_dims) + x = torch.permute(x, self.perm2) + return x -def test_permute_copy_quant_conversion__with_bias(mocker): - input_shape = (1, 4, 8, 8) - new_dims = (0, 2, 3, 1) +class LinearPermuteModule(torch.nn.Module): + def __init__(self, in_features: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.fc = torch.nn.Linear(in_features, in_features) + + def forward(self, x): + x = self.fc(x) + return torch.permute(x, self.perm) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - # Run conversion - _ = to_quantized_edge_program(Conv2dPermuteCopyModule(new_dims), input_shape) +class TestPermuteCopyConversion(kgb.SpyAgency, unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) - # Capture generated model - tflite_flatbuffers_model, io_formats = converter_spy.spy_return + @parameterized.expand( + [ + ["To channel first permutation", (1, 16, 8, 8), (0, 3, 1, 2)], + ["To channel last permutation", (1, 16, 8, 8), (0, 2, 3, 1)], + ] + ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_input( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = Conv2dPermuteModule(input_shape[1], perm) - # Capture converted program - edge_program: ExportedProgram = converter_spy.call_args.args[1] + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) - convert_run_compare( - edge_program, - input_data, - tfl_model=tflite_flatbuffers_model, - atol=1.0, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["To channel first permutation", (1, 8, 8, 8), (0, 3, 1, 2)], + ["To channel last permutation", (1, 8, 8, 8), (0, 2, 3, 1)], + ] ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_output( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = PermuteConv2dModule(input_shape[1], perm) + + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["nchw->nhwc ... nchw->nhwc", (1, 8, 8, 8), (0, 2, 3, 1), (0, 2, 3, 1)], + ["nchw->nhwc ... nhwc->nchw", (1, 8, 8, 8), (0, 2, 3, 1), (0, 3, 1, 2)], + ["nhwc->nchw ... nhwc->nchw", (1, 8, 8, 8), (0, 3, 1, 2), (0, 3, 1, 2)], + ["nhwc->nchw ... nchw->nhwc", (1, 8, 8, 8), (0, 3, 1, 2), (0, 2, 3, 1)], + ] + ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_io( + self, _: str, input_shape, perm1, perm2 + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = PermuteConv2dPermuteModule(input_shape[1], perm1, perm2) + + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["Permutation can be replaced by reshapes", (10, 1, 8), (0, 2, 1)], + ["Permutation can be replaced by reshapes", (10, 1, 1), (2, 1, 0)], + ["Permutation is identical and can be removed", (10, 1, 8), (0, 1, 2)], + ] + ) + def test_permute_copy_conversion__from_permute_3D__quantized( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + # Run conversion + edge_program = to_quantized_edge_program( + LinearPermuteModule(input_shape[2], perm), input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["Transpose dims 1 and 2", (1, 16, 8, 8), (0, 2, 1, 3)], + ["To (2, 0, 1, 3) permutation", (1, 16, 8, 8), (2, 0, 1, 3)], + ["To (3, 1, 2, 0) permutation", (1, 16, 8, 8), (3, 1, 2, 0)], + ["To (3, 1, 0, 2) permutation", (1, 16, 8, 8), (3, 1, 0, 2)], + ] + ) + def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized( + self, _: str, input_shape, perm + ): + model = Conv2dPermuteModule(input_shape[1], perm) + edge_program = to_quantized_edge_program(model, input_shape).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 10 + assert ( + nodes[6].target == exir_ops.edge.aten.permute_copy.default + ) # PermuteCopy not delegated. + + @parameterized.expand( + [ + ["Transpose dims 1 and 2", (1, 16, 8, 8), 1, 2], + ["Transpose dims 2 and 3", (1, 16, 8, 8), 2, 3], + ] + ) + def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized( + self, _: str, input_shape, dim0, dim1 + ): + model = Conv2dTransposeModule(input_shape[1], dim0, dim1) + edge_program = to_quantized_edge_program(model, input_shape).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 10 + assert ( + nodes[6].target == exir_ops.edge.aten.permute_copy.default + ) # PermuteCopy not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py index 8d903e3e0b5..cf0e0135ffe 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py @@ -67,7 +67,9 @@ def test_relu_with_conv_quant_conversion(mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(ConvReLUModule(), input_shape) + _ = to_quantized_edge_program( + ConvReLUModule(), input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py index c5d7d4d6a38..382266e9cb1 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py @@ -33,7 +33,9 @@ def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - to_quantized_edge_program(model, input_shape).exported_program() + to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py index 98566ff1ad6..336c3cc9afd 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -118,7 +118,9 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape): y_input_shape = (n, 8, h, w) # Run conversion - _ = to_quantized_edge_program(model, [x_input_shape, y_input_shape]) + _ = to_quantized_edge_program( + model, [x_input_shape, y_input_shape], use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py index ca750719a32..eb5fc6600f5 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py @@ -62,7 +62,7 @@ def test_conv_tanh( ) quantized_program = to_quantized_edge_program( - model, input_shape + model, input_shape, use_neutron_for_format_conversion=False ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value exported_program: ExportedProgram = converter_spy.calls[-1].args[0] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py index 448a9753000..fac0a1fffee 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py @@ -12,6 +12,8 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -146,6 +148,9 @@ def test__channels_first_to_4d(mocker): input_data, tflite_input_preprocess=ToNHWCPreprocess(), atol=2.0e-7, + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) tflite_model = converter_spy.spy_return @@ -243,6 +248,7 @@ def test_view_w_conv_linear_quant_conversion(mocker, input_shape, channels_view_ channels=input_shape[1], channels_view_out=channels_view_out ), input_shape, + use_neutron_for_format_conversion=False, ) # Capture generated model diff --git a/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py b/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py index 17b040fbc3d..b5e701ab239 100644 --- a/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py +++ b/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py @@ -51,7 +51,10 @@ def test_remove_io_quant_ops_pass__cifarnet(): model = CifarNet().get_eager_model() input_shape = (1, 3, 32, 32) edge_program_manager = to_quantized_edge_program( - model, input_shape, remove_quant_io_ops=True + model, + input_shape, + remove_quant_io_ops=True, + use_neutron_for_format_conversion=False, ) exec_prog = edge_program_manager.to_executorch( diff --git a/backends/nxp/tests/test_integration.py b/backends/nxp/tests/test_integration.py index d31b22c9ce9..3bd5f3e1487 100644 --- a/backends/nxp/tests/test_integration.py +++ b/backends/nxp/tests/test_integration.py @@ -39,7 +39,9 @@ def test_conv_fc_softmax__to_executorch_program(): def test_cifarnet(): model = CifarNet().get_eager_model().eval() input_shape = (1, 3, 32, 32) - exec_prog = to_quantized_executorch_program(model, input_shape) + exec_prog = to_quantized_executorch_program( + model, input_shape, use_neutron_for_format_conversion=False + ) delegation_info = get_delegation_info(exec_prog.exported_program().graph_module) assert delegation_info.num_delegated_subgraphs == 1 diff --git a/backends/nxp/tests/test_move_activation_before_concatenation.py b/backends/nxp/tests/test_move_activation_before_concatenation.py index 114b17720f6..cede3e41994 100644 --- a/backends/nxp/tests/test_move_activation_before_concatenation.py +++ b/backends/nxp/tests/test_move_activation_before_concatenation.py @@ -631,7 +631,7 @@ def test_move_activation_before_concat_quantization__conv( ) edge_program = to_quantized_edge_program( - model, input_shape + model, input_shape, use_neutron_for_format_conversion=False ).exported_program() # Make sure that all nodes were delegated. @@ -822,7 +822,7 @@ def test_concat_cluster_quantization__conv( ) edge_program = to_quantized_edge_program( - model, input_shape + model, input_shape, use_neutron_for_format_conversion=False ).exported_program() # Make sure that all nodes were delegated. diff --git a/backends/nxp/tests/test_neutron_backend.py b/backends/nxp/tests/test_neutron_backend.py index c9917651fbd..08c66b22585 100644 --- a/backends/nxp/tests/test_neutron_backend.py +++ b/backends/nxp/tests/test_neutron_backend.py @@ -21,7 +21,9 @@ def test_neutron_backend__single_conv_model(): def test_neutron_backend__single_conv_model__payload_header_channels_last(): edge_program_manager = to_quantized_edge_program( - Conv2dModule(bias=False), (1, 4, 32, 32) + Conv2dModule(bias=False), + (1, 4, 32, 32), + use_neutron_for_format_conversion=False, ) payload = ( edge_program_manager.exported_program().graph_module.lowered_module_0.processed_bytes diff --git a/backends/nxp/tests/test_neutron_backend_executor.py b/backends/nxp/tests/test_neutron_backend_executor.py index 3503403311f..6daf1570374 100644 --- a/backends/nxp/tests/test_neutron_backend_executor.py +++ b/backends/nxp/tests/test_neutron_backend_executor.py @@ -11,10 +11,13 @@ ) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOptions import BuiltinOptions from executorch.backends.nxp.backend.ir.lib.tflite.Model import Model +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.nxp_backend import PayloadComposer from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, EdgeProgramExecutor, + graph_contains_any_of_ops, TFLiteExecutor, ToNHWCPreprocess, ) @@ -108,3 +111,217 @@ def test_conv_fc__lowered_program_and_tflite_output_match(mocker): input_data=input_data, tflite_input_preprocess=ToNHWCPreprocess(), ) + + +def test_delegating_format_related_transpose_operators__unsupported_shapes(mocker): + # This test focuses on the case when Neutron would not support the inserted Transpose operators, so they are not + # inserted, so the runtime will permute the data. + + # Make sure none of the dimensions are multiples of `num_macs` (8), for proper testing. + model = Conv2dModule(in_channels=3, out_channels=3, padding=1, stride=1) + input_shape = (1, 3, 3, 3) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure the `Transpose` ops are NOT in the IR model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 2 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [1, 1]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_case(mocker): + # Make sure the output channels (channels for the trailing Transpose), and the last input dimension (channels for + # the leading Transpose) are multiples of `num_macs``. + + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, out_channels=num_macs, padding=1, stride=1 + ) + input_shape = (1, num_macs, num_macs, num_macs) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure the `Transpose` ops ARE in the IR model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 4 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(3).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `0` means `channels_last`, which means the runtime will NOT transpose the data. + assert all(payload_header[3:5] == [0, 0]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_output__unsupported_input( + mocker, +): + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, + out_channels=num_macs, # The output `Transpose` will be supported. + padding=1, + stride=1, + ) + input_shape = (1, num_macs, num_macs, 3) # The input `Transpose` is not supported. + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure there is just the 1 `Transpose` in the model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 3 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [1, 0]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_input__unsupported_output( + mocker, +): + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, + out_channels=3, # The output `Transpose` will NOT be supported. + stride=1, + ) + input_shape = (1, num_macs, 3, num_macs) # The input `Transpose` is supported. + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure there is just the 1 `Transpose` in the model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 2 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [0, 1]) # [, ] diff --git a/backends/nxp/tests/test_per_channel_conversion.py b/backends/nxp/tests/test_per_channel_conversion.py index b988fce470d..62cbef9e151 100644 --- a/backends/nxp/tests/test_per_channel_conversion.py +++ b/backends/nxp/tests/test_per_channel_conversion.py @@ -126,6 +126,7 @@ def test_per_channel_convolution(self): get_quantizer_fn=lambda: NeutronAtenQuantizer( Conv2dPatternPerChannel(is_per_channel=True), static_qconfig ), + use_neutron_for_format_conversion=False, ) tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value