diff --git a/backends/nxp/backend/ir/conversion_config.py b/backends/nxp/backend/ir/conversion_config.py index 4ac88eb467c..622735e881f 100644 --- a/backends/nxp/backend/ir/conversion_config.py +++ b/backends/nxp/backend/ir/conversion_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,7 +14,6 @@ def __init__(self, args: dict | None = None): :param args: Optional dictionary with conversion arguments. Unknown arguments are ignored. """ self.keep_io_format: bool = False - self.skip_shape_inference: bool = False self.allow_inputs_stripping: bool = True self.qdq_aware_conversion: bool = True self.symbolic_dimensions_mapping: dict[str, int] | None = None @@ -46,15 +45,6 @@ def __repr__(self): return "ConversionConfig[" + ", ".join(attrs) + "]" -class SkipShapeInferenceConfig(ConversionConfig): - - def __init__(self): - """ - Conversion config shortcut with disabled shape inference. - """ - super().__init__({"skip_shape_inference": True}) - - class QDQAwareConfig(ConversionConfig): def __init__(self): diff --git a/backends/nxp/backend/ir/converter/builder/model_builder.py b/backends/nxp/backend/ir/converter/builder/model_builder.py index 4f036854138..496fa752853 100755 --- a/backends/nxp/backend/ir/converter/builder/model_builder.py +++ b/backends/nxp/backend/ir/converter/builder/model_builder.py @@ -1,6 +1,6 @@ # # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -795,29 +795,8 @@ def _remove_tensor_with_name(self, name): def append_new_tensor(self, t_tensor: tflite_model.Tensor, overwrite: bool = False): """Append the TFLite tensor 't_tensor' to the 'SubGraph.tensors' and register it.""" - - if t_tensor.name in self._tensor_name_map.keys(): - """Tensor has already been added. Sometimes however, ONNX models - will have tensors in their 'inputs' or 'outputs', which don't - belong there and are in fact static. I this case we need to - overwrite the existing tensors.""" - - if overwrite: - self._remove_tensor_with_name(t_tensor.name) - - # If the tenor previously appeared in ONNX 'inputs' or 'outputs', - # the old version MUST be removed from there. - self._remove_input_with_name(t_tensor.name) - self._remove_output_with_name(t_tensor.name) - - self.get_tensors().append(t_tensor) - self._tensor_name_map[t_tensor.name] = t_tensor - else: - logger.w(f"Tensor '{t_tensor.name}' is already in the tensors!") - - else: - self._tensor_name_map[t_tensor.name] = t_tensor - self.get_tensors().append(t_tensor) + self._tensor_name_map[t_tensor.name] = t_tensor + self.get_tensors().append(t_tensor) def append_new_buffer(self, buffer: tflite_model.Buffer): """Append the 'buffer' to the 'model.buffers'.""" @@ -1515,7 +1494,7 @@ def prepare_dynamic_tensor_for_correct_broadcasting_with_channels_first_tensors( # Prepend a partial identity, to keep leading dimensions unchanged. revert_perm = list(range(rank_diff)) + list(revert_perm) - # Now add a permutation to convert the extended ONNX shape to a TFLite shape + # Now add a permutation to convert the extended ExecuTorch shape to a TFLite shape to_tflite_perm = ( translator.create_channels_first_to_channels_last_permutation( output_rank @@ -1579,20 +1558,20 @@ def prepare_static_tensor_for_correct_broadcasting_with_channels_first_tensors( original_shape = translator.dims_to_channels_first( shape - ) # Same shape as in the ONNX model + ) # Same shape as in the ExecuTorch model # Prepend 1s to the shape - extended_onnx_shape = [1] * rank_diff + original_shape + extended_executorch_shape = [1] * rank_diff + original_shape # Convert the full shape to TFLite format - tflite_shape = translator.dims_to_channels_last(extended_onnx_shape) + tflite_shape = translator.dims_to_channels_last(extended_executorch_shape) tensor.shape = tflite_model.Shape(tflite_shape) # Statically transpose the data data = translator.convert_data_to_channels_first( data - ) # To the same shape as in the ONNX model - data = data.reshape(extended_onnx_shape) # Extend with leading 1s + ) # To the same shape as in the ExecuTorch model + data = data.reshape(extended_executorch_shape) # Extend with leading 1s tensor.tmp_buffer.data = translator.convert_data_to_channels_last( data ) # Convert to TFLite format @@ -1600,16 +1579,16 @@ def prepare_static_tensor_for_correct_broadcasting_with_channels_first_tensors( assert tflite_shape == list(tensor.tmp_buffer.data.shape) else: - # The tensor is the same as in the ONNX model. + # The tensor is the same as in the ExecuTorch model. - extended_onnx_shape = [1] * rank_diff + shape + extended_executorch_shape = [1] * rank_diff + shape # Convert the full shape to TFLite format - tflite_shape = translator.dims_to_channels_last(extended_onnx_shape) + tflite_shape = translator.dims_to_channels_last(extended_executorch_shape) tensor.shape = tflite_model.Shape(tflite_shape) # Statically transpose the data - data = data.reshape(extended_onnx_shape) # Extend with leading 1s + data = data.reshape(extended_executorch_shape) # Extend with leading 1s tensor.tmp_buffer.data = translator.convert_data_to_channels_last( data ) # Convert to TFLite format diff --git a/backends/nxp/backend/ir/converter/conversion/common.py b/backends/nxp/backend/ir/converter/conversion/common.py index 8230e39a7fa..318fe66dfbd 100755 --- a/backends/nxp/backend/ir/converter/conversion/common.py +++ b/backends/nxp/backend/ir/converter/conversion/common.py @@ -1,6 +1,6 @@ # # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -12,7 +12,7 @@ 'conversion/builtin/' directory. """ -from typing import Any, List, MutableSequence, Optional +from typing import List, MutableSequence, Optional import executorch.backends.nxp.backend.ir.logger as logger from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model @@ -22,28 +22,8 @@ max_pool_2d_options, transpose_conv_options, ) -from torch.fx import Node - - -def exactly_one_is_none(obj1: Optional, obj2: Optional) -> bool: - """Determine if exactly 1 of the arguments is None, or not.""" - return (obj1 is None and obj2 is not None) or (obj1 is not None and obj2 is None) - - -def contains_duplicates(list_to_check: List[Any]) -> bool: - """Determine if given list has duplicate elements or not.""" - return len(list_to_check) != len(set(list_to_check)) - - -def clamp(val: int, start: int, end: int) -> int: - """Clamp an int value between start and end (inclusive) and return it.""" - if val < start: - return start - - elif val > end: - return end - return val +from torch.fx import Node def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor | None: @@ -62,11 +42,6 @@ def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor tensor = t_op.tmp_inputs[idx] - if tensor.name == "": - # ONNX allows the name "" for optional tensors. It indicates that the tensor should be ignored, and a default - # value should be used. Just like if the tensor was omitted altogether. - return None - return tensor @@ -101,7 +76,7 @@ def assign_2d_strides(options: StridedOptions, strides: Optional[List[int]]): If 'strides' is None, assign 1s. :param options: TFLite AveragePool2D, Conv2D, MaxPool2D or TransposeConv options object. - :param strides: An optional list of ONNX strides attribute. + :param strides: An optional list of ExecuTorch strides attribute. """ if strides is None: @@ -115,8 +90,8 @@ def assign_2d_strides(options: StridedOptions, strides: Optional[List[int]]): else: logger.e( - logger.Code.INVALID_ONNX_OPERATOR_ATTRIBUTE, - f"ONNX operator has invalid 'strides' attribute! ('{strides}')", + logger.Code.INVALID_OPERATOR_ATTRIBUTE, + f"ExecuTorch operator has invalid 'strides' attribute! ('{strides}')", ) @@ -188,32 +163,6 @@ def node_uses_shape_broadcasting(node: Node) -> bool: ) -def uses_multiple_input_types(t_op: tflite_model.Operator) -> bool: - """Determine if the input tensors of given TFLite operator use different data types or not. - - :param t_op: TFLite operator with 'tmp_inputs' initialized. - :return: True, if any two input tensors have a different data type. - False, if all input tensors use the same data type. - """ - - if t_op.tmp_inputs is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "common.uses_multiple_input_types(): 'tmp_inputs' are None!", - ) - - if len(t_op.tmp_inputs) == 0: - logger.e( - logger.Code.INTERNAL_ERROR, - "common.uses_multiple_input_types(): Operator has no inputs!", - ) - - first_input_type = t_op.tmp_inputs[0].type - return any( - input_tensor.type != first_input_type for input_tensor in t_op.tmp_inputs[1:] - ) - - class OpsList: """ Holder of TFLite operator (middle_op) that can be prefixed (pre_ops) of suffixed (post_ops) diff --git a/backends/nxp/backend/ir/converter/conversion/translator.py b/backends/nxp/backend/ir/converter/conversion/translator.py index 4f327c6ac80..1fe195843c0 100755 --- a/backends/nxp/backend/ir/converter/conversion/translator.py +++ b/backends/nxp/backend/ir/converter/conversion/translator.py @@ -1,6 +1,5 @@ -# # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -9,10 +8,10 @@ translator Module contains functions for context-free conversion of various -things from ONNX to TFLite. +things from ExecuTorch to NeutronIR. """ -from typing import Any, Collection, List, Optional, Sequence, Tuple +from typing import Any, Collection, List, Optional, Sequence import executorch.backends.nxp.backend.ir.lib.tflite.Padding as tflPadding import executorch.backends.nxp.backend.ir.logger as logger @@ -21,16 +20,12 @@ import numpy as np import torch from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType -from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat -from executorch.backends.nxp.backend.ir.tflite_generator.meta.types import ( - TensorFlowDataType, -) def permute_static_tensor(tensor: tflite_model.Tensor, perm: list[int]): - """Take a static TFLite tensor and permute its shape and data according to the permutation in 'perm'. + """Take a static NeutronIR tensor and permute its shape and data according to the permutation in 'perm'. - :param tensor: Static TFLite tensor to permute. + :param tensor: Static NeutronIR tensor to permute. :param perm: Permutation to apply to the tensor. """ @@ -53,7 +48,7 @@ def permute_static_tensor(tensor: tflite_model.Tensor, perm: list[int]): def get_tflite_tensor_shape_with_explicit_padding( tflite_shape: List[int], explicit_padding: List[List[int]] ) -> List[int]: - """Get the resulting shape of a tensor with shape 'tflite_shape' (in TFLite format), after 'explicit_padding' is + """Get the resulting shape of a tensor with shape 'tflite_shape' (in NeutronIR format), after 'explicit_padding' is applied to it. """ @@ -62,7 +57,7 @@ def get_tflite_tensor_shape_with_explicit_padding( ): logger.e( logger.Code.INTERNAL_ERROR, - f"Cannot apply padding '{explicit_padding}' to TFLite shape '{tflite_shape}'!", + f"Cannot apply padding '{explicit_padding}' to NeutronIR shape '{tflite_shape}'!", ) total_padding = [ @@ -90,24 +85,9 @@ def get_tflite_tensor_shape_with_explicit_padding( return padded_shape -def convert_tensor_format_to_tflite(tensor_format: TensorFormat) -> TensorFormat: - """Convert the format of a tensor from ONNX to TFLite. - :return: The tensor_format converted to TFLite. - """ - if tensor_format is TensorFormat.CHANNELS_FIRST: - return TensorFormat.CHANNELS_LAST - - elif tensor_format not in (TensorFormat.FORMATLESS, TensorFormat.NONE): - logger.d( - f"translator.convert_tensor_format(): Got unexpected format '{tensor_format}'." - ) - - return tensor_format - - def dims_to_channels_first(channels_last_dimensions: List[int]) -> List[int]: - """Convert a list of ints which represent dimensions in the channels last (TFLite) format to the channels first - (ONNX) format. + """Convert a list of ints which represent dimensions in the channels last (NeutronIR) format to the channels first + (ExecuTorch) format. """ assert len(channels_last_dimensions) > 0, "Dimensions list is empty!" @@ -122,8 +102,8 @@ def dims_to_channels_first(channels_last_dimensions: List[int]) -> List[int]: def dims_to_channels_last(channels_first_dimensions: List[int]) -> List[int]: - """Convert a list of ints which represent dimensions in the channels first (ONNX) format to the channels last - (TFLite) format. + """Convert a list of ints which represent dimensions in the channels first (ExecuTorch) format to the channels last + (NeutronIR) format. """ assert len(channels_first_dimensions) > 0, "Dimensions list is empty!" @@ -171,7 +151,7 @@ def _same_upper_equals_same_lower( o_strides: Optional[List[int]] = None, o_dilations: Optional[List[int]] = None, ) -> bool: - """Determine if in a given particular setting, the values of the ONNX `auto_pads` attribute SAME_UPPER and + """Determine if in a given particular setting, the values of the ExecuTorch `auto_pads` attribute SAME_UPPER and SAME_LOWER represent the exact same padding. """ @@ -193,7 +173,7 @@ def _tflite_padding_compute_output_size( """ Calculates the output shape of the tensor with particular setting as tflite would. Implementation corresponds to tensorflow/lite/kernels/padding.h:ComputeOutSize() - :param padding: TFLite Padding value - 'Same' or 'Valid' + :param padding: NeutronIR Padding value - 'Same' or 'Valid' :param tflite_spatial_input_shape: input tensor shape :param tflite_kernel_shape: convolution kernel shape :param strides: strides (default is 1) @@ -229,7 +209,7 @@ def tflite_compute_padding_with_offset( dilations: Optional[List[int]] = None, ) -> (List[int], List[int]): """ - Calculate padding and offset for each dimension for particular convolution setting as TFLite. + Calculate padding and offset for each dimension for particular convolution setting as NeutronIR. Implementation corresponds to tensorflow/lite/kernels/padding.h:ComputePaddingWithOffset() :param tflite_input_shape: tensorflow lite input shape :param tflite_kernel_shape: tensorflow lite kernel shape @@ -272,14 +252,14 @@ def _is_same_padding( o_strides: Optional[List[int]] = None, o_dilations: Optional[List[int]] = None, ) -> bool: - """Determine if given ONNX 'pads' padding can be represented exactly with the TFLite 'SAME' padding type. - - :param o_pads: ONNX 'pads' attribute. - :param tflite_input_shape: The shape of the main input of the operator in TFLite format. - :param tflite_output_shape: The shape of the main output of the operator in TFLite format. - :param o_kernel_shape: ONNX 'kernel_shape' attribute. - :param o_strides: ONNX 'strides' attribute. Can be omitted. - :param o_dilations: ONNX 'dilations' attribute. Can be omitted. + """Determine if given ExecuTorch 'pads' padding can be represented exactly with the NeutronIR 'SAME' padding type. + + :param o_pads: ExecuTorch 'pads' attribute. + :param tflite_input_shape: The shape of the main input of the operator in NeutronIR format. + :param tflite_output_shape: The shape of the main output of the operator in NeutronIR format. + :param o_kernel_shape: ExecuTorch 'kernel_shape' attribute. + :param o_strides: ExecuTorch 'strides' attribute. Can be omitted. + :param o_dilations: ExecuTorch 'dilations' attribute. Can be omitted. """ if len(tflite_input_shape) == 0 or len(tflite_output_shape) == 0: @@ -289,7 +269,7 @@ def _is_same_padding( f"'{tflite_input_shape}' and output shape '{tflite_output_shape}'.", ) - # Calculate if the output shape corresponds to Same padding setting in TFLite + # Calculate if the output shape corresponds to Same padding setting in NeutronIR tflite_spatial_input_shape = tflite_input_shape[1:-1] tmp_spatial_output_shape = _tflite_padding_compute_output_size( tflPadding.Padding.SAME, @@ -302,10 +282,10 @@ def _is_same_padding( return False # For every dimension, the padding is added to the start and end of the dimension. - # TFLite padding 'SAME' tries to split it evenly, but in case of odd padding, 'SAME' adds the excess 1 at the end. - # TFLite represents this in the offset. The offset is added to the end of particular dimension, + # NeutronIR padding 'SAME' tries to split it evenly, but in case of odd padding, 'SAME' adds the excess 1 at the end. + # NeutronIR represents this in the offset. The offset is added to the end of particular dimension, # i.e. bottom for H dim, right for W dim and so on. - # ONNX represents this in 'pads' as [x1_begin, x2_begin,... , x1_end, x2_end,...]. + # ExecuTorch represents this in 'pads' as [x1_begin, x2_begin,... , x1_end, x2_end,...]. padding, offset = tflite_compute_padding_with_offset( tflite_input_shape, o_kernel_shape, tflite_output_shape, o_strides, o_dilations ) @@ -319,30 +299,6 @@ def _is_same_padding( return True -def permutations_are_inverse( - permutation1: Sequence[int], permutation2: Sequence[int] -) -> bool: - """Determine if given Transpose permutations are inverse of each other. - i.e. when applied back to back, there will be no effect. - - Example: - 0 3 1 2 - 0 2 3 1 - """ - - if len(permutation1) != len(permutation2): - logger.e( - logger.Code.INTERNAL_ERROR, - "translator.permutations_are_inverse(): permutations have different size!", - ) - - for i, perm2 in enumerate(permutation2): - if i != permutation1[perm2]: - return False - - return True - - def combine_permutations( permutation1: Sequence[int], permutation2: Sequence[int] ) -> List[int]: @@ -375,31 +331,35 @@ def shape_from_numpy(numpy_array): return tflite_model.Shape(dims) -def onnx_explicit_padding_to_tflite(onnx_pads: list[int]) -> list[list[int]]: - """Convert the attribute or input 'pads' of the ONNX 'Pad' operator to the 'paddings' input of the TFLite 'Pad' +def executorch_explicit_padding_to_tflite( + executorch_pads: list[int], +) -> list[list[int]]: + """Convert the attribute or input 'pads' of the ExecuTorch 'Pad' operator to the 'paddings' input of the NeutronIR 'Pad' class of operators. This function does NOT take tensor formats into consideration. """ - start_padding = onnx_pads[ - : len(onnx_pads) // 2 + start_padding = executorch_pads[ + : len(executorch_pads) // 2 ] # Padding at the start of each dimension - end_padding = onnx_pads[ - len(onnx_pads) // 2 : + end_padding = executorch_pads[ + len(executorch_pads) // 2 : ] # Padding at the end of each dimension return list(zip(start_padding, end_padding)) -def onnx_pads_to_tflite_explicit_padding(onnx_pads: List[int]) -> List[List[int]]: - """Convert an ONNX attribute 'pads' of operators such as Conv, MaxPool or AveragePool, to a list of ints which is - compatible with the TFLite 'Pad' operator. +def executorch_pads_to_tflite_explicit_padding( + executorch_pads: List[int], +) -> List[List[int]]: + """Convert an ExecuTorch attribute 'pads' of operators such as Conv, MaxPool or AveragePool, to a list of ints which is + compatible with the NeutronIR 'Pad' operator. """ - tflite_padding = onnx_explicit_padding_to_tflite(onnx_pads) + tflite_padding = executorch_explicit_padding_to_tflite(executorch_pads) - # TFLite also allows padding to the 'batch' and 'channels'. ONNX does not + # NeutronIR also allows padding to the 'batch' and 'channels'. ExecuTorch does not tflite_padding.insert(0, [0, 0]) tflite_padding.append([0, 0]) @@ -413,15 +373,15 @@ def _get_explicit_tflite_padding_for_same_lower( o_strides: Optional[List[int]] = None, o_dilations: Optional[List[int]] = None, ) -> List[List[int]]: - """Get the TFLite explicit padding required to represent ONNX 'SAME_LOWER' auto_pad for a particular setting. + """Get the NeutronIR explicit padding required to represent ExecuTorch 'SAME_LOWER' auto_pad for a particular setting. - :param tflite_input_shape: TFLite (NHWC) shape of the input tensor of the operator. - :param tflite_output_shape: TFLite (NHWC) shape of the output tensor of the operator. - :param o_kernel_shape: ONNX 'kernel_shape' attribute. - :param o_strides: Optional ONNX 'o_strides' attribute. - :param o_dilations: Optional ONNX 'o_dilations' attribute. + :param tflite_input_shape: NeutronIR (NHWC) shape of the input tensor of the operator. + :param tflite_output_shape: NeutronIR (NHWC) shape of the output tensor of the operator. + :param o_kernel_shape: ExecuTorch 'kernel_shape' attribute. + :param o_strides: Optional ExecuTorch 'o_strides' attribute. + :param o_dilations: Optional ExecuTorch 'o_dilations' attribute. - :return: A TFLite style explicit padding, compatible with the TFLite 'Pad' operator. + :return: A NeutronIR style explicit padding, compatible with the NeutronIR 'Pad' operator. """ padding, offset = tflite_compute_padding_with_offset( @@ -433,102 +393,15 @@ def _get_explicit_tflite_padding_for_same_lower( ] # In case of odd padding, the excess is added at the start end_padding = padding - onnx_explicit_padding = start_padding + end_padding - - # Return explicit ONNX padding converted to TFLite padding - return onnx_pads_to_tflite_explicit_padding(onnx_explicit_padding) - - -def convert_padding( - o_auto_pad: str, - o_pads: List[int], - tflite_input_shape: List[int], - tflite_output_shape: List[int], - o_kernel_shape: List[int], - o_strides: Optional[List[int]], - o_dilations: Optional[List[int]] = None, -) -> Tuple[tflPadding.Padding, Optional[List[List[int]]]]: - """Convert ONNX operator attributes 'pads' and 'auto_pad' to TFLite. - - :param o_auto_pad: ONNX operator attribute 'auto_pad' - :param o_pads: ONNX operator attribute 'pads' - :param tflite_input_shape: The shape of the main input tensor in the TFLite format. - :param tflite_output_shape: The shape of the main output tensor in the TFLite format. - :param o_kernel_shape: ONNX operator attribute 'kernel_shape' - :param o_strides: ONNX operator attribute 'strides' - :param o_dilations: ONNX operator attribute 'dilations' - - :return: A tuple. - The first element is the converted TFLite padding. - The second is None, if conversion is finished. Or it is a list of ints representing the explicit - padding in TFLite format (compatible with the 'Pad' operator), which needs to be provided by a - 'Pad' operator. Caller must add this operator using model_builder! - """ - - if o_auto_pad == "SAME_UPPER": - return tflPadding.Padding.SAME, None - - elif o_auto_pad == "SAME_LOWER": - if _same_upper_equals_same_lower( - tflite_input_shape, - tflite_output_shape, - o_kernel_shape, - o_strides, - o_dilations, - ): - return tflPadding.Padding.SAME, None - - else: - logger.d( - "'SAME_LOWER' auto_pad cannot be exactly represented in TFLite as padding 'SAME' or 'VALID'. " - "Inserting an extra 'Pad' operator." - ) - tflite_explicit_padding = _get_explicit_tflite_padding_for_same_lower( - tflite_input_shape, - tflite_output_shape, - o_kernel_shape, - o_strides, - o_dilations, - ) - return tflPadding.Padding.VALID, tflite_explicit_padding - - elif o_auto_pad == "VALID": - return tflPadding.Padding.VALID, None - - # auto_pad is NOTSET -> use explicit padding - elif o_pads is None or all(val == 0 for val in o_pads): - # No padding in any direction - return tflPadding.Padding.VALID, None - - elif _is_same_padding( - o_pads, - tflite_input_shape, - tflite_output_shape, - o_kernel_shape, - o_strides, - o_dilations, - ): - # Explicit padding can be represented with TFLite 'SAME' padding. - return tflPadding.Padding.SAME, None - - else: - # 'pads' cannot be converted directly. Return 'VALID' and the required explicit padding and caller must - # implement conversion by adding a 'Pad' operator. - - logger.d( - "Explicit ONNX 'pads' cannot be represented directly as 'SAME' or 'VALID'. " - "Inserting an extra 'Pad' operator." - ) - - # ONNX 'pads' uses different format than TFLite 'Pad' operator. Convert the explicit padding. - tflite_explicit_padding = onnx_pads_to_tflite_explicit_padding(o_pads) + executorch_explicit_padding = start_padding + end_padding - return tflPadding.Padding.VALID, tflite_explicit_padding + # Return explicit ExecuTorch padding converted to NeutronIR padding + return executorch_pads_to_tflite_explicit_padding(executorch_explicit_padding) def convert_data_to_channels_first(array: np.ndarray) -> np.ndarray: - """Convert a numpy array representing the data of a tensor from the channels last format (TFLite), to channels - first format (ONNX). + """Convert a numpy array representing the data of a tensor from the channels last format (NeutronIR), to channels + first format (ExecuTorch). :param array: Numpy array holding the tensor's data. :return: The transformed data. @@ -543,8 +416,8 @@ def convert_data_to_channels_first(array: np.ndarray) -> np.ndarray: def convert_data_to_channels_last(array: np.ndarray) -> np.ndarray: - """Convert a numpy array representing the data of a tensor from the channels first format (ONNX), to channels last - format (TFLite). + """Convert a numpy array representing the data of a tensor from the channels first format (ExecuTorch), to channels last + format (NeutronIR). :param array: Numpy array holding the tensor's data. :return: The transformed data. @@ -558,17 +431,6 @@ def convert_data_to_channels_last(array: np.ndarray) -> np.ndarray: return np.moveaxis(array, 1, -1) # Move the second axis (C), to the end -def channels_first_shape_to_channels_last( - channels_first_shape: tflite_model.Shape, -) -> tflite_model.Shape: - """Create a channels last version of a channels first 'tflite_model.Shape' object.""" - - dims = channels_first_shape.vector.copy() - dims = dims_to_channels_last(dims) - - return tflite_model.Shape(dims) - - def channels_last_shape_to_channels_first( nhwc_shape: tflite_model.Shape, ) -> tflite_model.Shape: @@ -580,23 +442,13 @@ def channels_last_shape_to_channels_first( return tflite_model.Shape(dims) -def convert_onnx_dimensions_to_tflite_shape(o_dims: List[int]) -> tflite_model.Shape: - """Convert list of ints representing the shape of an ONNX channels first Tensor to a TFLite 'Shape' object.""" - - dims = list(o_dims) # Copy just in case - - dims = dims_to_channels_last(dims) - - return tflite_model.Shape(dims) - - def create_channels_last_to_channels_first_permutation( rank: int, return_list: bool = False ) -> np.ndarray | list[int]: """Return a numpy array with data that describes the permutation, which would change a tensor from the channels - last (TFLite) format to the channels first (ONNX) format. + last (NeutronIR) format to the channels first (ExecuTorch) format. - This permutation is compatible with the TFLite `Transpose` operator. + This permutation is compatible with the NeutronIR `Transpose` operator. :param rank: The rank of the required permutation. :param return_list: If True, the function returns a list of ints. If False, a numpy array is returned. @@ -615,9 +467,9 @@ def create_channels_first_to_channels_last_permutation( rank: int, return_list: bool = False ) -> np.ndarray | list[int]: """Return a numpy array with data that describes the permutation, which would change a tensor from the channels - first (ONNX) format to the channels last (TFLite) format. + first (ExecuTorch) format to the channels last (NeutronIR) format. - This permutation is compatible with the TFLite `Transpose` operator. + This permutation is compatible with the NeutronIR `Transpose` operator. :param rank: The rank of the required permutation. :param return_list: If True, the function returns a list of ints. If False, a numpy array is returned. @@ -632,35 +484,8 @@ def create_channels_first_to_channels_last_permutation( return np.asarray(perm, np.int32) -def create_axis_to_last_perm(axis, num_dims): - """Create a numpy array representing the transpose permutations needed, to - make the 'axis' dimension, the last dimension. - """ - - dims = list(range(num_dims)) - - if axis == num_dims - 1: - return dims - elif axis >= num_dims or axis < 0: - logger.e( - logger.Code.INTERNAL_ERROR, - f"translator.create_axis_to_last_perm({axis},{num_dims}). Inputs don't make sense!", - ) - - # Remember axis dimension - axis_dim = dims[axis] - - # Move dimensions after 'axis' to the left - dims[axis:-1] = dims[axis + 1 : -1] - - # Add axis dimension to the end - dims.append(axis_dim) - - return np.asarray(dims, np.int32) - - def apply_permutation_to(target: List[Any], permutation: Collection[int]) -> List: - """Permute a list according to a permutation. Uses the same permutation format as the TFLite Transpose operator. + """Permute a list according to a permutation. Uses the same permutation format as the NeutronIR Transpose operator. :param target: A list of any types, to permute. Must be same size as the permutation. :param permutation: The permutation to apply to the target. @@ -678,7 +503,7 @@ def apply_permutation_to(target: List[Any], permutation: Collection[int]) -> Lis def create_inverse_permutation(permutation: List[int]) -> List[int]: """Create and return a permutation, that is the inverse of the given 'permutation' parameter. - Uses the same permutation format as the TFLite Transpose operator. + Uses the same permutation format as the NeutronIR Transpose operator. :param permutation: The permutation to create the inverse of. :return: Inverse permutation. @@ -694,38 +519,8 @@ def create_inverse_permutation(permutation: List[int]) -> List[int]: return [permutation.index(perm) for perm in range(len(permutation))] -def get_max_value_for_type(dtype: np.dtype) -> any: - """Return the maximum possible value for given numpy type.""" - if dtype.kind in ("i", "u"): - return np.iinfo(dtype).max - - elif dtype.kind == "f": - return np.finfo(dtype).max - - else: - logger.e( - logger.Code.INTERNAL_ERROR, - f"translator.get_max_value_for_type(): unexpected type {dtype.name}.", - ) - - -def get_min_value_for_type(dtype: np.dtype) -> any: - """Return the minimum possible value for given numpy type.""" - if dtype.kind in ("i", "u"): - return np.iinfo(dtype).min - - elif dtype.kind == "f": - return np.finfo(dtype).min - - else: - logger.e( - logger.Code.INTERNAL_ERROR, - f"translator.get_min_value_for_type(): unexpected type {dtype.name}.", - ) - - def convert_data_type(torch_type: torch.TensorType) -> TensorType: - """Convert Torch DataType to TFLite TensorType""" + """Convert Torch DataType to NeutronIR TensorType""" if torch_type == torch.float32: return TensorType.FLOAT32 @@ -753,7 +548,7 @@ def convert_data_type(torch_type: torch.TensorType) -> TensorType: def torch_type_to_numpy_type(torch_type: torch.TensorType) -> np.ScalarType: - """Convert Torch DataType to TFLite TensorType""" + """Convert Torch DataType to NeutronIR TensorType""" if torch_type == torch.float32: return np.dtype(np.float32) @@ -778,10 +573,10 @@ def torch_type_to_numpy_type(torch_type: torch.TensorType) -> np.ScalarType: def numpy_type_to_tf_lite(numpy_type: np.dtype) -> TensorType: # noqa C901 - """Convert the numpy data type to a corresponding TFLite 'TensorType'. + """Convert the numpy data type to a corresponding NeutronIR 'TensorType'. :param numpy_type: Numpy dtype to convert. - :return: Corresponding TFLite TensorType. + :return: Corresponding NeutronIR TensorType. """ numpy_type = numpy_type.type @@ -835,12 +630,12 @@ def numpy_type_to_tf_lite(numpy_type: np.dtype) -> TensorType: # noqa C901 else: logger.e( logger.Code.CONVERSION_IMPOSSIBLE, - f"Cannot convert numpy data type '{numpy_type}' to TFLite.", + f"Cannot convert numpy data type '{numpy_type}' to NeutronIR.", ) def tf_lite_type_to_numpy(tfl_type: TensorType) -> np.ScalarType: # noqa C901 - """Convert TFLite TensorType to numpy dtype""" + """Convert NeutronIR TensorType to numpy dtype""" if tfl_type == TensorType.FLOAT32: return np.dtype(np.float32) @@ -890,72 +685,5 @@ def tf_lite_type_to_numpy(tfl_type: TensorType) -> np.ScalarType: # noqa C901 else: logger.e( logger.Code.CONVERSION_IMPOSSIBLE, - f"Cannot convert TFLite type '{tfl_type}' to numpy dtype.", + f"Cannot convert NeutronIR type '{tfl_type}' to numpy dtype.", ) - - -def tflite_type_to_tensor_flow_data_type(tfl_type: TensorType) -> TensorFlowDataType: - """Convert TFLite TensorType to the internal type of TensorFlow.""" - match tfl_type: - case TensorType.FLOAT16: - # There seems to be no counterpart in the TF DataType. - logger.e( - logger.Code.INTERNAL_ERROR, - "tflite_type_to_tensor_flow_data_type(): float16.", - ) - case TensorType.FLOAT32: - return TensorFlowDataType.DT_FLOAT.value - case TensorType.FLOAT64: - return TensorFlowDataType.DT_DOUBLE.value - - case TensorType.INT4: - return TensorFlowDataType.DT_INT4.value - case TensorType.INT8: - return TensorFlowDataType.DT_INT8.value - case TensorType.INT16: - return TensorFlowDataType.DT_INT16.value - case TensorType.INT32: - return TensorFlowDataType.DT_INT32.value - case TensorType.INT64: - return TensorFlowDataType.DT_INT64.value - - case TensorType.UINT8: - return TensorFlowDataType.DT_UINT8.value - case TensorType.UINT16: - return TensorFlowDataType.DT_UINT16.value - case TensorType.UINT32: - return TensorFlowDataType.DT_UINT32.value - case TensorType.UINT64: - return TensorFlowDataType.DT_UINT64.value - - case TensorType.COMPLEX64: - return TensorFlowDataType.DT_COMPLEX64.value - case TensorType.COMPLEX128: - return TensorFlowDataType.DT_COMPLEX128.value - - case TensorType.STRING: - return TensorFlowDataType.DT_STRING.value - - case TensorType.BOOL: - return TensorFlowDataType.DT_BOOL.value - - case TensorType.RESOURCE: - return TensorFlowDataType.DT_RESOURCE.value - case TensorType.VARIANT: - return TensorFlowDataType.DT_VARIANT.value - - case _: - # All TFLite types are covered. Must be an invalid type. - logger.e( - logger.Code.INTERNAL_ERROR, - f"tflite_type_to_tensor_flow_data_type(): invalid TFLite type `{tfl_type}`.", - ) - - -def infer_kernel_shape(weight_tensor: tflite_model.Tensor) -> list[int]: - """Returns the kernel shape inferred from the weight tensor. - - Weight tensors shape expected in TFlite Format, where the 0th index is output channels count, last is input channels - count. - """ - return weight_tensor.shape.vector[1:-1] diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/recurrent_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/recurrent_utils.py index 50b9aef6d18..52b895d60cd 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/recurrent_utils.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/recurrent_utils.py @@ -1,19 +1,12 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.backends.nxp.backend.ir import logger from executorch.backends.nxp.backend.ir.converter.builder import model_builder from executorch.backends.nxp.backend.ir.converter.conversion import translator -from executorch.backends.nxp.backend.ir.converter.conversion.common import ( - OpsList, - try_get_input, -) +from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data -from executorch.backends.nxp.backend.ir.lib.tflite.ActivationFunctionType import ( - ActivationFunctionType, -) from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model @@ -25,12 +18,12 @@ def ensure_correct_tensor_formatting( or RNN operator. The LSTM/RNN may be using channels last tensors, because of the surrounding operators. LSTM/RNN requires its own - format, however I think the input tensors should be marked as 'FORMATLESS', because the main inputs of TFLite - and ONNX version of the operators have the same shape. + format, however I think the input tensors should be marked as 'FORMATLESS', because the main inputs of the + NeutronIR and the ExecuTorch version of the operators have the same shape. I believe that the cleanest and most robust way to solve this, is to mark LSTM/RNN as an operator which can change the formats of its tensors, and solve any format related issues in this module. - :param t_op: TFLite operator with inputs and outputs corresponding to the ONNX LSTM/RNN operator. + :param t_op: NeutronIR operator with inputs and outputs corresponding to the ExecuTorch LSTM/RNN operator. :param builder: ModelBuilder object. :param ops: OpsList object, with operators to add to the model. May already contain some operators. """ @@ -69,44 +62,3 @@ def ensure_correct_tensor_formatting( ops.post_ops.append(transpose) t_op.tmp_outputs[idx].tensor_format = TensorFormat.FORMATLESS - - -def get_activation_function_for_name( - name: str, op_type: str = "LSTM" -) -> ActivationFunctionType: - get_activation_function_for_name.map = { - "Tanh": ActivationFunctionType.TANH, - "Relu": ActivationFunctionType.RELU, - } - - if act_fun := get_activation_function_for_name.map.get(name, None): - return act_fun - - # Couldn't find a corresponding activation function - logger.e( - logger.Code.CONVERSION_IMPOSSIBLE, - f"Conversion of ONNX {op_type} with activation function '{name}' is not possible.", - ) - - -def check_sequence_lens( - t_op: tflite_model.Operator, seq_length: int, op_type: str = "LSTM" -): - """Check if the 'sequence_lens' operand of ONNX LSTM/RNN has an effect. If it does, exit with error. - - :param t_op: TFLite operator with inputs and outputs corresponding to the ONNX operator. - :param seq_length: The first dimension of the main LSTM input. - :param op_type: Operator type of 't_op'. Used only for printing a specific error message. - """ - if sequence_lens := try_get_input(t_op, 4): - # 'sequence_lens' allows each sequence to have a different length. As far as I can tell, TFLite doesn't support - # this. - if (not tensor_has_data(sequence_lens)) or any( - elt != seq_length for elt in sequence_lens.tmp_buffer.data - ): - # The 'sequence_lens' is either dynamic, or static with at least one value different from 'seq_length'. - # Conversion most likely impossible. - logger.e( - logger.Code.CONVERSION_IMPOSSIBLE, - f"Conversion of ONNX {op_type} with 'sequence_lens' input is not possible.", - ) diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py index 1dca3acea74..da92e359f1e 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import numpy as np + from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -16,7 +17,7 @@ def convert_axes_from_attribute( t_op: tflite_model.Operator, builder: ModelBuilder, axes: list[int] | None ): - """Create an `axes` tensor and assign it as an input to the `t_op`, which is expected to represent an ONNX + """Create an `axes` tensor and assign it as an input to the `t_op`, which is expected to represent an ExecuTorch reduction operator. """ x = t_op.tmp_inputs[0] @@ -52,15 +53,15 @@ def ensure_reduce_transposition(builder, ops: OpsList): output_format = output_tensor.tensor_format if input_format.is_channels_last() and output_format.is_channels_last(): - to_onnx_perm = translator.create_channels_last_to_channels_first_permutation( - input_rank + to_executorch_perm = ( + translator.create_channels_last_to_channels_first_permutation(input_rank) ) to_tflite_perm = translator.create_channels_first_to_channels_last_permutation( output_rank, return_list=True ) transpose_before = builder.create_transpose_operator_before( - t_op, 0, to_onnx_perm + t_op, 0, to_executorch_perm ) transpose_before.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST ops.add_pre(transpose_before) @@ -72,7 +73,7 @@ def ensure_reduce_transposition(builder, ops: OpsList): ops.post_ops.insert(0, transpose_after) elif input_format.is_channels_last() and not output_format.is_channels_last(): - # The dimensions of the tensor lose their meaning! Insert a transpose op, to change input to match ONNX. + # The dimensions of the tensor lose their meaning! Insert a transpose op, to change input to match ExecuTorch. permutation = list( translator.create_channels_last_to_channels_first_permutation(input_rank) @@ -83,9 +84,9 @@ def ensure_reduce_transposition(builder, ops: OpsList): ops.add_pre(transpose) elif not input_format.is_channels_last() and output_format.is_channels_last(): - # The ReduceX introduces format to the tensor - # The ONNX ReduceX outputs a 'channels first' tensor. This has to stay the same, and then a Transpose operator - # must be added, to change the tensor to 'channels last'. + # The reduction operator introduces format to the tensor. + # The ExecuTorch reduction operator outputs a 'channels first' tensor. This has to stay the same, and then a + # Transpose operator must be added, to change the tensor to 'channels last'. permutation = list( translator.create_channels_first_to_channels_last_permutation(output_rank) diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/reshape_transposition.py b/backends/nxp/backend/ir/converter/node_converters/shared/reshape_transposition.py index 0e55c27684b..55056614684 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/reshape_transposition.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/reshape_transposition.py @@ -1,4 +1,4 @@ -# Copyright 2023 NXP +# Copyright 2023-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -158,7 +158,7 @@ def ensure_reshape_transposition(builder, ops: OpsList) -> list[int]: new_shape = output_tensor.shape.vector if input_format.is_channels_last() and not output_format.is_channels_last(): - # The dimensions of the tensor lose their meaning! Insert a transpose op, to change input to match ONNX. + # The dimensions of the tensor lose their meaning! Insert a transpose op, to change input to match ExecuTorch. permutation = list( translator.create_channels_last_to_channels_first_permutation(input_rank) @@ -170,7 +170,7 @@ def ensure_reshape_transposition(builder, ops: OpsList) -> list[int]: elif not input_format.is_channels_last() and output_format.is_channels_last(): # The Reshape introduces format to the tensor (2D -> 4D for example) - # The ONNX Reshape outputs a 'channels first' tensor. This has to stay the same, and then a Transpose operator + # The `view_copy` outputs a 'channels first' tensor. This has to stay the same, and then a Transpose operator # must be added, to change the tensor to 'channels last'. permutation = list( diff --git a/backends/nxp/backend/ir/converter/quantization_utils.py b/backends/nxp/backend/ir/converter/quantization_utils.py index d9e7674d953..11de4eec13c 100755 --- a/backends/nxp/backend/ir/converter/quantization_utils.py +++ b/backends/nxp/backend/ir/converter/quantization_utils.py @@ -1,111 +1,19 @@ -# Copyright 2023 NXP +# Copyright 2023-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy -from typing import Iterable, List, Optional - -import executorch.backends.nxp.backend.ir.converter.builder.model_builder as model_builder +from typing import List import numpy as np + from executorch.backends.nxp.backend.ir import logger as logger -from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( - tf_lite_type_to_numpy, -) -from executorch.backends.nxp.backend.ir.lib.tflite import TensorType as tflTensorType -from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType from executorch.backends.nxp.backend.ir.tflite_generator import ( tflite_model as tflite_model, ) -def quantization_is_equal( - x_scale: np.ndarray, - x_zp: np.ndarray, - x_type: TensorType, - y_scale: np.ndarray, - y_zp: np.ndarray, - y_type: TensorType, -) -> bool: - """Determine if provided quantization parameters of tensors 'x' and 'y' are the same. - - :param x_scale: Scale of the 'x' tensor. - :param x_zp: Zero point of the 'x' tensor. - :param x_type: TFLite data type of the 'x' tensor. - :param y_scale: Scale of the 'y' tensor. - :param y_zp: Zero point of the 'y' tensor. - :param y_type: TFLite data type of the 'y' tensor. - :return: True, if the quantization parameters are equal. - """ - if x_type != y_type: - return False - - if not (x_scale.size == x_zp.size == y_scale.size == y_zp.size): - return False - - x_scale, x_zp = quantization_params_to_lists(x_scale, x_zp) - y_scale, y_zp = quantization_params_to_lists(y_scale, y_zp) - - return all( - x_s == y_s and x_z == y_z - for x_s, y_s, x_z, y_z in zip(x_scale, y_scale, x_zp, y_zp) - ) - - -def quantization_params_to_lists( - scale: np.ndarray, zero_point: np.ndarray -) -> (List[float], List[int]): - if (scale is None) or (zero_point is None): - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing zero_point and/or scale quantization params when converting to list!", - ) - - if (scale.size == 1) and (zero_point.size == 1): - # Per tensor quantization - scale = [scale.item()] - zero_point = [zero_point.item()] - elif (scale.size != 1) and (zero_point.size != 1): - # Per channel quantization - scale = scale.tolist() - zero_point = zero_point.tolist() - else: - logger.e( - logger.Code.CONVERSION_IMPOSSIBLE, - "TFLite doesn't support combination of per-channel and per-tensor quantization params.", - ) - - return scale, zero_point - - -def is_quantization_valid(scale, zero_point): - return scale.size == zero_point.size - - -def is_per_tensor_quantized(scale, zero_point): - return (scale.size == 1) and (zero_point.size == 1) - - -def is_per_channel_quantized(scale, zero_point): - return is_quantization_valid(scale, zero_point) and not is_per_tensor_quantized( - scale, zero_point - ) - - -def get_symmetric_zero_point_for_type(tensor_type: TensorType): - match tensor_type: - case TensorType.INT8: - return 0 - case TensorType.UINT8: - return 128 - case _: - logger.e( - logger.Code.INTERNAL_ERROR, - f"Attempt to get zero point definition for type: {tensor_type}", - ) - - def _validate_or_set_quant_params( tensor: tflite_model.Tensor, quant: tflite_model.Quantization ) -> bool: @@ -130,7 +38,7 @@ def propagate_quantization( """ Propagates quantization parameters from from_tensor to to_tensor. If to_tensor already has the params set checks the consistency. - :raises: logger.Error - INVALID_ONNX_MODEL + :raises: logger.Error - INVALID_INPUT_MODEL """ if ( @@ -147,7 +55,7 @@ def propagate_quantization( # noinspection PyTypeChecker if not _validate_or_set_quant_params(to_tensor, from_tensor.quantization): logger.e( - logger.Code.INVALID_ONNX_MODEL, + logger.Code.INVALID_INPUT_MODEL, f'Mismatched quantization parameters between tensors "{from_tensor.name}" and "{to_tensor.name}"', ) @@ -161,16 +69,16 @@ def set_quantization_parameters_to_tensor( """Create a TFLite QuantizationParameters object, initialize it from given parameters and add it to the 'tflite_tensor'. :param tflite_tensor: The TFLite tensor in the model, to add the quantization to. - :param scale: The data of the tensor, which is an input of a quantized ONNX operator and represents the + :param scale: The data of the tensor, which is an input of a quantized ExecuTorch operator and represents the quantization scale. - :param zero_point: The data of the tensor, which is an input of a quantized ONNX operator and represents the + :param zero_point: The data of the tensor, which is an input of a quantized ExecuTorch operator and represents the quantization zero point. :param quantized_dimension: The quantized dimension attribute of TFLite QuantizationParameters. """ if (scale is None) or (zero_point is None): logger.e( logger.Code.NOT_IMPLEMENTED, - "Conversion of ONNX quantized operators is only supported when " + "Conversion of ExecuTorch quantized operators is only supported when " "the quantization parameters are static!", ) @@ -184,8 +92,8 @@ def set_quantization_parameters_to_tensor( if scale.size != zero_point.size: logger.e( - logger.Code.INVALID_ONNX_MODEL, - f"The per channel quantization parameters of ONNX tensor " + logger.Code.INVALID_INPUT_MODEL, + f"The per channel quantization parameters of ExecuTorch tensor " f"'{tflite_tensor.name}' are of different sizes! ('{scale.size}'" f" != '{zero_point.size}')", ) @@ -193,8 +101,8 @@ def set_quantization_parameters_to_tensor( quantized_dimension_size = tflite_tensor.shape.get(quantized_dimension) if scale.size != quantized_dimension_size: logger.e( - logger.Code.INVALID_ONNX_MODEL, - f"The ONNX per channel quantization parameter vectors do not " + logger.Code.INVALID_INPUT_MODEL, + f"The ExecuTorch per channel quantization parameter vectors do not " f"match the size of the quantized dimension! ('{scale.size}' != " f"'{quantized_dimension_size}')", ) @@ -205,8 +113,8 @@ def set_quantization_parameters_to_tensor( else: # Combination of per tensor and per channel quantization parameters logger.e( - logger.Code.INVALID_ONNX_MODEL, - f"ONNX tensor '{tflite_tensor.name}' uses a combination of per " + logger.Code.INVALID_INPUT_MODEL, + f"ExecuTorch node '{tflite_tensor.name}' uses a combination of per " f"tensor and per channel quantization parameters. Conversion to " f"TFLite is not possible!", ) @@ -218,33 +126,12 @@ def set_quantization_parameters_to_tensor( ) if not _validate_or_set_quant_params(tflite_tensor, quant): logger.e( - logger.Code.INVALID_ONNX_MODEL, + logger.Code.INVALID_INPUT_MODEL, f'Mismatched quantization parameters between tensors: "{tflite_tensor.name}" already ' f"has the quantization params set", ) -def calculate_uint_to_int_re_quantization_zero_point( - data_type_byte_size: int, old_zero_point: Iterable[int] -) -> np.ndarray: - """ - Calculate the new zero points, after a quantized tensor with an unsigned int data type is re-quantized to - a signed type. - :param data_type_byte_size: Size of the data type that is used, in Bytes. For example 1 for INT8. - :param old_zero_point: The zero point quantisation parameter, of the original data, before re-quantization. - :return: The new zero point quantisation parameter, after re-quantization. - """ - data_type_bit_size = 8 * data_type_byte_size - zero_point_shift = 2 ** (data_type_bit_size - 1) - return np.asarray(np.subtract(np.array(old_zero_point, np.int32), zero_point_shift)) - - -def _re_quantize_uint8_to_int8(tensor_data: np.ndarray) -> np.ndarray: - """Re-quantize static uint8 data to int8.""" - int16_data = np.asarray(tensor_data, np.int16) - return np.array(int16_data - 128, np.int8) - - def quantize_int8( data: np.ndarray, scale: List[float], zero_point: List[int] ) -> np.ndarray: @@ -252,20 +139,6 @@ def quantize_int8( return np.clip(new_data, -128, 127).astype(np.int8) -def quantize_uint8( - data: np.ndarray, scale: List[float], zero_point: List[int] -) -> np.ndarray: - new_data = np.add(np.round(np.divide(data, scale)), zero_point) - return np.clip(new_data, 0, 255).astype(np.uint8) - - -def quantize_int32( - data: np.ndarray, scale: List[float], zero_point: List[int] -) -> np.ndarray: - new_data = np.add(np.round(np.divide(data, scale)), zero_point) - return np.clip(new_data, -2_147_483_648, 2_147_483_648).astype(np.int32) - - def dequantize( data: np.ndarray, scale: List[float], zero_point: List[int] ) -> np.ndarray: @@ -274,211 +147,3 @@ def dequantize( scale, dtype=np.float32, ) - - -def re_quantize_static_tensor( - builder: "model_builder.ModelBuilder", - tflite_tensor: tflite_model.Tensor, - to_type: tflTensorType.TensorType, - new_scale: Optional[List[float]] = None, - new_zero_point: Optional[List[int]] = None, -) -> tflite_model.Tensor: - """Create a new TFLite Tensor with new quantization parameters, type and data. - - :param builder: A ModelBuilder instance. - :param tflite_tensor: TFLite tensor to re-quantize. - :param to_type: The TFLite TensorType, that the tensor will be re-quantized to. - :param new_scale: New scale quantization parameter. Used only when re-quantizing to the same type. - :param new_zero_point: New zero point quantization parameter. Used only when re-quantizing to the same type. - :return: A new re-quantized tensor. - """ - if tflite_tensor.quantization is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "translator.re_quantize_static_tensor(): Got tensor without quantization!", - ) - - if tflite_tensor.tmp_buffer.data is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "translator.re_quantize_static_tensor(): Got tensor without static data!", - ) - - new_dtype = tf_lite_type_to_numpy(to_type) - re_quantized_tensor = builder.duplicate_tensor(tflite_tensor) - tensor_data = re_quantized_tensor.tmp_buffer.data - - if tensor_data.dtype == np.uint8 and new_dtype == np.int8: # INT8 -> UINT8 - re_quantized_tensor.tmp_buffer.data = _re_quantize_uint8_to_int8(tensor_data) - re_quantized_tensor.type = tflTensorType.TensorType.INT8 - calculated_zero_point = calculate_uint_to_int_re_quantization_zero_point( - 1, re_quantized_tensor.quantization.zero_point.vector - ) - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(calculated_zero_point) - ) - - elif tensor_data.dtype == np.int32 and new_dtype == np.int8: # INT32 -> INT8 - if new_zero_point is None or new_scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when re-quantizing tensor.", - ) - - old_zp = re_quantized_tensor.quantization.zero_point.vector - old_scale = re_quantized_tensor.quantization.scale.vector - float_data = dequantize(tensor_data, old_scale, old_zp) - int8_data = quantize_int8(float_data, new_scale, new_zero_point) - - re_quantized_tensor.tmp_buffer.data = int8_data - re_quantized_tensor.type = tflTensorType.TensorType.INT8 - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(new_zero_point) - ) - re_quantized_tensor.quantization.scale = tflite_model.Scale(list(new_scale)) - - elif tensor_data.dtype == np.int32 and new_dtype == np.uint8: # INT32 -> UINT8 - if new_zero_point is None or new_scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when re-quantizing tensor.", - ) - - old_zp = re_quantized_tensor.quantization.zero_point.vector - old_scale = re_quantized_tensor.quantization.scale.vector - float_data = dequantize(tensor_data, old_scale, old_zp) - uint8_data = quantize_uint8(float_data, new_scale, new_zero_point) - - re_quantized_tensor.tmp_buffer.data = uint8_data - re_quantized_tensor.type = tflTensorType.TensorType.UINT8 - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(new_zero_point) - ) - re_quantized_tensor.quantization.scale = tflite_model.Scale(list(new_scale)) - - elif tensor_data.dtype == np.int8 and new_dtype == np.int8: # INT8 -> INT8 - # Re-quantizing int8 tensor data with different quantization parameters - if new_zero_point is None or new_scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when re-quantizing tensor.", - ) - - zero_point_data = re_quantized_tensor.quantization.zero_point.vector - scale_data = re_quantized_tensor.quantization.scale.vector - new_tensor_data = dequantize(tensor_data, scale_data, zero_point_data) - - re_quantized_tensor.tmp_buffer.data = quantize_int8( - new_tensor_data, new_scale, new_zero_point - ) - re_quantized_tensor.quantization.scale = tflite_model.Scale(new_scale) - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - new_zero_point - ) - - elif tensor_data.dtype == np.int32 and new_dtype == np.int32: # INT32 -> INT32 - if new_zero_point is None or new_scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when re-quantizing tensor.", - ) - - old_zp = re_quantized_tensor.quantization.zero_point.vector - old_scale = re_quantized_tensor.quantization.scale.vector - float_data = dequantize(tensor_data, old_scale, old_zp) - int32_data = quantize_int32(float_data, new_scale, new_zero_point) - - re_quantized_tensor.tmp_buffer.data = int32_data - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(new_zero_point) - ) - re_quantized_tensor.quantization.scale = tflite_model.Scale(list(new_scale)) - - else: - logger.e( - logger.Code.NOT_IMPLEMENTED, - f"Re-quantization of static tensors from type '{tensor_data.dtype}' " - f"to type '{to_type}' is not yet implemented!", - ) - - return re_quantized_tensor - - -def quantize_static_float_tensor( - builder: "model_builder.ModelBuilder", - tflite_tensor: tflite_model.Tensor, - to_type: tflTensorType.TensorType, - scale: List[float], - zero_point: List[int], - quantized_dimension: int = 0, -) -> tflite_model.Tensor: - """Quantize tensor 'tflite_tensor' with passed quantization params. - - :param builder: A ModelBuilder instance. - :param tflite_tensor: TFLite tensor to quantize. - :param to_type: The TFLite TensorType, that the tensor will be quantized to. - :param scale: Scale quantization parameter. - :param zero_point: Zero point quantization parameter. - :param quantized_dimension: Quantized dimension. - """ - if tflite_tensor.quantization is not None: - logger.e(logger.Code.INTERNAL_ERROR, "Got tensor with quantization!") - - if tflite_tensor.tmp_buffer.data is None: - logger.e(logger.Code.INTERNAL_ERROR, "Got tensor without static data!") - - quantized_tensor = builder.duplicate_tensor(tflite_tensor) - tensor_data = quantized_tensor.tmp_buffer.data - - if zero_point is None or scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when quantizing tensor.", - ) - - new_dtype = tf_lite_type_to_numpy(to_type) - - if tensor_data.dtype == np.float32 and new_dtype == np.int8: - int8_data = quantize_int8(tensor_data, scale, zero_point) - - quantized_tensor.tmp_buffer.data = int8_data - quantized_tensor.type = tflTensorType.TensorType.INT8 - quantized_tensor.quantization = tflite_model.Quantization() - quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(zero_point) - ) - quantized_tensor.quantization.scale = tflite_model.Scale(list(scale)) - quantized_tensor.quantization.quantized_dimension = quantized_dimension - - elif tensor_data.dtype == np.float32 and new_dtype == np.uint8: - uint8_data = quantize_uint8(tensor_data, scale, zero_point) - - quantized_tensor.tmp_buffer.data = uint8_data - quantized_tensor.type = tflTensorType.TensorType.UINT8 - quantized_tensor.quantization = tflite_model.Quantization() - quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(zero_point) - ) - quantized_tensor.quantization.scale = tflite_model.Scale(list(scale)) - quantized_tensor.quantization.quantized_dimension = quantized_dimension - - elif tensor_data.dtype == np.float32 and new_dtype == np.int32: - int32_data = quantize_int32(tensor_data, scale, zero_point) - - quantized_tensor.tmp_buffer.data = int32_data - quantized_tensor.type = tflTensorType.TensorType.INT32 - quantized_tensor.quantization = tflite_model.Quantization() - quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(zero_point) - ) - quantized_tensor.quantization.scale = tflite_model.Scale(list(scale)) - quantized_tensor.quantization.quantized_dimension = quantized_dimension - - else: - logger.e( - logger.Code.NOT_IMPLEMENTED, - f"Quantization of static tensors from type '{tensor_data.dtype}' " - f"to type '{to_type}' is not yet implemented!", - ) - - return quantized_tensor diff --git a/backends/nxp/backend/ir/logger.py b/backends/nxp/backend/ir/logger.py index ce8da2a31df..8019fb4d780 100644 --- a/backends/nxp/backend/ir/logger.py +++ b/backends/nxp/backend/ir/logger.py @@ -1,6 +1,6 @@ # # Copyright 2023 Martin Pavella -# Copyright 2023 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -85,18 +85,18 @@ class Code(Enum): PREPROCESSING_ERROR = 4 UNSUPPORTED_OPERATOR = 21 - UNSUPPORTED_ONNX_TYPE = 22 + # Code 22 was removed. UNSUPPORTED_OPERATOR_ATTRIBUTES = 23 NOT_IMPLEMENTED = 24 INVALID_TYPE = 31 INVALID_TENSOR_SHAPE = 32 - INVALID_ONNX_OPERATOR = 33 - INVALID_ONNX_OPERATOR_ATTRIBUTE = 34 - INVALID_ONNX_MODEL = 35 + # Code 33 was removed. + INVALID_OPERATOR_ATTRIBUTE = 34 + INVALID_INPUT_MODEL = 35 CONVERSION_IMPOSSIBLE = 41 - SHAPE_INFERENCE_ERROR = 42 + # Code 42 was removed. IO_PRESERVATION_ERROR = 43 INVALID_INPUT = 51 @@ -142,8 +142,6 @@ class BasicLoggingContext(LoggingContext): """ GLOBAL = LoggingContext("global") - SHAPE_INFERENCE = LoggingContext("shape_inference") - ONNX_PARSER = LoggingContext("onnx_parser") OPERATOR_CONVERSION = LoggingContext("operator_conversion") TFLITE_GENERATOR = LoggingContext("tflite_generator") QDQ_QUANTIZER = LoggingContext("qdq_quantizer") @@ -151,7 +149,7 @@ class BasicLoggingContext(LoggingContext): class NodeLoggingContext(LoggingContext): """ - ONNX node specific context. Logs reported within this context are related to node with index 'node_id'. + ExecuTorch node specific context. Logs reported within this context are related to node with index 'node_id'. """ def __init__(self, node_id): @@ -213,7 +211,7 @@ def _get_node_error(self, node_id: int, dict_item: str) -> Code | str | None: Return first error log item that belong to node with id 'node_id'. If no error is present None is returned instead. - :param node_id: ONNX node id. + :param node_id: ExecuTorch node id. :param dict_item: Dictionary item to return from `log` :return: Error code or None if there's no error related to node. """ @@ -230,7 +228,7 @@ def get_node_error_code(self, node_id: int) -> Code | None: Return first error code that belong to node with id 'node_id'. If no error is present None is returned instead. - :param node_id: ONNX node id. + :param node_id: ExecuTorch node id. :return: Error code or None if there's no error related to node. """ @@ -241,7 +239,7 @@ def get_node_error_message(self, node_id: int) -> str | None: Return first error message that belong to node with id 'node_id'. If no error is present None is returned instead. - :param node_id: ONNX node id + :param node_id: ExecuTorch node id :return: Error message or None if there is no error related to node. """ @@ -256,7 +254,7 @@ class loggingContext: Context manager used to nest logging contexts. Usage: with loggingContext(BasicLoggingContext.GLOBAL): - with loggingContext(BasicLoggingContext.ONNX_PARSER): + with loggingContext(BasicLoggingContext.OPERATOR_CONVERSION): logger.i("My log") # this log is automatically assigned to both parent contexts """ diff --git a/backends/nxp/backend/ir/tensor_formatting.py b/backends/nxp/backend/ir/tensor_formatting.py index aab22c3c368..db24576e81f 100644 --- a/backends/nxp/backend/ir/tensor_formatting.py +++ b/backends/nxp/backend/ir/tensor_formatting.py @@ -1,6 +1,5 @@ -# # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -26,7 +25,7 @@ class TensorFormat(Enum): TRANSPOSE_CONV_2D_WEIGHT_FORMAT = 13 # No special format (matrices, vectors, shapes etc.). All tensors with the FORMATLESS format MUST have EXACTLY - # the same shape and data in the TFLite model and in the ONNX model. + # the same shape and data in the NeutronIR model and in the ExecuTorch model. FORMATLESS = 20 NONE = 30 # Format has not been identified diff --git a/backends/nxp/backend/ir/tflite_generator/tflite_model.py b/backends/nxp/backend/ir/tflite_generator/tflite_model.py index a9384861178..76a50a2e177 100755 --- a/backends/nxp/backend/ir/tflite_generator/tflite_model.py +++ b/backends/nxp/backend/ir/tflite_generator/tflite_model.py @@ -1,6 +1,5 @@ -# # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -272,8 +271,7 @@ def is_per_tensor(self) -> bool: return False def gen_tflite(self, builder: fb.Builder): - # Sometimes 1D per-tensor quantized tensors can have quantized_dimension != 0 - # (residue from badly defined ONNX models). This would cause TFLite inference to crash. + # Sometimes 1D per-tensor quantized tensors can have quantized_dimension != 0. if not self.is_per_channel(): self.quantized_dimension = 0 @@ -513,7 +511,7 @@ class Operator(meta.TFLiteObject): tmp_outputs: List[Tensor] tmp_version: int # OperatorConverter uses this to assign the corresponding operator code with correct version. - # If `True`, this is an extra operator added during conversion. It was not present in the original ONNX model. + # If `True`, this is an extra operator added during conversion. It was not present in the original input model. tmp_added_extra: bool def __init__( diff --git a/backends/nxp/backend/ir/tflite_optimizer/operator_rules.py b/backends/nxp/backend/ir/tflite_optimizer/operator_rules.py index 253dc9c69a1..e861eff0d18 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/operator_rules.py +++ b/backends/nxp/backend/ir/tflite_optimizer/operator_rules.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -100,23 +100,3 @@ def __call__( operator_is_type(preceding_op, self.single_preceding_op_type, builder) for preceding_op in preceding_ops ) - - -@dataclass -class WasNotInTheOriginalONNXModel(OpRule): - """Assures that this operator wasn't created by converting an ONNX operator from the original model, but instead - was added extra in order to convert a different operator. - - This rule is currently only satisfied for operators added by ModelBuilder methods `create_..._before()` and - `create_..._after()`. - """ - - def __call__( - self, - op: tflite_model.Operator, - tensor_map: NameToTensorMap, - input_to_ops_map: InputTensorToOpsMap, - output_to_op_map: OutputTensorToOpMap, - builder: "model_builder.ModelBuilder", - ) -> bool: - return op.tmp_added_extra diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/permute_fully_connected_weights_after_reshape.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/permute_fully_connected_weights_after_reshape.py index 42eefc1ab56..ef76fad90de 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/permute_fully_connected_weights_after_reshape.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/permute_fully_connected_weights_after_reshape.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -50,7 +50,7 @@ def __call__(self) -> bool: How it works: - The original model doesn't have the `Transpose`. It just has `Reshape` into `MatMul` (or `Gemm`...). - The `Transpose` is added, because the `Reshape` has a channels last input, which was originally - channels first (in the ONNX model), and so the 2D output of the `Reshape` would have the same data. + channels first (in the ExecuTorch model), and so the 2D output of the `Reshape` would have the same data. but at different locations. The `Transpose` makes the input channels first, which ensures correct output of the `Reshape`. - In the scenario in the graph above, it is possible to omit the `Transpose`, which causes the `Reshape` @@ -85,12 +85,12 @@ def __call__(self) -> bool: for (transpose, reshape, fc), tensor_map, _, _ in matcher.match_patterns(): # Make sure the `Transpose` is applying the expected permutation. y = tensor_map["y"] - to_onnx_perm = ( + to_executorch_perm = ( translator.create_channels_last_to_channels_first_permutation( y.shape.len() ) ) - if not np.allclose(to_onnx_perm, tensor_map["perm"].tmp_buffer.data): + if not np.allclose(to_executorch_perm, tensor_map["perm"].tmp_buffer.data): continue # The `Transpose` has an unexpected permutation. w = tensor_map["w"]