Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/cadence/aot/memory_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ def is_slice_view(self, node: torch.fx.Node) -> bool:
return not self.constraint.is_alias_of(source_info.source, node)
return False

def has_relative_placement_constraint(self, node: torch.fx.Node) -> bool:
"""Return if `node` already has any relative placement constraint."""
return self.constraint.get_relative_placement_source(node) is not None

# Return true if the cat node performs concatenation along outermost dimension
def is_cat_along_outermost_dim(
self, graph_module: torch.fx.GraphModule, cat_node: torch.fx.Node
Expand Down Expand Up @@ -481,6 +485,17 @@ def is_removable_cat_op(
if any(self.is_slice_view(arg) for arg in cat_tensors):
return False

# If any of the tensors already has a relative placement constraint,
# we cannot add a new constraint for this cat without conflicting.
# This can happen when a tensor is used in multiple cat operations.
if any(self.has_relative_placement_constraint(arg) for arg in cat_tensors):
return False

# If the same tensor appears multiple times in the cat inputs,
# we cannot place it at multiple different offsets relative to the output.
if len(cat_tensors) != len(set(cat_tensors)):
return False

# Many ops in HiFi require the input to be aligned to 8-byte boundary.
# If the cat is not the graph's output, then ensure that the relative
# offset of any concatenated non-placeholder tensor is a multiple of
Expand Down
104 changes: 104 additions & 0 deletions backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,110 @@ def test_cat_then_cat(self) -> None:
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_cat_with_duplicate_input_tensor(self) -> None:
"""
Test that cat is NOT optimized when the same tensor appears multiple
times in the cat input list. This is because we cannot place the same
tensor at multiple different offsets relative to the output.
"""
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32))
to_add_to_x = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([3, 6], 123.0),
kwargs={"dtype": torch.float32},
)
add_x = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(x, to_add_to_x),
)
pre_created_output = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([6, 6], 0.0),
kwargs={"dtype": torch.float32},
)
# Same tensor (add_x) appears twice in the cat inputs
cat = builder.call_operator(
op=torch.ops.aten.cat.out,
args=([add_x, add_x],),
kwargs={"dim": 0, "out": pre_created_output},
)
builder.output([cat])
original = builder.get_graph_module()
graph_module = self.run_memory_planning(original)
graph_module.graph.eliminate_dead_code()

# Assert that cat op is NOT optimized away since the same tensor
# appears multiple times in the input list
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_cat_with_tensor_having_existing_constraint(self) -> None:
"""
Test that the second cat is NOT optimized when a tensor already has a
relative placement constraint from a previous cat operation.
"""
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(8, 8, dtype=torch.float32))
to_add = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([8, 8], 1.0),
kwargs={"dtype": torch.float32},
)
x1 = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(x, to_add),
)
x2 = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(x1, to_add),
)
x3 = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(x2, to_add),
)
# First cat: cat(x1, x2) - this will give x1 and x2 relative placement constraints
pre_created_output1 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([16, 8], 0.0),
kwargs={"dtype": torch.float32},
)
cat1 = builder.call_operator(
op=torch.ops.aten.cat.out,
args=([x1, x2],),
kwargs={"dim": 0, "out": pre_created_output1},
)
# Second cat: cat(x2, x3) - x2 already has a constraint from cat1,
# so this cat cannot be optimized
pre_created_output2 = builder.call_operator(
op=exir_ops.edge.aten.full.default,
args=([16, 8], 0.0),
kwargs={"dtype": torch.float32},
)
cat2 = builder.call_operator(
op=torch.ops.aten.cat.out,
args=([x2, x3],),
kwargs={"dim": 0, "out": pre_created_output2},
)
# Use both cat results to keep them alive
graph_output = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor,
args=(cat1, cat2),
)
builder.output([graph_output])
original = builder.get_graph_module()
graph_module = self.run_memory_planning(
original, opt_level=3, alloc_graph_input=False
)
graph_module.graph.eliminate_dead_code()

# The first cat should be optimized to _cat_nop, but the second cat
# cannot be optimized because x2 already has a relative placement constraint
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_view_for_unallocated_output(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.ones(3, 5, dtype=torch.float32))
Expand Down
Loading