From 934c245dbd262a4a61f41736e0b003470225add2 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Sat, 6 Jun 2026 12:28:01 -0700 Subject: [PATCH] Make CUDA/AOTI partitioner composable after another delegate (#20077) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: `AotiPartitioner.partition` tagged every `call_function` node, including `executorch_call_delegate` calls already lowered by an earlier partitioner. So when `CudaPartitioner` runs as a second partitioner — e.g. after a TensorRT partition in a stacked `.pte` where TensorRT lowers the ops it can and the CUDA backend handles the rest — it tried to re-delegate the foreign delegate node, producing a malformed nested delegate. This is the blocker to composing the two backends in one `.pte`. Tag only the non-lowered nodes, reusing the existing `get_non_lowered_nodes` helper (which already excludes `executorch_call_delegate` calls and their output getitems), so the partitioner claims just the remaining ops and composes cleanly after another backend. In the single-partitioner case there are no delegate nodes, so `get_non_lowered_nodes` returns every `call_function` and behavior is unchanged. The same composition gap existed for constants: the final loop tagged every untagged param/buffer/lifted constant with this partition's tag, including ones consumed only by the foreign delegate. Backend lowering rejected those, since it requires every user of a tagged constant to share that tag while the foreign delegate's call keeps the prior one. Now only genuinely unused constants are tagged here — `tag_constant_data` already claims the ones this partition uses, and a constant feeding only a prior delegate is left untagged. Mirrored in fbcode and xplat. Differential Revision: D107690797 --- backends/aoti/aoti_partitioner.py | 38 ++++++-- backends/cuda/tests/test_cuda_partitioner.py | 98 ++++++++++++++++++++ 2 files changed, 126 insertions(+), 10 deletions(-) diff --git a/backends/aoti/aoti_partitioner.py b/backends/aoti/aoti_partitioner.py index aa56d3507e9..b263d0f9c81 100644 --- a/backends/aoti/aoti_partitioner.py +++ b/backends/aoti/aoti_partitioner.py @@ -14,7 +14,11 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from executorch.exir.backend.utils import ( + get_non_lowered_nodes, + tag_constant_data, + tag_mutated_buffer, +) from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.export.exported_program import ExportedProgram @@ -60,8 +64,17 @@ def is_control_flow(node: torch.fx.Node) -> bool: torch.ops.higher_order.while_loop, ] + # Nodes already lowered by an earlier partitioner (e.g. a preceding + # TensorRT partition) appear as executorch_call_delegate calls and their + # output getitems; re-delegating them would nest a foreign delegate. Tag + # only the remaining non-lowered ops so this partitioner composes after + # others. + non_lowered_nodes = set(get_non_lowered_nodes(exported_program.graph)) + for node in exported_program.graph.nodes: if node.op == "call_function": + if node not in non_lowered_nodes: + continue node.meta["delegation_tag"] = tag # Tag get_attr nodes that are used by control flow operations elif node.op == "get_attr": @@ -76,17 +89,22 @@ def is_control_flow(node: torch.fx.Node) -> bool: tag_constant_data(exported_program) tag_mutated_buffer(exported_program) - # Tag constant placeholders that have no users - # tag_constant_data only tags constants that have users with delegation_tag - # but we need to tag all constants for this partition + # A constant that still has users feeds only a prior delegate; tagging it + # would fail backend lowering's same-tag check (its user keeps the prior + # tag). tag_constant_data already claimed the ones this partition uses, so + # tag only the genuinely unused constants here. for node in exported_program.graph.nodes: - if node.op == "placeholder" and ( - is_param(exported_program, node) - or is_buffer(exported_program, node) - or is_lifted_tensor_constant(exported_program, node) + if ( + node.op == "placeholder" + and not node.users + and "delegation_tag" not in node.meta + and ( + is_param(exported_program, node) + or is_buffer(exported_program, node) + or is_lifted_tensor_constant(exported_program, node) + ) ): - if "delegation_tag" not in node.meta: - node.meta["delegation_tag"] = tag + node.meta["delegation_tag"] = tag return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags diff --git a/backends/cuda/tests/test_cuda_partitioner.py b/backends/cuda/tests/test_cuda_partitioner.py index c08c0e6ff56..0ee345be08a 100644 --- a/backends/cuda/tests/test_cuda_partitioner.py +++ b/backends/cuda/tests/test_cuda_partitioner.py @@ -4,12 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator import unittest from typing import Tuple import torch from executorch.backends.cuda.cuda_partitioner import CudaPartitioner from executorch.exir.backend.partitioner import PartitionResult +from executorch.exir.delegate import executorch_call_delegate +from torch._export.utils import is_buffer from torch.export import export @@ -222,3 +225,98 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expected_tag, f"Constant placeholder {node.name} has tag '{actual_tag}' but expected '{expected_tag}'", ) + + def test_does_not_retag_already_lowered_delegate(self) -> None: + """ + A node already lowered by a previous partitioner appears as an + executorch_call_delegate call plus its output getitem. The CUDA + partitioner must not re-tag those, so it can run after another backend + (e.g. TensorRT) and only claim the remaining ops. + """ + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + x + + exported_program = export(AddModule(), (torch.randn(3, 4),), strict=True) + graph_module = exported_program.graph_module + graph = graph_module.graph + + placeholder = next(n for n in graph.nodes if n.op == "placeholder") + aten_node = next( + n + for n in graph.nodes + if n.op == "call_function" and n.target != operator.getitem + ) + + # Splice in a fake, already-lowered delegate (call + output getitem), as a + # preceding partitioner (e.g. TensorRT) would have produced. + graph_module.lowered_module_0 = torch.nn.Module() + with graph.inserting_before(aten_node): + lowered = graph.get_attr("lowered_module_0") + delegate = graph.call_function( + executorch_call_delegate, (lowered, placeholder) + ) + delegate_output = graph.call_function(operator.getitem, (delegate, 0)) + graph.lint() + + CudaPartitioner([]).partition(exported_program) + + self.assertNotIn("delegation_tag", delegate.meta) + self.assertNotIn("delegation_tag", delegate_output.meta) + self.assertIn("delegation_tag", aten_node.meta) + + def test_does_not_tag_constant_used_only_by_prior_delegate(self) -> None: + """ + A constant whose only consumer is a previously lowered delegate must stay + untagged. Tagging it would give it this partition's tag while its user + keeps the prior delegate's, which backend lowering rejects. Only ops this + partitioner claims and genuinely unused constants may be tagged. + """ + + class AddModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("w", torch.randn(3, 4)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.w + + exported_program = export(AddModule(), (torch.randn(3, 4),), strict=True) + graph_module = exported_program.graph_module + graph = graph_module.graph + + buffer_placeholder = next( + n + for n in graph.nodes + if n.op == "placeholder" and is_buffer(exported_program, n) + ) + input_placeholder = next( + n + for n in graph.nodes + if n.op == "placeholder" and not is_buffer(exported_program, n) + ) + aten_node = next( + n + for n in graph.nodes + if n.op == "call_function" and n.target != operator.getitem + ) + + # Make the buffer feed only a fake, already-lowered delegate (as a + # preceding TensorRT partition would): rewire the aten op off the buffer, + # then splice the delegate consuming it. + aten_node.replace_input_with(buffer_placeholder, input_placeholder) + graph_module.lowered_module_0 = torch.nn.Module() + with graph.inserting_before(aten_node): + lowered = graph.get_attr("lowered_module_0") + delegate = graph.call_function( + executorch_call_delegate, (lowered, buffer_placeholder) + ) + graph.call_function(operator.getitem, (delegate, 0)) + graph.lint() + + CudaPartitioner([]).partition(exported_program) + + self.assertNotIn("delegation_tag", buffer_placeholder.meta) + self.assertNotIn("delegation_tag", delegate.meta) + self.assertIn("delegation_tag", aten_node.meta)