Skip to content
Draft
13 changes: 13 additions & 0 deletions backends/nxp/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ runtime.python_library(
],
)

runtime.python_library(
name = "_passes",
srcs = glob([
"_passes/*.py",
]),
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_manager",
],
)

runtime.python_library(
name = "quantizer",
srcs = [
Expand Down Expand Up @@ -65,6 +77,7 @@ runtime.python_library(
deps = [
":neutron_sdk",
":aten_passes",
":_passes",
":quantizer",
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
Expand Down
103 changes: 103 additions & 0 deletions backends/nxp/_passes/remove_getitem_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NXP_NODE_FORMAT,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RemoveGetItemPass(ExportPass):
"""
This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator.
Before Pass:
MaxPool2d ---> GetItem[max_values, max_indexes]
After Pass:
MaxPool2d -> max_values
"""

def call(self, graph_module: torch.fx.GraphModule):
module = graph_module
for node in module.graph.nodes:
if node.op == "call_function":
if (
node.target.__name__ == "aten.max_pool2d_with_indices.default"
or node.target.__name__ == "aten.max.dim"
):
users = list(node.users.keys())

if len(users) != 1:
if len(users) == 2 and node.target.__name__ == "aten.max.dim":
# Two users is allowed for max.dim. For that case,
# rather than removing the getitem node in this
# pass, we handle the getitem nodes in the op's
# visitor when serializing
continue
else:
raise AssertionError(
f"Invalid number of users for {node.target.__name__}: {len(users)}"
)

getitem_node = list(node.users.keys())[0]

if getitem_node.target.__name__ != "getitem":
raise AssertionError(
f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
)

getitem_index = getitem_node.args[1]

with module.graph.inserting_before(node):
if (
node.target.__name__
== "aten.max_pool2d_with_indices.default"
):
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.max_pool2d.default,
args=node.args,
kwargs=node.kwargs,
)

else:
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.amax.default,
args=node.args,
kwargs=node.kwargs,
)

# MODIFIED PART START
# Make sure to preserve the inferred node format.
new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get(
NXP_NODE_FORMAT, NodeFormat.NONE
)
# MODIFIED PART END

getitem_node.replace_all_uses_with(new_max_wd)

module.graph.erase_node(getitem_node)
module.graph.erase_node(node)

graph_module.recompile()
# Propagate metadata and retrace module
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
32 changes: 32 additions & 0 deletions backends/nxp/backend/edge_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

import torch

from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import GraphModule, Node
from torch.nn import Parameter

Expand Down Expand Up @@ -87,3 +89,33 @@ def try_get_tensor_constant_from_node(
return None
attr_itr = getattr(attr_itr, atom)
return attr_itr


def _is_dequantize(node_: Node) -> bool:
return node_.op == "call_function" and node_.target in [
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]


def _is_quantize(node_: Node) -> bool:
return node_.op == "call_function" and node_.target.__name__ in [
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]


def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None:
"""Return the first node which is not a `quantize` or `dequantize`, found by traversing the graph backwards
starting with the `node.args[input_index]`,
"""
current_node = node.args[input_index]
while True:
if _is_quantize(current_node) or _is_dequantize(current_node):
current_node = current_node.args[0]
else:
return current_node
13 changes: 3 additions & 10 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NodeFormatInference,
)
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from executorch.exir.dialects._ops import ops as exir_ops

# noinspection PyProtectedMember
Expand Down Expand Up @@ -74,12 +71,10 @@ def convert_program(
:param custom_delegation_options: Custom user options which affect node delegation.
:return: TFLite flatbuffers as bytes.
"""
node_formats = NodeFormatInference(edge_program).identify_node_formats()
parameters_mapping = self.map_inputs_to_parameters(edge_program)

cc = self.build_conversion_context(
parameters_mapping,
node_formats,
neutron_target_spec,
conversion_config,
custom_delegation_options,
Expand All @@ -106,7 +101,7 @@ def convert_program(
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
for node in nodes:
if node.op == "placeholder":
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]

if node.name in context.parameters_mapping:
# Node is placeholder and has data -> append as static tensor with data
Expand All @@ -119,7 +114,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "call_function":
# Node is call function -> append only output as a tensor
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "output":
# Nothing to do
Expand Down Expand Up @@ -177,7 +172,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
@staticmethod
def build_conversion_context(
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
neutron_target_spec: NeutronTargetSpec,
conversion_config: ConversionConfig = _default_conversion_config,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
Expand All @@ -193,7 +187,6 @@ def build_conversion_context(
tflite_builder,
conversion_config,
parameters_mapping,
node_formats,
custom_delegation_options,
)

Expand Down
5 changes: 0 additions & 5 deletions backends/nxp/backend/ir/conversion_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
AtenModelBuilderDirector,
)
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from torch import Node
from torch.nn import Parameter


class ConversionContext:
tflite_builder: AtenModelBuilderDirector
conversion_config: ConversionConfig
parameters_mapping: dict[str, Parameter]
node_formats: dict[Node, NodeFormat]
custom_delegation_options: CustomDelegationOptions

def __init__(
self,
tflite_builder: AtenModelBuilderDirector,
conversion_config: ConversionConfig,
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
custom_delegation_options: CustomDelegationOptions,
):
"""
Expand All @@ -39,5 +35,4 @@ def __init__(
self.tflite_builder = tflite_builder
self.conversion_config = conversion_config
self.parameters_mapping = parameters_mapping
self.node_formats = node_formats
self.custom_delegation_options = custom_delegation_options
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from executorch.backends.nxp.backend.node_format import NodeFormat
from torch.fx import Node
from torch.nn import Parameter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from executorch.backends.nxp.backend.custom_delegation_options import (
CustomDelegationOptions,
)
from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
apply_permutation_to,
create_channels_first_to_channels_last_permutation,
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
_is_dequant_node,
_is_quant_node,
Expand All @@ -18,7 +23,9 @@
Concatenation,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from torch.fx import Node
from torch.fx.passes.infra.partitioner import Partition
from torch.nn import Parameter


Expand Down Expand Up @@ -79,38 +86,28 @@ def _is_supported_on_target(
if custom_delegation_options.force_delegate_cat:
return True

dim = CatConverter._get_normalized_dim(node)

# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491
if dim == 0:
return False
# Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the
# last dimension, depending on the formats of the node.
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
# `1` will end up being the channels (last dim in NHWC).
channels_index = 1
else:
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
channels_index = -1

# Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the
# last dimension, depending on the formats of the node. The format, however, cannot be determined
# during conversion, as it depends on what other nodes are delegated.
input_channels = [
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
# will still be the channels in the IR.
_get_shape(input_)[1]
for input_ in node.all_input_nodes
] + [
# If the inputs/outputs are channels first, the last dimension will be the channels.
_get_shape(input_)[-1]
for input_ in node.all_input_nodes
_get_shape(input_)[channels_index] for input_ in node.all_input_nodes
]
if any(
(input_channel % neutron_target_spec.get_num_macs()) != 0
for input_channel in input_channels
):
output_channels = _get_shape(node)[channels_index]

num_macs = neutron_target_spec.get_num_macs()
if any((input_channel % num_macs) != 0 for input_channel in input_channels):
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
return False

output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
if any(
(out_c % neutron_target_spec.get_num_macs()) != 0
for out_c in output_channels
):
if (output_channels % num_macs) != 0:
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
return False

if len(node.all_input_nodes) < 2: # Not supported on Neutron
Expand All @@ -132,6 +129,46 @@ def _is_supported_in_IR(

return True

@classmethod
def supports_partitioning_result(
cls,
node: Node,
partition_list: list[Partition],
custom_delegation_options: CustomDelegationOptions,
):
# There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by
# `dim` are `!= 1`, the `Concat` is not delegated.
# This only happens when the inputs to the `Concat` are model inputs, and not outputs of other
# operators.
cat_partition = [p for p in partition_list if node in p.nodes][0]
cat_inputs = map(previous_non_qdq_node, node.args[0])

if not all(
input_.op == "call_function" and input_ in cat_partition.nodes
for input_ in cat_inputs
):
# Some inputs of the `cat` are NOT in the same partition as `cat`.
dim = CatConverter._get_normalized_dim(node)
input_shapes = [list(n.meta["val"].shape) for n in node.args[0]]
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# Transform the shapes to channels last.
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
len(node.meta["val"].shape), True
)
input_shapes = [
apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes
]

# Transform the `dim` to refer to a channels last dimension.
dim = to_nhwc_perm.index(dim)

for input_shape in input_shapes:
if not any(d != 1 for d in input_shape[:dim]):
# Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension.
return False

return True

def convert(self, node: Node):
"""Convert the 'aten.cat' operator to TFLite 'Concatenation'."""
self.assert_convertible(node)
Expand Down
Loading
Loading