Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/aoti/aoti_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.export.exported_program import ExportedProgram


Expand Down Expand Up @@ -61,6 +62,18 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
tag_constant_data(exported_program)
tag_mutated_buffer(exported_program)

# Tag constant placeholders that have no users
# tag_constant_data only tags constants that have users with delegation_tag
# but we need to tag all constants for this partition
for node in exported_program.graph.nodes:
if node.op == "placeholder" and (
is_param(exported_program, node)
or is_buffer(exported_program, node)
or is_lifted_tensor_constant(exported_program, node)
):
if "delegation_tag" not in node.meta:
node.meta["delegation_tag"] = tag

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
Expand Down
83 changes: 83 additions & 0 deletions backends/cuda/tests/test_cuda_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
fully_partitioned,
"Graph should be fully partitioned with all operators having the same tag",
)

def test_unused_constant_tagging(self):
"""
Test that constant nodes without users are properly tagged with delegation_tag.

When a graph contains constants (parameters, buffers, or lifted tensor constants)
that are not used by any operations, the CUDA partitioner should still tag them
with the delegation_tag. This ensures all constant data is properly handled during
delegation, even if they have no users in the graph.
"""

class ModuleWithUnusedConst(torch.nn.Module):
def __init__(self):
super().__init__()
# Register a buffer that won't be used in forward
self.register_buffer("unused_buffer", torch.randn(10, 10))
# Also register a used parameter
self.weight = torch.nn.Parameter(torch.randn(5, 5))

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Only use the weight parameter, not the unused_buffer
return x + self.weight

module = ModuleWithUnusedConst()
inputs = (torch.randn(5, 5),)

# Get partition result
partition_result = self._get_partition_result(module, inputs)

# Find all placeholder nodes (these represent constants, parameters, buffers, and inputs)
constant_placeholders = []
input_placeholders = []

for node in partition_result.tagged_exported_program.graph.nodes:
if node.op == "placeholder":
# Check if this is a constant (param, buffer, or lifted tensor constant)
from torch._export.utils import (
is_buffer,
is_lifted_tensor_constant,
is_param,
)

is_constant = (
is_param(partition_result.tagged_exported_program, node)
or is_buffer(partition_result.tagged_exported_program, node)
or is_lifted_tensor_constant(
partition_result.tagged_exported_program, node
)
)

if is_constant:
constant_placeholders.append(node)
else:
input_placeholders.append(node)

# Verify we have constant placeholders
self.assertGreater(
len(constant_placeholders),
0,
"Expected to find constant placeholders in the graph",
)

# Check that all constant placeholders are tagged, including unused ones
untagged_constants = []
for node in constant_placeholders:
if "delegation_tag" not in node.meta:
untagged_constants.append(node.name)

self.assertEqual(
len(untagged_constants),
0,
f"All constant placeholders should be tagged. Found untagged constants: {untagged_constants}",
)

# Verify all tagged constants have the expected tag
expected_tag = "tag0"
for node in constant_placeholders:
actual_tag = node.meta.get("delegation_tag")
self.assertEqual(
actual_tag,
expected_tag,
f"Constant placeholder {node.name} has tag '{actual_tag}' but expected '{expected_tag}'",
)
Loading