diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py index 4db6435d9a2..f6e3c374b19 100644 --- a/backends/nxp/aten_passes/neutron_aten_pass_manager.py +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -13,6 +13,9 @@ from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( FuseBatchNormWithLinearPass, ) +from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import ( + RemoveNodesWithKnownOutputs, +) from executorch.backends.nxp.aten_passes.split_group_convolution import ( SplitGroupConvolution, ) @@ -34,6 +37,7 @@ def __init__(self, passes: list[PassType] = None): FuseBatchNormWithLinearPass(), SplitGroupConvolution(), SplitGRUBasedOnNumLayers(), + RemoveNodesWithKnownOutputs(), ] super().__init__(passes) diff --git a/backends/nxp/aten_passes/remove_nodes_with_known_outputs.py b/backends/nxp/aten_passes/remove_nodes_with_known_outputs.py new file mode 100644 index 00000000000..3c08ac6c3fb --- /dev/null +++ b/backends/nxp/aten_passes/remove_nodes_with_known_outputs.py @@ -0,0 +1,196 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Collection + +import torch + +from executorch.backends.nxp.backend.edge_helper import ( + try_get_tensor_constant_from_node, +) +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.nn import Parameter + + +class RemoveNodesWithKnownOutputs(PassBase): + """In some situations, a node will always produce the same output data at runtime. If these cases are identified, + the nodes can simply be removed and replaced by a static parameter node, which holds the data the original + node would produce. + This pass identifies some of these cases and performs the replacement. + """ + + # Nodes which don't have the `.meta['val']` attribute. The datatype and shape of their inferred output data will + # therefore not be checked against the expected values in the `.meta['val']`. + nodes_without_val_meta = [ + torch.ops.aten.empty.memory_format, + ] + + module: GraphModule + + def replace_nodes_in_list_with_their_data(self, list_of_args: list) -> list | None: + """Replace the nodes in `list_of_args` by their static data. If not all data is available, return `None`. + + :param list_of_args: List of arguments of an aten operator. Can include nodes, generic arguments, lists... + :return:`list_of_args` but with tensors replaced by their static data, or `None` if not all data is available. + """ + args_with_data = [] + for arg in list_of_args: + match arg: + case Node(): + # `arg` is either another operator, a model input, or a static parameter. + data = try_get_tensor_constant_from_node(self.module, arg) + if data is None: + # No static data is available. + return None + + args_with_data.append(data) + case list(): + nested = self.replace_nodes_in_list_with_their_data(arg) + if nested is None: + return None + args_with_data.append(nested) + + case _: + # Generic argument. Not an input from a previous node. + args_with_data.append(arg) + + return args_with_data + + @staticmethod + def node_is_followed_only_by_getitem_nodes(node: Node) -> bool: + def _is_getitem(node_: Node) -> bool: + return node_.op == "call_function" and node_.target.__name__ == "getitem" + + users = list(node.users.keys()) + return all(_is_getitem(user) for user in users) + + def replace_node_with_static_data(self, node: Node, static_data: Parameter): + """Remove the given `node` from the graph and replace it with a parameter node containing the `static_data`.""" + # Generate a unique name for the new static parameter. + new_name = get_new_attr_name_with_prefix(node.name)(self.module) + + # Create the node for the parameter. + param = torch.nn.Parameter(static_data, False) + _assign_attr(param, self.module, str(new_name), _AttrKind.PARAMETER) + with self.module.graph.inserting_before(node): + static_parameter_node = self.module.graph.get_attr(new_name) + + with FakeTensorMode() as mode: + # Assign the parameter node its shape and data type. + static_parameter_node.meta["val"] = FakeTensor.from_tensor( + torch.empty(static_data.shape, dtype=static_data.dtype), mode + ) + + # Replace the old node with the new static parameter. + node.replace_all_uses_with(static_parameter_node) + self.module.graph.erase_node(node) + + def replace_following_getitem_nodes_with_static_data( + self, root_node: Node, static_data: Collection[Parameter] + ) -> bool: + """Remove the `root_node` and all `GetItem` nodes that consume its output from the graph, and replace their + uses with parameter nodes containing the provided `static_data`. + If something other than just `GetItem` nodes follow after the `root_node`, nothing is done and `False` is + returned. + + :param root_node: The main compute node which is followed only by `GetItem` nodes. + :param static_data: A tuple of static tensors with the data that will be used to replace the `GetItem` nodes + after the `root_node`. + :return: `True` if the replacement was successfully executed. `False` otherwise. + """ + + if not self.node_is_followed_only_by_getitem_nodes(root_node): + return False # Unexpected case. + + users = list(root_node.users.keys()) + if len(users) != len(static_data): + return False # Unexpected missmatch. + + # Replace the individual `GetItem` nodes. + for get_item_node in users: + idx = get_item_node.args[1] + self.replace_node_with_static_data(get_item_node, static_data[idx]) + + # Finally remove the root node from the graph. + self.module.graph.erase_node(root_node) + + return True + + def data_matches_node_meta(self, node: Node, data: Parameter) -> bool: + """Verify that the provided `data` tensor has the same shape and datatype as the `node`.""" + if node.target not in self.nodes_without_val_meta: + if node.meta["val"].shape != data.shape: + return False # The inferred data has a different shape than expected. + + if node.meta["val"].dtype != data.dtype: + return ( + False # The inferred data has a different data type than expected. + ) + + return True + + def data_matches_meta_of_following_getitem_nodes( + self, root_node: Node, data: Collection[Parameter] + ) -> bool: + """Verify that the provided `data` tensor has the same shape and datatype as the `GetItem` nodes which consume + the output of the `root_node`. + """ + if not self.node_is_followed_only_by_getitem_nodes(root_node): + return False # Unexpected case + + users = list(root_node.users.keys()) + return all( + self.data_matches_node_meta(get_item, data[get_item.args[1]]) + for get_item in users + ) + + def call(self, module: GraphModule) -> bool: + self.module = module + made_changes = False + + for node in module.graph.nodes: + if node.op != "call_function": + continue # Not a compute operator. + + # Try to access the static data for the inputs of the node. + args_with_data = self.replace_nodes_in_list_with_their_data(node.args) + + if args_with_data is None: + # Output data inference is not possible. + continue + + # All input data is static. Run the operator to compute the input it would produce at runtime. + # noinspection PyBroadException + try: + output = node.target(*args_with_data, **node.kwargs) + + if isinstance(output, tuple) or isinstance(output, list): + if not self.data_matches_meta_of_following_getitem_nodes( + node, output + ): + continue # The inferred data does not have the expected type/shape. + else: + if not self.data_matches_node_meta(node, output): + continue # The inferred data does not have the expected type/shape. + + except Exception: + continue # Failed to infer the data. Continue with the other nodes. + # The output data appears to have been correctly inferred. Create a static parameter node for it. + + if isinstance(output, tuple) or isinstance(output, list): + # The node produces multiple outputs (e.g. `split`). If the node is followed only by `GetItem` nodes + # which extract the individual outputs, replace them by the static data. + if self.replace_following_getitem_nodes_with_static_data(node, output): + made_changes = True + + else: + self.replace_node_with_static_data(node, output) + made_changes = True # Indicate that changes were made. + + return PassResult(module, made_changes) diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index 3b74d86f599..061295ead79 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torch.fx import Node +from torch.fx import GraphModule, Node from torch.nn import Parameter @@ -71,3 +71,19 @@ def _is_dequantize(node_: Node) -> bool: return _is_dequantize(node) and node_is_static_tensor( node.args[0], parameters_mapping ) + + +def try_get_tensor_constant_from_node( + graph_module: GraphModule, node: Node +) -> Parameter | None: + """Get the static data from a given node. If it doesn't have any data, return `None`.""" + if node is None or node.op != "get_attr": + return None + + target_atoms = node.target.split(".") + attr_itr = graph_module + for atom in target_atoms: + if not hasattr(attr_itr, atom): + return None + attr_itr = getattr(attr_itr, atom) + return attr_itr diff --git a/backends/nxp/tests/test_removing_nodes_with_known_outputs.py b/backends/nxp/tests/test_removing_nodes_with_known_outputs.py new file mode 100644 index 00000000000..8f5549c8526 --- /dev/null +++ b/backends/nxp/tests/test_removing_nodes_with_known_outputs.py @@ -0,0 +1,109 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import numpy as np +import torch + +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + NeutronAtenPassManager, +) +from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import ( + RemoveNodesWithKnownOutputs, +) +from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import ( + SplitGRUBasedOnNumLayers, +) +from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops +from parameterized import parameterized +from torch import nn + + +class GRUModel(nn.Module): + def __init__(self, num_layers=1): + super().__init__() + self.gru = torch.nn.GRU(8, 8, num_layers=num_layers) + + def forward(self, input_): + # `input_` has shape [sequence_length, batch_size, input_size] ([8, 1, 8]) + return self.gru( + input_, None + ) # The initial hidden is `None`, which will result in a `Zeros` node being added. + + +class TestRemovingNodesWithKnownOutputs(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests. + + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + def test_removing_nodes__zeros(self): + model = GRUModel() + + input_shape = (8, 1, 8) + example_input = (torch.ones(input_shape),) + + exir_program_aten = torch.export.export(model, example_input).module() + + # Make sure the `aten.zeros` is in the model. + assert graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.zeros.default] + ) + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Apply the optimization. + NeutronAtenPassManager([RemoveNodesWithKnownOutputs()])(exir_program_aten) + + # Make sure the `aten.zeros` is no longer in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.zeros.default] + ) + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + assert np.allclose(outputs_before[1], outputs_after[1]) + + @parameterized.expand([2, 8]) + def test_removing_nodes__split(self, num_layers): + # `num_layers > 1` will result in a `split` operator being added. It's input will be a `zeros` operator, which + # provides the static 0s input data. + model = GRUModel(num_layers).eval() + + input_shape = (8, 1, 8) + example_input = (torch.ones(input_shape),) + + exir_program_aten = torch.export.export(model, example_input).module() + + # Apply the pass to split the `aten.gru.input` into multiple instances, and add a `split` node. + NeutronAtenPassManager([SplitGRUBasedOnNumLayers()])(exir_program_aten) + + # Make sure the `aten.zeros` and `torch.split` are in the model. + assert graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.zeros.default] + ) + assert graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.split.default] + ) + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Apply the optimization. + NeutronAtenPassManager([RemoveNodesWithKnownOutputs()])(exir_program_aten) + + # Make sure the `aten.zeros` and `torch.split` are no longer in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.zeros.default] + ) + assert not graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.split.default] + ) + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + assert np.allclose(outputs_before[1], outputs_after[1])