Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 106 additions & 67 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import logging
from itertools import count
from typing import Callable, List, Optional, Sequence, Tuple

import torch
Expand All @@ -35,8 +36,10 @@
)
from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.passes.operator_support import OperatorSupportBase

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -110,6 +113,43 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
return all(m == 1 for m in multiples)


def is_partitioned(
node: torch.fx.Node,
tag: str,
) -> bool:
"""Return True if the node currently belongs to the partition ``tag``.

Args:
node (torch.fx.Node): FX node to check.
tag (str): Delegation tag identifying the partition.

Returns:
bool: True if the node carries the matching delegation tag.

"""
return "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag


def reject_partition(
reason: str, partition: Partition, reporter: WhyNoPartitionReporter
) -> None:
"""Remove a proposed partition and record the rejection reason.

Args:
reason (str): Human-readable explanation for rejection.
partition (object): Proposed partition object from the
capability partitioner.
reporter (WhyNoPartitionReporter): used to report why nodes were rejected.
"""
for node in partition.nodes:
if "delegation_tag" in node.meta:
del node.meta["delegation_tag"]
reporter.report_reject(
node,
reason,
)


class TOSAPartitioner(Partitioner):
"""Partition an exported program into TOSA-delegable subgraphs.

Expand Down Expand Up @@ -142,107 +182,76 @@ def __init__(
self.additional_checks = additional_checks
self.tosa_spec = compile_spec.tosa_spec

def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa
"""Partition the program and tag TOSA-compatible subgraphs.

Run the FX capability-based partitioner to propose subgraphs, then
refine tags by removing boundary-only quantize/dequantize nodes and by
rejecting partitions that would lower to no-ops. Emit a detailed report
of rejected nodes and their reasons.
def _tag_module( # noqa
self,
module: GraphModule,
containing_program: ExportedProgram,
reporter: WhyNoPartitionReporter,
tag_iterator: count | None = None,
) -> set[str]:
"""Tag nodes in a module, possibly a submodule, from the containing program.

Args:
exported_program (ExportedProgram): Program to analyze and
partition.

module: a GraphModule from `containing_program` to tag nodes in.
containing_program: The ExportedProgram that contains the module.
reporter: A reporter to report why nodes were rejected.
Returns:
PartitionResult: The input program with nodes tagged for delegation
and a mapping of partition tags to delegation specs.

A set of strings with the partition tags.
"""
logger.info("TOSAPartitioner::partition")
partition_tags: dict[str, DelegationSpec] = {}

logger.info(
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
)

reporter = WhyNoPartitionReporter()
tags: set[str] = set()
if tag_iterator is None:
tag_iterator = count(0)
for _, submodule, _ in get_control_flow_submodules(module):
submodule_tags = self._tag_module(
submodule, containing_program, reporter, tag_iterator
)
if len(tags & submodule_tags) != 0:
raise RuntimeError(
"Got overlapping tags in two different modules, this shouldn't happen."
)
tags = tags | submodule_tags
operator_support = tosa_support_factory(
self.tosa_spec, exported_program, reporter, self.additional_checks
self.tosa_spec, containing_program, reporter, self.additional_checks
)
capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
module,
operator_support,
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()

def reject_partition(reason: str, partition, tag) -> None:
"""Remove a proposed partition and record the rejection reason.

Args:
reason (str): Human-readable explanation for rejection.
partition (object): Proposed partition object from the
capability partitioner.
tag (str): Delegation tag associated with the partition.

"""
for node in partition.nodes:
if "delegation_tag" in node.meta:
del node.meta["delegation_tag"]
reporter.report_reject(
node,
reason,
)
partition_tags.pop(tag, None)

for partition in partition_list:
tag = f"tag{partition.id}"

def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
"""Return True if the node currently belongs to the partition ``tag``.

Args:
node (torch.fx.Node): FX node to check.
tag (str): Delegation tag identifying the partition.

Returns:
bool: True if the node carries the matching delegation tag.

"""
return (
"delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
)
tag = f"tag{next(tag_iterator)}"
tags.add(tag)

for node in partition.nodes:
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

# De-tag outermost q-nodes upwards and dq-nodes downwards.
# De-tag if at least one input/output is not part of the partition.
for node in exported_program.graph_module.graph.nodes:
if not is_partitioned(node):
for node in module.graph.nodes:
if not is_partitioned(node, tag):
continue
if node.target in Q_OPS:
for input in node.all_input_nodes:
if not is_partitioned(input):
if not is_partitioned(input, tag):
del node.meta["delegation_tag"]
break
continue

if node.target in DQ_OPS:
for user in node.users:
if not is_partitioned(user):
if not is_partitioned(user, tag):
del node.meta["delegation_tag"]
break
continue

if self.tosa_spec.support_float():
continue

if is_partitioned(node):
if is_partitioned(node, tag):
for input in node.all_input_nodes:
if is_partitioned(input):
if is_partitioned(input, tag):
continue
if get_first_fake_tensor(input).dtype.is_floating_point:
reporter.report_reject(
Expand All @@ -265,8 +274,38 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
reject_partition(
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
partition,
tag,
reporter,
)
tags.remove(tag)
return tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""Partition the program and tag TOSA-compatible subgraphs.

Run the FX capability-based partitioner to propose subgraphs, then
refine tags by removing boundary-only quantize/dequantize nodes and by
rejecting partitions that would lower to no-ops. Emit a detailed report
of rejected nodes and their reasons.

Args:
exported_program (ExportedProgram): Program to analyze and
partition.

Returns:
PartitionResult: The input program with nodes tagged for delegation
and a mapping of partition tags to delegation specs.

"""
logger.info("TOSAPartitioner::partition")
logger.info(
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
)

reporter = WhyNoPartitionReporter()
tags = self._tag_module(
exported_program.graph_module, exported_program, reporter
)
partition_tags = {tag: self.delegation_spec for tag in tags}

tag_constant_data(exported_program)
logger.info(f"The following nodes were rejected for {self.tosa_spec}:")
Expand Down
Loading