Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions backends/nxp/backend/ir/conversion_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 NXP
# Copyright 2024-2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -14,7 +14,6 @@ def __init__(self, args: dict | None = None):
:param args: Optional dictionary with conversion arguments. Unknown arguments are ignored.
"""
self.keep_io_format: bool = False
self.skip_shape_inference: bool = False
self.allow_inputs_stripping: bool = True
self.qdq_aware_conversion: bool = True
self.symbolic_dimensions_mapping: dict[str, int] | None = None
Expand Down Expand Up @@ -46,15 +45,6 @@ def __repr__(self):
return "ConversionConfig[" + ", ".join(attrs) + "]"


class SkipShapeInferenceConfig(ConversionConfig):

def __init__(self):
"""
Conversion config shortcut with disabled shape inference.
"""
super().__init__({"skip_shape_inference": True})


class QDQAwareConfig(ConversionConfig):

def __init__(self):
Expand Down
47 changes: 13 additions & 34 deletions backends/nxp/backend/ir/converter/builder/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# Copyright 2023 Martin Pavella
# Copyright 2023-2024 NXP
# Copyright 2023-2025 NXP
#
# License: MIT
# See the LICENSE_MIT for more details.
Expand Down Expand Up @@ -795,29 +795,8 @@ def _remove_tensor_with_name(self, name):

def append_new_tensor(self, t_tensor: tflite_model.Tensor, overwrite: bool = False):
"""Append the TFLite tensor 't_tensor' to the 'SubGraph.tensors' and register it."""

if t_tensor.name in self._tensor_name_map.keys():
"""Tensor has already been added. Sometimes however, ONNX models
will have tensors in their 'inputs' or 'outputs', which don't
belong there and are in fact static. I this case we need to
overwrite the existing tensors."""

if overwrite:
self._remove_tensor_with_name(t_tensor.name)

# If the tenor previously appeared in ONNX 'inputs' or 'outputs',
# the old version MUST be removed from there.
self._remove_input_with_name(t_tensor.name)
self._remove_output_with_name(t_tensor.name)

self.get_tensors().append(t_tensor)
self._tensor_name_map[t_tensor.name] = t_tensor
else:
logger.w(f"Tensor '{t_tensor.name}' is already in the tensors!")

else:
self._tensor_name_map[t_tensor.name] = t_tensor
self.get_tensors().append(t_tensor)
self._tensor_name_map[t_tensor.name] = t_tensor
self.get_tensors().append(t_tensor)

def append_new_buffer(self, buffer: tflite_model.Buffer):
"""Append the 'buffer' to the 'model.buffers'."""
Expand Down Expand Up @@ -1515,7 +1494,7 @@ def prepare_dynamic_tensor_for_correct_broadcasting_with_channels_first_tensors(
# Prepend a partial identity, to keep leading dimensions unchanged.
revert_perm = list(range(rank_diff)) + list(revert_perm)

# Now add a permutation to convert the extended ONNX shape to a TFLite shape
# Now add a permutation to convert the extended ExecuTorch shape to a TFLite shape
to_tflite_perm = (
translator.create_channels_first_to_channels_last_permutation(
output_rank
Expand Down Expand Up @@ -1579,37 +1558,37 @@ def prepare_static_tensor_for_correct_broadcasting_with_channels_first_tensors(

original_shape = translator.dims_to_channels_first(
shape
) # Same shape as in the ONNX model
) # Same shape as in the ExecuTorch model

# Prepend 1s to the shape
extended_onnx_shape = [1] * rank_diff + original_shape
extended_executorch_shape = [1] * rank_diff + original_shape

# Convert the full shape to TFLite format
tflite_shape = translator.dims_to_channels_last(extended_onnx_shape)
tflite_shape = translator.dims_to_channels_last(extended_executorch_shape)
tensor.shape = tflite_model.Shape(tflite_shape)

# Statically transpose the data
data = translator.convert_data_to_channels_first(
data
) # To the same shape as in the ONNX model
data = data.reshape(extended_onnx_shape) # Extend with leading 1s
) # To the same shape as in the ExecuTorch model
data = data.reshape(extended_executorch_shape) # Extend with leading 1s
tensor.tmp_buffer.data = translator.convert_data_to_channels_last(
data
) # Convert to TFLite format

assert tflite_shape == list(tensor.tmp_buffer.data.shape)

else:
# The tensor is the same as in the ONNX model.
# The tensor is the same as in the ExecuTorch model.

extended_onnx_shape = [1] * rank_diff + shape
extended_executorch_shape = [1] * rank_diff + shape

# Convert the full shape to TFLite format
tflite_shape = translator.dims_to_channels_last(extended_onnx_shape)
tflite_shape = translator.dims_to_channels_last(extended_executorch_shape)
tensor.shape = tflite_model.Shape(tflite_shape)

# Statically transpose the data
data = data.reshape(extended_onnx_shape) # Extend with leading 1s
data = data.reshape(extended_executorch_shape) # Extend with leading 1s
tensor.tmp_buffer.data = translator.convert_data_to_channels_last(
data
) # Convert to TFLite format
Expand Down
63 changes: 6 additions & 57 deletions backends/nxp/backend/ir/converter/conversion/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# Copyright 2023 Martin Pavella
# Copyright 2023-2024 NXP
# Copyright 2023-2025 NXP
#
# License: MIT
# See the LICENSE_MIT for more details.
Expand All @@ -12,7 +12,7 @@
'conversion/builtin/' directory.
"""

from typing import Any, List, MutableSequence, Optional
from typing import List, MutableSequence, Optional

import executorch.backends.nxp.backend.ir.logger as logger
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
Expand All @@ -22,28 +22,8 @@
max_pool_2d_options,
transpose_conv_options,
)
from torch.fx import Node


def exactly_one_is_none(obj1: Optional, obj2: Optional) -> bool:
"""Determine if exactly 1 of the arguments is None, or not."""
return (obj1 is None and obj2 is not None) or (obj1 is not None and obj2 is None)


def contains_duplicates(list_to_check: List[Any]) -> bool:
"""Determine if given list has duplicate elements or not."""
return len(list_to_check) != len(set(list_to_check))


def clamp(val: int, start: int, end: int) -> int:
"""Clamp an int value between start and end (inclusive) and return it."""
if val < start:
return start

elif val > end:
return end

return val
from torch.fx import Node


def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor | None:
Expand All @@ -62,11 +42,6 @@ def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor

tensor = t_op.tmp_inputs[idx]

if tensor.name == "":
# ONNX allows the name "" for optional tensors. It indicates that the tensor should be ignored, and a default
# value should be used. Just like if the tensor was omitted altogether.
return None

return tensor


Expand Down Expand Up @@ -101,7 +76,7 @@ def assign_2d_strides(options: StridedOptions, strides: Optional[List[int]]):
If 'strides' is None, assign 1s.

:param options: TFLite AveragePool2D, Conv2D, MaxPool2D or TransposeConv options object.
:param strides: An optional list of ONNX strides attribute.
:param strides: An optional list of ExecuTorch strides attribute.
"""

if strides is None:
Expand All @@ -115,8 +90,8 @@ def assign_2d_strides(options: StridedOptions, strides: Optional[List[int]]):

else:
logger.e(
logger.Code.INVALID_ONNX_OPERATOR_ATTRIBUTE,
f"ONNX operator has invalid 'strides' attribute! ('{strides}')",
logger.Code.INVALID_OPERATOR_ATTRIBUTE,
f"ExecuTorch operator has invalid 'strides' attribute! ('{strides}')",
)


Expand Down Expand Up @@ -188,32 +163,6 @@ def node_uses_shape_broadcasting(node: Node) -> bool:
)


def uses_multiple_input_types(t_op: tflite_model.Operator) -> bool:
"""Determine if the input tensors of given TFLite operator use different data types or not.

:param t_op: TFLite operator with 'tmp_inputs' initialized.
:return: True, if any two input tensors have a different data type.
False, if all input tensors use the same data type.
"""

if t_op.tmp_inputs is None:
logger.e(
logger.Code.INTERNAL_ERROR,
"common.uses_multiple_input_types(): 'tmp_inputs' are None!",
)

if len(t_op.tmp_inputs) == 0:
logger.e(
logger.Code.INTERNAL_ERROR,
"common.uses_multiple_input_types(): Operator has no inputs!",
)

first_input_type = t_op.tmp_inputs[0].type
return any(
input_tensor.type != first_input_type for input_tensor in t_op.tmp_inputs[1:]
)


class OpsList:
"""
Holder of TFLite operator (middle_op) that can be prefixed (pre_ops) of suffixed (post_ops)
Expand Down
Loading
Loading