Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions backends/aoti/aoti_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from executorch.exir.backend.utils import (
get_non_lowered_nodes,
tag_constant_data,
tag_mutated_buffer,
)
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.export.exported_program import ExportedProgram

Expand Down Expand Up @@ -60,8 +64,17 @@ def is_control_flow(node: torch.fx.Node) -> bool:
torch.ops.higher_order.while_loop,
]

# Nodes already lowered by an earlier partitioner (e.g. a preceding
# TensorRT partition) appear as executorch_call_delegate calls and their
# output getitems; re-delegating them would nest a foreign delegate. Tag
# only the remaining non-lowered ops so this partitioner composes after
# others.
non_lowered_nodes = set(get_non_lowered_nodes(exported_program.graph))

for node in exported_program.graph.nodes:
if node.op == "call_function":
if node not in non_lowered_nodes:
continue
node.meta["delegation_tag"] = tag
# Tag get_attr nodes that are used by control flow operations
elif node.op == "get_attr":
Expand All @@ -76,17 +89,22 @@ def is_control_flow(node: torch.fx.Node) -> bool:
tag_constant_data(exported_program)
tag_mutated_buffer(exported_program)

# Tag constant placeholders that have no users
# tag_constant_data only tags constants that have users with delegation_tag
# but we need to tag all constants for this partition
# A constant that still has users feeds only a prior delegate; tagging it
# would fail backend lowering's same-tag check (its user keeps the prior
# tag). tag_constant_data already claimed the ones this partition uses, so
# tag only the genuinely unused constants here.
for node in exported_program.graph.nodes:
if node.op == "placeholder" and (
is_param(exported_program, node)
or is_buffer(exported_program, node)
or is_lifted_tensor_constant(exported_program, node)
if (
node.op == "placeholder"
and not node.users
and "delegation_tag" not in node.meta
and (
is_param(exported_program, node)
or is_buffer(exported_program, node)
or is_lifted_tensor_constant(exported_program, node)
)
):
if "delegation_tag" not in node.meta:
node.meta["delegation_tag"] = tag
node.meta["delegation_tag"] = tag

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
Expand Down
98 changes: 98 additions & 0 deletions backends/cuda/tests/test_cuda_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
import unittest
from typing import Tuple

import torch
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.exir.backend.partitioner import PartitionResult
from executorch.exir.delegate import executorch_call_delegate
from torch._export.utils import is_buffer
from torch.export import export


Expand Down Expand Up @@ -222,3 +225,98 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
expected_tag,
f"Constant placeholder {node.name} has tag '{actual_tag}' but expected '{expected_tag}'",
)

def test_does_not_retag_already_lowered_delegate(self) -> None:
"""
A node already lowered by a previous partitioner appears as an
executorch_call_delegate call plus its output getitem. The CUDA
partitioner must not re-tag those, so it can run after another backend
(e.g. TensorRT) and only claim the remaining ops.
"""

class AddModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + x

exported_program = export(AddModule(), (torch.randn(3, 4),), strict=True)
graph_module = exported_program.graph_module
graph = graph_module.graph

placeholder = next(n for n in graph.nodes if n.op == "placeholder")
aten_node = next(
n
for n in graph.nodes
if n.op == "call_function" and n.target != operator.getitem
)

# Splice in a fake, already-lowered delegate (call + output getitem), as a
# preceding partitioner (e.g. TensorRT) would have produced.
graph_module.lowered_module_0 = torch.nn.Module()
with graph.inserting_before(aten_node):
lowered = graph.get_attr("lowered_module_0")
delegate = graph.call_function(
executorch_call_delegate, (lowered, placeholder)
)
delegate_output = graph.call_function(operator.getitem, (delegate, 0))
graph.lint()

CudaPartitioner([]).partition(exported_program)

self.assertNotIn("delegation_tag", delegate.meta)
self.assertNotIn("delegation_tag", delegate_output.meta)
self.assertIn("delegation_tag", aten_node.meta)

def test_does_not_tag_constant_used_only_by_prior_delegate(self) -> None:
"""
A constant whose only consumer is a previously lowered delegate must stay
untagged. Tagging it would give it this partition's tag while its user
keeps the prior delegate's, which backend lowering rejects. Only ops this
partitioner claims and genuinely unused constants may be tagged.
"""

class AddModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("w", torch.randn(3, 4))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.w

exported_program = export(AddModule(), (torch.randn(3, 4),), strict=True)
graph_module = exported_program.graph_module
graph = graph_module.graph

buffer_placeholder = next(
n
for n in graph.nodes
if n.op == "placeholder" and is_buffer(exported_program, n)
)
input_placeholder = next(
n
for n in graph.nodes
if n.op == "placeholder" and not is_buffer(exported_program, n)
)
aten_node = next(
n
for n in graph.nodes
if n.op == "call_function" and n.target != operator.getitem
)

# Make the buffer feed only a fake, already-lowered delegate (as a
# preceding TensorRT partition would): rewire the aten op off the buffer,
# then splice the delegate consuming it.
aten_node.replace_input_with(buffer_placeholder, input_placeholder)
graph_module.lowered_module_0 = torch.nn.Module()
with graph.inserting_before(aten_node):
lowered = graph.get_attr("lowered_module_0")
delegate = graph.call_function(
executorch_call_delegate, (lowered, buffer_placeholder)
)
graph.call_function(operator.getitem, (delegate, 0))
graph.lint()

CudaPartitioner([]).partition(exported_program)

self.assertNotIn("delegation_tag", buffer_placeholder.meta)
self.assertNotIn("delegation_tag", delegate.meta)
self.assertIn("delegation_tag", aten_node.meta)
Loading