Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions backends/arm/test/misc/test_partitioner_tag_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from types import SimpleNamespace

from executorch.backends.arm.tosa import partitioner as tosa_partitioner
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner


class _FakeCapabilityBasedPartitioner:
def __init__(self, *args, **kwargs) -> None:
pass

def propose_partitions(self):
return [
SimpleNamespace(nodes=[SimpleNamespace(meta={}, target=f"op{idx}")])
for idx in range(3)
]


def _make_reporter() -> SimpleNamespace:
return SimpleNamespace(
report_reject=lambda *args, **kwargs: None,
get_table_report=lambda: "",
)


def test_tag_module_preserves_partition_discovery_order(monkeypatch):
partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))

monkeypatch.setattr(
tosa_partitioner, "get_cond_while_submodules_nested", lambda module: []
)
monkeypatch.setattr(
tosa_partitioner, "tosa_support_factory", lambda *args, **kwargs: object()
)
monkeypatch.setattr(
tosa_partitioner,
"CapabilityBasedPartitioner",
_FakeCapabilityBasedPartitioner,
)
monkeypatch.setattr(
partitioner,
"_partition_has_invalid_uint8",
lambda partition, tag: False,
)
monkeypatch.setattr(
partitioner,
"_preserve_io_quantization_enabled",
lambda: False,
)

tags = partitioner._tag_module(
SimpleNamespace(graph=SimpleNamespace(nodes=[])),
SimpleNamespace(),
_make_reporter(),
)

assert tags == ["tag0", "tag1", "tag2"]


def test_partition_preserves_tag_discovery_order(monkeypatch):
partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))

monkeypatch.setattr(
partitioner,
"_tag_module",
lambda *args, **kwargs: ["tag2", "tag10"],
)
monkeypatch.setattr(tosa_partitioner, "tag_constant_data", lambda program: None)
monkeypatch.setattr(tosa_partitioner, "WhyNoPartitionReporter", _make_reporter)

result = partitioner.partition(SimpleNamespace(graph_module=SimpleNamespace()))

assert list(result.partition_tags) == ["tag2", "tag10"]
19 changes: 13 additions & 6 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _tag_module( # noqa
containing_program: ExportedProgram,
reporter: WhyNoPartitionReporter,
tag_iterator: count | None = None,
) -> set[str]:
) -> list[str]:
"""Tag nodes in a module or submodule from the containing program.

Args:
Expand All @@ -298,21 +298,25 @@ def _tag_module( # noqa
reporter: A reporter to report why nodes were rejected.

Returns:
A set of strings with the partition tags.
A list of strings with the partition tags in discovery order.

"""
tags: set[str] = set()
# Preserve discovery order so backend lowering sees a deterministic
# partition order across Python processes.
tags: list[str] = []
seen_tags: set[str] = set()
if tag_iterator is None:
tag_iterator = count(0)
for _, submodule, _ in get_cond_while_submodules_nested(module):
submodule_tags = self._tag_module(
submodule, containing_program, reporter, tag_iterator
)
if len(tags & submodule_tags) != 0:
if any(tag in seen_tags for tag in submodule_tags):
raise RuntimeError(
"Got overlapping tags in two different modules, this shouldn't happen."
)
tags = tags | submodule_tags
tags.extend(submodule_tags)
seen_tags.update(submodule_tags)
operator_support = tosa_support_factory(
self.tosa_spec, containing_program, reporter, self.additional_checks
)
Expand All @@ -335,7 +339,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:

for partition in partition_list:
tag = f"tag{next(tag_iterator)}"
tags.add(tag)
tags.append(tag)
seen_tags.add(tag)

for node in partition.nodes:
node.meta["delegation_tag"] = tag
Expand Down Expand Up @@ -364,6 +369,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
reporter,
)
tags.remove(tag)
seen_tags.remove(tag)
continue

# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
Expand All @@ -385,6 +391,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
reporter,
)
tags.remove(tag)
seen_tags.remove(tag)
return tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
Expand Down
Loading