diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 6eb1dcbef72..3a1a79ec8de 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -15,6 +15,7 @@ """ import logging +from itertools import count from typing import Callable, List, Optional, Sequence, Tuple import torch @@ -35,8 +36,10 @@ ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_control_flow_submodules from torch.export.exported_program import ExportedProgram -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx import GraphModule +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupportBase logger = logging.getLogger(__name__) @@ -110,6 +113,43 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool: return all(m == 1 for m in multiples) +def is_partitioned( + node: torch.fx.Node, + tag: str, +) -> bool: + """Return True if the node currently belongs to the partition ``tag``. + + Args: + node (torch.fx.Node): FX node to check. + tag (str): Delegation tag identifying the partition. + + Returns: + bool: True if the node carries the matching delegation tag. + + """ + return "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag + + +def reject_partition( + reason: str, partition: Partition, reporter: WhyNoPartitionReporter +) -> None: + """Remove a proposed partition and record the rejection reason. + + Args: + reason (str): Human-readable explanation for rejection. + partition (object): Proposed partition object from the + capability partitioner. + reporter (WhyNoPartitionReporter): used to report why nodes were rejected. + """ + for node in partition.nodes: + if "delegation_tag" in node.meta: + del node.meta["delegation_tag"] + reporter.report_reject( + node, + reason, + ) + + class TOSAPartitioner(Partitioner): """Partition an exported program into TOSA-delegable subgraphs. @@ -142,97 +182,66 @@ def __init__( self.additional_checks = additional_checks self.tosa_spec = compile_spec.tosa_spec - def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa - """Partition the program and tag TOSA-compatible subgraphs. - - Run the FX capability-based partitioner to propose subgraphs, then - refine tags by removing boundary-only quantize/dequantize nodes and by - rejecting partitions that would lower to no-ops. Emit a detailed report - of rejected nodes and their reasons. + def _tag_module( # noqa + self, + module: GraphModule, + containing_program: ExportedProgram, + reporter: WhyNoPartitionReporter, + tag_iterator: count | None = None, + ) -> set[str]: + """Tag nodes in a module, possibly a submodule, from the containing program. Args: - exported_program (ExportedProgram): Program to analyze and - partition. - + module: a GraphModule from `containing_program` to tag nodes in. + containing_program: The ExportedProgram that contains the module. + reporter: A reporter to report why nodes were rejected. Returns: - PartitionResult: The input program with nodes tagged for delegation - and a mapping of partition tags to delegation specs. - + A set of strings with the partition tags. """ - logger.info("TOSAPartitioner::partition") - partition_tags: dict[str, DelegationSpec] = {} - - logger.info( - f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}" - ) - - reporter = WhyNoPartitionReporter() + tags: set[str] = set() + if tag_iterator is None: + tag_iterator = count(0) + for _, submodule, _ in get_control_flow_submodules(module): + submodule_tags = self._tag_module( + submodule, containing_program, reporter, tag_iterator + ) + if len(tags & submodule_tags) != 0: + raise RuntimeError( + "Got overlapping tags in two different modules, this shouldn't happen." + ) + tags = tags | submodule_tags operator_support = tosa_support_factory( - self.tosa_spec, exported_program, reporter, self.additional_checks + self.tosa_spec, containing_program, reporter, self.additional_checks ) capability_partitioner = CapabilityBasedPartitioner( - exported_program.graph_module, + module, operator_support, allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() - def reject_partition(reason: str, partition, tag) -> None: - """Remove a proposed partition and record the rejection reason. - - Args: - reason (str): Human-readable explanation for rejection. - partition (object): Proposed partition object from the - capability partitioner. - tag (str): Delegation tag associated with the partition. - - """ - for node in partition.nodes: - if "delegation_tag" in node.meta: - del node.meta["delegation_tag"] - reporter.report_reject( - node, - reason, - ) - partition_tags.pop(tag, None) - for partition in partition_list: - tag = f"tag{partition.id}" - - def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: - """Return True if the node currently belongs to the partition ``tag``. - - Args: - node (torch.fx.Node): FX node to check. - tag (str): Delegation tag identifying the partition. - - Returns: - bool: True if the node carries the matching delegation tag. - - """ - return ( - "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag - ) + tag = f"tag{next(tag_iterator)}" + tags.add(tag) for node in partition.nodes: node.meta["delegation_tag"] = tag - partition_tags[tag] = self.delegation_spec # De-tag outermost q-nodes upwards and dq-nodes downwards. # De-tag if at least one input/output is not part of the partition. - for node in exported_program.graph_module.graph.nodes: - if not is_partitioned(node): + for node in module.graph.nodes: + if not is_partitioned(node, tag): continue if node.target in Q_OPS: for input in node.all_input_nodes: - if not is_partitioned(input): + if not is_partitioned(input, tag): del node.meta["delegation_tag"] break continue if node.target in DQ_OPS: for user in node.users: - if not is_partitioned(user): + if not is_partitioned(user, tag): del node.meta["delegation_tag"] break continue @@ -240,9 +249,9 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: if self.tosa_spec.support_float(): continue - if is_partitioned(node): + if is_partitioned(node, tag): for input in node.all_input_nodes: - if is_partitioned(input): + if is_partitioned(input, tag): continue if get_first_fake_tensor(input).dtype.is_floating_point: reporter.report_reject( @@ -265,8 +274,38 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: reject_partition( "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", partition, - tag, + reporter, ) + tags.remove(tag) + return tags + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """Partition the program and tag TOSA-compatible subgraphs. + + Run the FX capability-based partitioner to propose subgraphs, then + refine tags by removing boundary-only quantize/dequantize nodes and by + rejecting partitions that would lower to no-ops. Emit a detailed report + of rejected nodes and their reasons. + + Args: + exported_program (ExportedProgram): Program to analyze and + partition. + + Returns: + PartitionResult: The input program with nodes tagged for delegation + and a mapping of partition tags to delegation specs. + + """ + logger.info("TOSAPartitioner::partition") + logger.info( + f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}" + ) + + reporter = WhyNoPartitionReporter() + tags = self._tag_module( + exported_program.graph_module, exported_program, reporter + ) + partition_tags = {tag: self.delegation_spec for tag in tags} tag_constant_data(exported_program) logger.info(f"The following nodes were rejected for {self.tosa_spec}:")