Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.fx import Node
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NodeFormatInference,
Expand Down Expand Up @@ -54,19 +55,22 @@ class EdgeProgramToIRConverter:
"""

_default_conversion_config = ConversionConfig()
_default_target_spec = NeutronTargetSpec("imxrt700", "SDK_25_09")
_default_delegation_options = CustomDelegationOptions()

def convert_program(
self,
edge_program: ExportedProgram,
conversion_config=_default_conversion_config,
conversion_config: ConversionConfig = _default_conversion_config,
neutron_target_spec: NeutronTargetSpec = _default_target_spec,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
) -> (bytes, dict):
"""
Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.

:param edge_program: Converter ExportedProgram.
:param conversion_config: ConversionConfig instance.
:param neutron_target_spec: Object for querying the target platform to retrieve its properties.
:param custom_delegation_options: Custom user options which affect node delegation.
:return: TFLite flatbuffers as bytes.
"""
Expand All @@ -76,6 +80,7 @@ def convert_program(
cc = self.build_conversion_context(
parameters_mapping,
node_formats,
neutron_target_spec,
conversion_config,
custom_delegation_options,
)
Expand Down Expand Up @@ -173,11 +178,12 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
def build_conversion_context(
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
neutron_target_spec: NeutronTargetSpec,
conversion_config: ConversionConfig = _default_conversion_config,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
) -> ConversionContext:
tflite_builder = AtenModelBuilderDirector(
3, "TFLite from EdgeProgram", conversion_config
3, "TFLite from EdgeProgram", neutron_target_spec, conversion_config
)

# Add "sentinel" buffer (defined in schema.fbs)
Expand Down
70 changes: 41 additions & 29 deletions backends/nxp/backend/ir/converter/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
FlexTranspose,
)
from executorch.backends.nxp.backend.ir.tflite_optimizer import optimizer
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec


class ModelBuilder:
Expand All @@ -74,17 +75,21 @@ class ModelBuilder:

_zeros_tensor_map: Dict # Mapping 'string' shapes to 'tflT.Tensor' objects

_default_conversion_config = ConversionConfig()
neutron_target_spec: NeutronTargetSpec

conversion_config: ConversionConfig

_default_conversion_config = ConversionConfig()

def __init__(
self,
model_version: int,
model_description: str,
neutron_target_spec: NeutronTargetSpec,
conversion_config: ConversionConfig = _default_conversion_config,
) -> None:
self._tfl_model = tflite_model.Model(model_version, model_description)
self.neutron_target_spec = neutron_target_spec
self.conversion_config = conversion_config

self.op_code_type_index_map = {}
Expand Down Expand Up @@ -471,31 +476,7 @@ def finish(self) -> tflite_model.Model:

return self._tfl_model

def _assign_tensor_and_buffer_indices( # noqa C901
self, allow_inputs_stripping: bool
):
"""Correctly initialize all references via indices in all tensors and buffers."""

# Assign each buffer its index
for i, buffer in enumerate(self.get_buffers().vector):
buffer.tmp_index = i

# Assign each tensor its index and its buffer index
for i, tensor in enumerate(self.get_tensors().vector):
if tensor.tmp_null_tensor:
# Using -1 as the index to the 'tensors' vector is way of telling the TFLite inference engine, that
# this tensor should not be used.
# https://github.com/tensorflow/tensorflow/blob/05404d959119d41a8ffb8a75c6f232cfd8540d45/tensorflow/lite/kernels/kernel_util.cc#L79-L98
tensor.tmp_index = -1
else:
tensor.tmp_index = i

tensor.buffer = tensor.tmp_buffer.tmp_index

# TODO Remove inputs and outputs that are not in the tensors collection

# Assign 'Outputs' and 'Inputs' their tensor indices
outputs = self.get_sub_graph().outputs
def _assign_io_tensor_indices(self, inputs, outputs, allow_inputs_stripping: bool):
for tensor in outputs.tmp_outputs:
try:
outputs.append(tensor.tmp_index)
Expand All @@ -505,7 +486,6 @@ def _assign_tensor_and_buffer_indices( # noqa C901
f"The tensor '{tensor.name}' is among the model outputs, but does NOT appear in the graph!",
)

inputs = self.get_sub_graph().inputs
for tensor in inputs.tmp_inputs:
try:
inputs.append(tensor.tmp_index)
Expand All @@ -520,14 +500,46 @@ def _assign_tensor_and_buffer_indices( # noqa C901
f"The tensor '{tensor.name}' is among the model inputs, but does NOT appear in the graph!",
)

# Assign each operator its inputs and outputs indices
for operator in self.get_sub_graph().operators.vector:
def _assign_operators_io_tensor_indices(self, operators):
for operator in operators.vector:
for inputTensor in operator.tmp_inputs:
operator.inputs.append(inputTensor.tmp_index)

for outputTensor in operator.tmp_outputs:
operator.outputs.append(outputTensor.tmp_index)

def _assign_tensor_and_buffer_indices(self, allow_inputs_stripping: bool):
"""Correctly initialize all references via indices in all tensors and buffers."""

# Assign each buffer its index
for i, buffer in enumerate(self.get_buffers().vector):
buffer.tmp_index = i

# Assign each tensor its index and its buffer index
for i, tensor in enumerate(self.get_tensors().vector):
if tensor.tmp_null_tensor:
# Using -1 as the index to the 'tensors' vector is way of telling the TFLite inference engine, that
# this tensor should not be used.
# https://github.com/tensorflow/tensorflow/blob/05404d959119d41a8ffb8a75c6f232cfd8540d45/tensorflow/lite/kernels/kernel_util.cc#L79-L98
tensor.tmp_index = -1
else:
tensor.tmp_index = i

tensor.buffer = tensor.tmp_buffer.tmp_index

# TODO Remove inputs and outputs that are not in the tensors collection

subgraph = self.get_sub_graph()

# Assign 'Outputs' and 'Inputs' their tensor indices
self._assign_io_tensor_indices(
inputs=subgraph.inputs,
outputs=subgraph.outputs,
allow_inputs_stripping=allow_inputs_stripping,
)
# Assign each operator its inputs and outputs indices
self._assign_operators_io_tensor_indices(operators=subgraph.operators)

def _build_operator_code(
self, op_type: BuiltinOperator, version, custom_code: str = None
):
Expand Down
25 changes: 7 additions & 18 deletions backends/nxp/backend/ir/converter/node_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from enum import Enum

import torch

Expand All @@ -16,6 +15,7 @@
AtenModelBuilderDirector,
)
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node
from torch.fx.passes.infra.partitioner import Partition
Expand All @@ -42,17 +42,6 @@ def is_not_qdq_node(node: torch.fx.Node) -> bool:
return not (_is_quant_node(node) or _is_dequant_node(node))


class Target(Enum):
IGNORE = "ignore" # No target platform. Any target specific restrictions will be ignored.

RT700 = "imxrt700"
IMX95 = "imx95"

@classmethod
def values(cls) -> list[str]:
return [elt.value for elt in cls]


class NodeConverter(ABC):
"""
Classes which implement conversion of torch.Node to TFLite should inherit from this class and overwrite the
Expand Down Expand Up @@ -94,7 +83,7 @@ def _is_supported_in_IR(
@staticmethod
def _is_supported_on_target(
node: Node,
target: Target,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
Expand All @@ -103,31 +92,31 @@ def _is_supported_on_target(
can be used by operators with no target specific requirements.

:param node: The node (edge operator) to check.
:param target: Value of the `Target` enum representing the target platform to check for.
:param neutron_target_spec: Object for querying the target platform to retrieve its properties.
:param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it).
:param custom_delegation_options: Custom options which affect delegation.
"""
return target == Target.RT700
return True

@classmethod
def is_supported(
cls,
node: Node,
target: Target,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
"""Check if the given `node` is supported in the IR and on the given `target` platform.

:param node: torch.Node to check.
:param target: Value of the `Target` enum representing the target platform to check for.
:param neutron_target_spec: Object for querying the target platform to retrieve its properties.
:param parameters_mapping: Dict mapping tensor names to their data.
:param custom_delegation_options: Custom user options which affect node delegation.
"""
return cls._is_supported_in_IR(
node, parameters_mapping, custom_delegation_options
) and cls._is_supported_on_target(
node, target, parameters_mapping, custom_delegation_options
node, neutron_target_spec, parameters_mapping, custom_delegation_options
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
NodeConverter,
Target,
)
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
add_options,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torch.fx import Node
from torch.nn import Parameter

Expand All @@ -22,20 +22,15 @@ class AddTensorConverter(NodeConverter):
@staticmethod
def _is_supported_on_target(
node: Node,
target: Target,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
match target:
case Target.RT700:
if node_uses_shape_broadcasting(node):
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
return False

return True
if node_uses_shape_broadcasting(node):
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
return False

case _:
return False
return True

@staticmethod
def _is_supported_in_IR(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
_is_dequant_node,
_is_quant_node,
NodeConverter,
Target,
)
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import (
Concatenation,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torch.fx import Node
from torch.nn import Parameter

Expand Down Expand Up @@ -72,51 +72,52 @@ def _all_io_shares_quantization_parameters(node: Node) -> bool:
@staticmethod
def _is_supported_on_target(
node: Node,
target: Target,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if custom_delegation_options.force_delegate_cat:
return True

match target:
case Target.RT700:
dim = CatConverter._get_normalized_dim(node)

# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491
if dim == 0:
return False

# Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the
# last dimension, depending on the formats of the node. The format, however, cannot be determined
# during conversion, as it depends on what other nodes are delegated.
input_channels = [
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
# will still be the channels in the IR.
_get_shape(input_)[1]
for input_ in node.all_input_nodes
] + [
# If the inputs/outputs are channels first, the last dimension will be the channels.
_get_shape(input_)[-1]
for input_ in node.all_input_nodes
]
if any((input_channel % 8) != 0 for input_channel in input_channels):
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
return False

output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
if any((out_c % 8) != 0 for out_c in output_channels):
return False

if len(node.all_input_nodes) < 2: # Not supported on Neutron
# TODO Try to skip the operator if this case is realistic.
return False

return True

case _:
return False
dim = CatConverter._get_normalized_dim(node)

# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491
if dim == 0:
return False

# Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the
# last dimension, depending on the formats of the node. The format, however, cannot be determined
# during conversion, as it depends on what other nodes are delegated.
input_channels = [
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
# will still be the channels in the IR.
_get_shape(input_)[1]
for input_ in node.all_input_nodes
] + [
# If the inputs/outputs are channels first, the last dimension will be the channels.
_get_shape(input_)[-1]
for input_ in node.all_input_nodes
]
if any(
(input_channel % neutron_target_spec.get_num_macs()) != 0
for input_channel in input_channels
):
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
return False

output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
if any(
(out_c % neutron_target_spec.get_num_macs()) != 0
for out_c in output_channels
):
return False

if len(node.all_input_nodes) < 2: # Not supported on Neutron
# TODO Try to skip the operator if this case is realistic.
return False

return True

@staticmethod
def _is_supported_in_IR(
Expand Down
Loading
Loading