Skip to content
Merged
17 changes: 10 additions & 7 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
from executorch.exir.dialects._ops import ops as exir_ops

# noinspection PyProtectedMember
Expand Down Expand Up @@ -63,7 +63,7 @@ def convert_program(
conversion_config: ConversionConfig = _default_conversion_config,
neutron_target_spec: NeutronTargetSpec = _default_target_spec,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
) -> (bytes, dict):
) -> (bytes, dict[str, NodeFormat]):
"""
Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.

Expand All @@ -87,13 +87,16 @@ def convert_program(
self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc)
self._process_nodes(edge_program.graph.nodes, cc)

# Assign output
io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats(
edge_program.graph_signature
)
# Assign the model its inputs and outputs.
cc.tflite_builder.assign_model_io_to_subgraph(edge_program.graph_signature)

# TFLite model generation
# Apply optimizations and finalize the model.
internal_tflite_model = cc.tflite_builder.finish()

# Extract the formats of the model's inputs and outputs.
io_formats = cc.tflite_builder.get_io_formats(edge_program.graph_signature)

# TFLite model generation
flatbuffers_builder = flatbuffers.Builder()
internal_tflite_model.gen_tflite(flatbuffers_builder)

Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/backend/ir/conversion_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, args: dict | None = None):

:param args: Optional dictionary with conversion arguments. Unknown arguments are ignored.
"""
self.keep_io_format: bool = False
self.use_neutron_for_format_conversion: bool = True
self.allow_inputs_stripping: bool = True
self.qdq_aware_conversion: bool = True
self.symbolic_dimensions_mapping: dict[str, int] | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,40 @@ def append_operators(self, ops_to_add: list[tflite_model.Operator]):

self.check_and_append_operator(op)

def assign_model_io_to_subgraph_and_get_io_formats(
self, graph_signature
) -> dict[str, dict]:
"""
Assign model's inputs/outputs to SubGraph.
def get_io_formats(self, graph_signature) -> dict[str, dict[str, TensorFormat]]:
"""Get a mapping from tensor names to their formats.

:param graph_signature: Instance of GraphSignature.
:param graph_signature: Instance of GraphSignature.
:returns: Mapping between IO tensors' names and their formats.
"""
io_formats = {
"inputs": {},
"outputs": {},
}
for input_name in graph_signature.user_inputs:
tensor = self.tensor_for_name(input_name)
assert input_name == tensor.name, (
"Program's input name doesn't match with tensor name in TFLite. "
"Input was probably redirected."
)
io_formats["inputs"][tensor.name] = tensor.tensor_format

for output_name in graph_signature.user_outputs:
tensor = self.tensor_for_name(output_name)
assert output_name == tensor.name, (
"Program's output name doesn't match with tensor name in TFLite. "
"Output was probably redirected."
)
io_formats["outputs"][tensor.name] = tensor.tensor_format

return io_formats

def assign_model_io_to_subgraph(self, graph_signature):
"""
Assign model's inputs/outputs to SubGraph.

:param graph_signature: Instance of GraphSignature.
"""

self.get_sub_graph().inputs = tflite_model.SubGraphInputs()
for input_name in graph_signature.user_inputs:
Expand All @@ -110,7 +131,6 @@ def assign_model_io_to_subgraph_and_get_io_formats(
"Input was probably redirected."
)
self.get_sub_graph().inputs.tmp_inputs.append(tensor)
io_formats["inputs"][tensor.name] = tensor.tensor_format

self.get_sub_graph().outputs = tflite_model.SubGraphOutputs()
for output_name in graph_signature.user_outputs:
Expand All @@ -120,7 +140,3 @@ def assign_model_io_to_subgraph_and_get_io_formats(
"Output was probably redirected."
)
self.get_sub_graph().outputs.tmp_outputs.append(tensor)

io_formats["outputs"][tensor.name] = tensor.tensor_format

return io_formats
55 changes: 45 additions & 10 deletions backends/nxp/backend/ir/converter/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# License: MIT
# See the LICENSE_MIT for more details.
#

from copy import deepcopy
from itertools import chain
from typing import Dict, List, Optional, Union

import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator
Expand Down Expand Up @@ -48,6 +50,9 @@
FlexTranspose,
)
from executorch.backends.nxp.backend.ir.tflite_optimizer import optimizer
from executorch.backends.nxp.backend.neutron_operator_support import (
transposition_is_supported_on_neutron,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec


Expand Down Expand Up @@ -218,7 +223,7 @@ def channels_first_version_of(self, t_tensor: tflite_model.Tensor):
new_tensor.shape = translator.channels_last_shape_to_channels_first(
t_tensor.shape
)
new_tensor.tensor_format = new_tensor.tensor_format.to_node_format()
new_tensor.tensor_format = TensorFormat.CHANNELS_FIRST

perm = translator.create_channels_last_to_channels_first_permutation(
t_tensor.rank
Expand Down Expand Up @@ -355,6 +360,19 @@ def _make_inputs_channels_first(self):
if input_tensor.tensor_format.is_channels_last():
# Create a Transpose operator and replace the graph input

new_input_shape = translator.channels_last_shape_to_channels_first(
input_tensor.shape
)
perm = translator.create_channels_first_to_channels_last_permutation(
input_tensor.rank
)

if not transposition_is_supported_on_neutron(
new_input_shape.vector, list(perm), self.neutron_target_spec
):
new_inputs.append(input_tensor)
continue

if input_tensor.rank > 6:
msg = (
f"Couldn't preserve the shape of input tensor '{input_tensor.name}', because it has "
Expand All @@ -365,14 +383,9 @@ def _make_inputs_channels_first(self):
new_input = self.duplicate_tensor(
input_tensor, input_tensor.name + "_channels_first"
)
new_input.shape = translator.channels_last_shape_to_channels_first(
input_tensor.shape
)
new_input.tensor_format = input_tensor.tensor_format.to_node_format()
new_input.shape = new_input_shape
new_input.tensor_format = TensorFormat.CHANNELS_FIRST

perm = translator.create_channels_first_to_channels_last_permutation(
input_tensor.rank
)
transpose = self._create_transpose_operator(
new_input, input_tensor, perm
)
Expand All @@ -397,6 +410,16 @@ def _make_outputs_channels_first(self):
if output_tensor.tensor_format.is_channels_last():
# Add a Transpose operator, to make the output channels first

shape = output_tensor.shape.vector
perm = translator.create_channels_last_to_channels_first_permutation(
len(shape), True
)
if not transposition_is_supported_on_neutron(
shape, perm, self.neutron_target_spec
):
new_outputs.append(output_tensor)
continue

if output_tensor.rank > 6:
logger.e(
logger.Code.IO_PRESERVATION_ERROR,
Expand Down Expand Up @@ -437,26 +460,38 @@ def _keep_one_empty_buffer(self):
# It's safe to replace the buffer.
t.tmp_buffer = empty_buffer

def replace_io_tensor_format_with_node_format(self):
for t in chain(
self.get_sub_graph().inputs.tmp_inputs,
self.get_sub_graph().outputs.tmp_outputs,
):
if isinstance(t.tensor_format, TensorFormat):
t.tensor_format = t.tensor_format.to_equal_node_format()

def finish(self) -> tflite_model.Model:
"""Finalize and optimize the converted TFLite model. Then return it.

At least one of 'optimization_whitelist' and 'optimization_blacklist' must be 'None'.
:return: The final TFLite model.
"""

if self.conversion_config.keep_io_format:
if self.conversion_config.use_neutron_for_format_conversion:
# If the input or output is channels last, add a Transpose operator, to make is channels first.
self._make_inputs_channels_first()
self._make_outputs_channels_first()

# Apply optimizations to the internal TFLite model.
optimizer.Optimizer(self, self.conversion_config).optimize(
optimizer.Optimizer(
self, self.conversion_config, self.neutron_target_spec
).optimize(
self.conversion_config.optimization_whitelist,
self.conversion_config.optimization_blacklist,
)

self._keep_one_empty_buffer()

self.replace_io_tensor_format_with_node_format()

# Remove outputs, which are not produced by any node. Otherwise, there would be errors after inference.
operator_outputs = []
for op in self.get_operators().vector:
Expand Down
8 changes: 8 additions & 0 deletions backends/nxp/backend/ir/converter/node_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ def builder(self) -> AtenModelBuilderDirector:
"""
return self.context.tflite_builder

@property
def neutron_target_spec(self) -> NeutronTargetSpec:
"""
Get an instance of NeutronTargetSpec from the conversion context.
:return: NeutronTargetSpec instance.
"""
return self.builder.neutron_target_spec

def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator:
"""
Create TFLite op wrapper with input/output tensors added into 'tmp_inputs' and 'tmp_outputs'.
Expand Down
Loading
Loading