Skip to content

Commit 0769abc

Browse files
NXP backend: Add support for aten.permute_copy.default. (#15099)
### Summary Add support for `aten.permute_copy.default`. ### Test plan Unit tests provided. cc @robert-kalmar @JakeStevens @digantdesai --------- Co-authored-by: Roman Janik <roman.janik@nxp.com>
1 parent 179a155 commit 0769abc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1317
-121
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.nn.parameter import Parameter
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
2121
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
22-
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
22+
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
2323
from executorch.exir.dialects._ops import ops as exir_ops
2424

2525
# noinspection PyProtectedMember
@@ -63,7 +63,7 @@ def convert_program(
6363
conversion_config: ConversionConfig = _default_conversion_config,
6464
neutron_target_spec: NeutronTargetSpec = _default_target_spec,
6565
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
66-
) -> (bytes, dict):
66+
) -> (bytes, dict[str, NodeFormat]):
6767
"""
6868
Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.
6969
@@ -87,13 +87,16 @@ def convert_program(
8787
self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc)
8888
self._process_nodes(edge_program.graph.nodes, cc)
8989

90-
# Assign output
91-
io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats(
92-
edge_program.graph_signature
93-
)
90+
# Assign the model its inputs and outputs.
91+
cc.tflite_builder.assign_model_io_to_subgraph(edge_program.graph_signature)
9492

95-
# TFLite model generation
93+
# Apply optimizations and finalize the model.
9694
internal_tflite_model = cc.tflite_builder.finish()
95+
96+
# Extract the formats of the model's inputs and outputs.
97+
io_formats = cc.tflite_builder.get_io_formats(edge_program.graph_signature)
98+
99+
# TFLite model generation
97100
flatbuffers_builder = flatbuffers.Builder()
98101
internal_tflite_model.gen_tflite(flatbuffers_builder)
99102

backends/nxp/backend/ir/conversion_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, args: dict | None = None):
1313
1414
:param args: Optional dictionary with conversion arguments. Unknown arguments are ignored.
1515
"""
16-
self.keep_io_format: bool = False
16+
self.use_neutron_for_format_conversion: bool = True
1717
self.allow_inputs_stripping: bool = True
1818
self.qdq_aware_conversion: bool = True
1919
self.symbolic_dimensions_mapping: dict[str, int] | None = None

backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,40 @@ def append_operators(self, ops_to_add: list[tflite_model.Operator]):
8888

8989
self.check_and_append_operator(op)
9090

91-
def assign_model_io_to_subgraph_and_get_io_formats(
92-
self, graph_signature
93-
) -> dict[str, dict]:
94-
"""
95-
Assign model's inputs/outputs to SubGraph.
91+
def get_io_formats(self, graph_signature) -> dict[str, dict[str, TensorFormat]]:
92+
"""Get a mapping from tensor names to their formats.
9693
97-
:param graph_signature: Instance of GraphSignature.
94+
:param graph_signature: Instance of GraphSignature.
9895
:returns: Mapping between IO tensors' names and their formats.
9996
"""
10097
io_formats = {
10198
"inputs": {},
10299
"outputs": {},
103100
}
101+
for input_name in graph_signature.user_inputs:
102+
tensor = self.tensor_for_name(input_name)
103+
assert input_name == tensor.name, (
104+
"Program's input name doesn't match with tensor name in TFLite. "
105+
"Input was probably redirected."
106+
)
107+
io_formats["inputs"][tensor.name] = tensor.tensor_format
108+
109+
for output_name in graph_signature.user_outputs:
110+
tensor = self.tensor_for_name(output_name)
111+
assert output_name == tensor.name, (
112+
"Program's output name doesn't match with tensor name in TFLite. "
113+
"Output was probably redirected."
114+
)
115+
io_formats["outputs"][tensor.name] = tensor.tensor_format
116+
117+
return io_formats
118+
119+
def assign_model_io_to_subgraph(self, graph_signature):
120+
"""
121+
Assign model's inputs/outputs to SubGraph.
122+
123+
:param graph_signature: Instance of GraphSignature.
124+
"""
104125

105126
self.get_sub_graph().inputs = tflite_model.SubGraphInputs()
106127
for input_name in graph_signature.user_inputs:
@@ -110,7 +131,6 @@ def assign_model_io_to_subgraph_and_get_io_formats(
110131
"Input was probably redirected."
111132
)
112133
self.get_sub_graph().inputs.tmp_inputs.append(tensor)
113-
io_formats["inputs"][tensor.name] = tensor.tensor_format
114134

115135
self.get_sub_graph().outputs = tflite_model.SubGraphOutputs()
116136
for output_name in graph_signature.user_outputs:
@@ -120,7 +140,3 @@ def assign_model_io_to_subgraph_and_get_io_formats(
120140
"Output was probably redirected."
121141
)
122142
self.get_sub_graph().outputs.tmp_outputs.append(tensor)
123-
124-
io_formats["outputs"][tensor.name] = tensor.tensor_format
125-
126-
return io_formats

backends/nxp/backend/ir/converter/builder/model_builder.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# License: MIT
66
# See the LICENSE_MIT for more details.
77
#
8+
89
from copy import deepcopy
10+
from itertools import chain
911
from typing import Dict, List, Optional, Union
1012

1113
import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator
@@ -48,6 +50,9 @@
4850
FlexTranspose,
4951
)
5052
from executorch.backends.nxp.backend.ir.tflite_optimizer import optimizer
53+
from executorch.backends.nxp.backend.neutron_operator_support import (
54+
transposition_is_supported_on_neutron,
55+
)
5156
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
5257

5358

@@ -218,7 +223,7 @@ def channels_first_version_of(self, t_tensor: tflite_model.Tensor):
218223
new_tensor.shape = translator.channels_last_shape_to_channels_first(
219224
t_tensor.shape
220225
)
221-
new_tensor.tensor_format = new_tensor.tensor_format.to_node_format()
226+
new_tensor.tensor_format = TensorFormat.CHANNELS_FIRST
222227

223228
perm = translator.create_channels_last_to_channels_first_permutation(
224229
t_tensor.rank
@@ -355,6 +360,19 @@ def _make_inputs_channels_first(self):
355360
if input_tensor.tensor_format.is_channels_last():
356361
# Create a Transpose operator and replace the graph input
357362

363+
new_input_shape = translator.channels_last_shape_to_channels_first(
364+
input_tensor.shape
365+
)
366+
perm = translator.create_channels_first_to_channels_last_permutation(
367+
input_tensor.rank
368+
)
369+
370+
if not transposition_is_supported_on_neutron(
371+
new_input_shape.vector, list(perm), self.neutron_target_spec
372+
):
373+
new_inputs.append(input_tensor)
374+
continue
375+
358376
if input_tensor.rank > 6:
359377
msg = (
360378
f"Couldn't preserve the shape of input tensor '{input_tensor.name}', because it has "
@@ -365,14 +383,9 @@ def _make_inputs_channels_first(self):
365383
new_input = self.duplicate_tensor(
366384
input_tensor, input_tensor.name + "_channels_first"
367385
)
368-
new_input.shape = translator.channels_last_shape_to_channels_first(
369-
input_tensor.shape
370-
)
371-
new_input.tensor_format = input_tensor.tensor_format.to_node_format()
386+
new_input.shape = new_input_shape
387+
new_input.tensor_format = TensorFormat.CHANNELS_FIRST
372388

373-
perm = translator.create_channels_first_to_channels_last_permutation(
374-
input_tensor.rank
375-
)
376389
transpose = self._create_transpose_operator(
377390
new_input, input_tensor, perm
378391
)
@@ -397,6 +410,16 @@ def _make_outputs_channels_first(self):
397410
if output_tensor.tensor_format.is_channels_last():
398411
# Add a Transpose operator, to make the output channels first
399412

413+
shape = output_tensor.shape.vector
414+
perm = translator.create_channels_last_to_channels_first_permutation(
415+
len(shape), True
416+
)
417+
if not transposition_is_supported_on_neutron(
418+
shape, perm, self.neutron_target_spec
419+
):
420+
new_outputs.append(output_tensor)
421+
continue
422+
400423
if output_tensor.rank > 6:
401424
logger.e(
402425
logger.Code.IO_PRESERVATION_ERROR,
@@ -437,26 +460,38 @@ def _keep_one_empty_buffer(self):
437460
# It's safe to replace the buffer.
438461
t.tmp_buffer = empty_buffer
439462

463+
def replace_io_tensor_format_with_node_format(self):
464+
for t in chain(
465+
self.get_sub_graph().inputs.tmp_inputs,
466+
self.get_sub_graph().outputs.tmp_outputs,
467+
):
468+
if isinstance(t.tensor_format, TensorFormat):
469+
t.tensor_format = t.tensor_format.to_equal_node_format()
470+
440471
def finish(self) -> tflite_model.Model:
441472
"""Finalize and optimize the converted TFLite model. Then return it.
442473
443474
At least one of 'optimization_whitelist' and 'optimization_blacklist' must be 'None'.
444475
:return: The final TFLite model.
445476
"""
446477

447-
if self.conversion_config.keep_io_format:
478+
if self.conversion_config.use_neutron_for_format_conversion:
448479
# If the input or output is channels last, add a Transpose operator, to make is channels first.
449480
self._make_inputs_channels_first()
450481
self._make_outputs_channels_first()
451482

452483
# Apply optimizations to the internal TFLite model.
453-
optimizer.Optimizer(self, self.conversion_config).optimize(
484+
optimizer.Optimizer(
485+
self, self.conversion_config, self.neutron_target_spec
486+
).optimize(
454487
self.conversion_config.optimization_whitelist,
455488
self.conversion_config.optimization_blacklist,
456489
)
457490

458491
self._keep_one_empty_buffer()
459492

493+
self.replace_io_tensor_format_with_node_format()
494+
460495
# Remove outputs, which are not produced by any node. Otherwise, there would be errors after inference.
461496
operator_outputs = []
462497
for op in self.get_operators().vector:

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ def builder(self) -> AtenModelBuilderDirector:
185185
"""
186186
return self.context.tflite_builder
187187

188+
@property
189+
def neutron_target_spec(self) -> NeutronTargetSpec:
190+
"""
191+
Get an instance of NeutronTargetSpec from the conversion context.
192+
:return: NeutronTargetSpec instance.
193+
"""
194+
return self.builder.neutron_target_spec
195+
188196
def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator:
189197
"""
190198
Create TFLite op wrapper with input/output tensors added into 'tmp_inputs' and 'tmp_outputs'.

0 commit comments

Comments
 (0)