From bb6dbc39158c69268a6d8901439801970113f9f3 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 17 Oct 2025 12:50:14 +0200 Subject: [PATCH 1/2] Arm backend: break out tagging in partitioner This prepares the partitioner for partitioning submodules. Signed-off-by: Erik Lundell Change-Id: Ie3123c672ad4df93b8a9f835c3908d58668e27d2 --- backends/arm/tosa/partitioner.py | 157 ++++++++++++++++++------------- 1 file changed, 91 insertions(+), 66 deletions(-) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 6eb1dcbef72..6a20af55f79 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -36,7 +36,8 @@ from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx import GraphModule +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupportBase logger = logging.getLogger(__name__) @@ -110,6 +111,43 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool: return all(m == 1 for m in multiples) +def is_partitioned( + node: torch.fx.Node, + tag: str, +) -> bool: + """Return True if the node currently belongs to the partition ``tag``. + + Args: + node (torch.fx.Node): FX node to check. + tag (str): Delegation tag identifying the partition. + + Returns: + bool: True if the node carries the matching delegation tag. + + """ + return "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag + + +def reject_partition( + reason: str, partition: Partition, reporter: WhyNoPartitionReporter +) -> None: + """Remove a proposed partition and record the rejection reason. + + Args: + reason (str): Human-readable explanation for rejection. + partition (object): Proposed partition object from the + capability partitioner. + reporter (WhyNoPartitionReporter): used to report why nodes were rejected. + """ + for node in partition.nodes: + if "delegation_tag" in node.meta: + del node.meta["delegation_tag"] + reporter.report_reject( + node, + reason, + ) + + class TOSAPartitioner(Partitioner): """Partition an exported program into TOSA-delegable subgraphs. @@ -142,97 +180,54 @@ def __init__( self.additional_checks = additional_checks self.tosa_spec = compile_spec.tosa_spec - def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa - """Partition the program and tag TOSA-compatible subgraphs. - - Run the FX capability-based partitioner to propose subgraphs, then - refine tags by removing boundary-only quantize/dequantize nodes and by - rejecting partitions that would lower to no-ops. Emit a detailed report - of rejected nodes and their reasons. + def _tag_module( # noqa + self, + module: GraphModule, + containing_program: ExportedProgram, + reporter: WhyNoPartitionReporter, + ) -> set[str]: + """Tag nodes in a module, possibly a submodule, from the containing program. Args: - exported_program (ExportedProgram): Program to analyze and - partition. - + module: a GraphModule from `containing_program` to tag nodes in. + containing_program: The ExportedProgram that contains the module. + reporter: A reporter to report why nodes were rejected. Returns: - PartitionResult: The input program with nodes tagged for delegation - and a mapping of partition tags to delegation specs. - + A set of strings with the partition tags. """ - logger.info("TOSAPartitioner::partition") - partition_tags: dict[str, DelegationSpec] = {} - - logger.info( - f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}" - ) - - reporter = WhyNoPartitionReporter() + tags: set[str] = set() operator_support = tosa_support_factory( - self.tosa_spec, exported_program, reporter, self.additional_checks + self.tosa_spec, containing_program, reporter, self.additional_checks ) capability_partitioner = CapabilityBasedPartitioner( - exported_program.graph_module, + module, operator_support, allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() - def reject_partition(reason: str, partition, tag) -> None: - """Remove a proposed partition and record the rejection reason. - - Args: - reason (str): Human-readable explanation for rejection. - partition (object): Proposed partition object from the - capability partitioner. - tag (str): Delegation tag associated with the partition. - - """ - for node in partition.nodes: - if "delegation_tag" in node.meta: - del node.meta["delegation_tag"] - reporter.report_reject( - node, - reason, - ) - partition_tags.pop(tag, None) - for partition in partition_list: tag = f"tag{partition.id}" - - def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: - """Return True if the node currently belongs to the partition ``tag``. - - Args: - node (torch.fx.Node): FX node to check. - tag (str): Delegation tag identifying the partition. - - Returns: - bool: True if the node carries the matching delegation tag. - - """ - return ( - "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag - ) + tags.add(tag) for node in partition.nodes: node.meta["delegation_tag"] = tag - partition_tags[tag] = self.delegation_spec # De-tag outermost q-nodes upwards and dq-nodes downwards. # De-tag if at least one input/output is not part of the partition. - for node in exported_program.graph_module.graph.nodes: - if not is_partitioned(node): + for node in module.graph.nodes: + if not is_partitioned(node, tag): continue if node.target in Q_OPS: for input in node.all_input_nodes: - if not is_partitioned(input): + if not is_partitioned(input, tag): del node.meta["delegation_tag"] break continue if node.target in DQ_OPS: for user in node.users: - if not is_partitioned(user): + if not is_partitioned(user, tag): del node.meta["delegation_tag"] break continue @@ -240,9 +235,9 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: if self.tosa_spec.support_float(): continue - if is_partitioned(node): + if is_partitioned(node, tag): for input in node.all_input_nodes: - if is_partitioned(input): + if is_partitioned(input, tag): continue if get_first_fake_tensor(input).dtype.is_floating_point: reporter.report_reject( @@ -265,8 +260,38 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: reject_partition( "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", partition, - tag, + reporter, ) + tags.remove(tag) + return tags + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """Partition the program and tag TOSA-compatible subgraphs. + + Run the FX capability-based partitioner to propose subgraphs, then + refine tags by removing boundary-only quantize/dequantize nodes and by + rejecting partitions that would lower to no-ops. Emit a detailed report + of rejected nodes and their reasons. + + Args: + exported_program (ExportedProgram): Program to analyze and + partition. + + Returns: + PartitionResult: The input program with nodes tagged for delegation + and a mapping of partition tags to delegation specs. + + """ + logger.info("TOSAPartitioner::partition") + logger.info( + f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}" + ) + + reporter = WhyNoPartitionReporter() + tags = self._tag_module( + exported_program.graph_module, exported_program, reporter + ) + partition_tags = {tag: self.delegation_spec for tag in tags} tag_constant_data(exported_program) logger.info(f"The following nodes were rejected for {self.tosa_spec}:") From d468f4b7e44b9a6db1912946f7399701e5fec2eb Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 17 Oct 2025 13:36:43 +0200 Subject: [PATCH 2/2] Arm backend: Tag submodules in partitioner Recursively search for control_flow_modules and tag them in _tag_module function in partitioner. Signed-off-by: Erik Lundell Change-Id: Idb41e7f808013936dae8ca6909a69d053e834ca9 --- backends/arm/tosa/partitioner.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 6a20af55f79..3a1a79ec8de 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -15,6 +15,7 @@ """ import logging +from itertools import count from typing import Callable, List, Optional, Sequence, Tuple import torch @@ -35,6 +36,7 @@ ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_control_flow_submodules from torch.export.exported_program import ExportedProgram from torch.fx import GraphModule from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition @@ -185,6 +187,7 @@ def _tag_module( # noqa module: GraphModule, containing_program: ExportedProgram, reporter: WhyNoPartitionReporter, + tag_iterator: count | None = None, ) -> set[str]: """Tag nodes in a module, possibly a submodule, from the containing program. @@ -196,6 +199,17 @@ def _tag_module( # noqa A set of strings with the partition tags. """ tags: set[str] = set() + if tag_iterator is None: + tag_iterator = count(0) + for _, submodule, _ in get_control_flow_submodules(module): + submodule_tags = self._tag_module( + submodule, containing_program, reporter, tag_iterator + ) + if len(tags & submodule_tags) != 0: + raise RuntimeError( + "Got overlapping tags in two different modules, this shouldn't happen." + ) + tags = tags | submodule_tags operator_support = tosa_support_factory( self.tosa_spec, containing_program, reporter, self.additional_checks ) @@ -207,7 +221,7 @@ def _tag_module( # noqa partition_list = capability_partitioner.propose_partitions() for partition in partition_list: - tag = f"tag{partition.id}" + tag = f"tag{next(tag_iterator)}" tags.add(tag) for node in partition.nodes: