From 7b157b14798589dabb4a18ec736a9cc56f900c37 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Tue, 21 Jan 2025 14:01:35 -0800 Subject: [PATCH] Cleanup memory passes tests. (#7788) Summary: Add verifiers for memory allocation. Reviewed By: zonglinpeng, mcremon-meta Differential Revision: D68446633 --- .../cadence/aot/tests/test_memory_passes.py | 162 +++++++++++++++++- 1 file changed, 161 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index d1971ea6051..f1f01ea9296 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -1,7 +1,9 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +import logging import math import unittest +from typing import cast import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -110,7 +112,121 @@ def forward(self, x): class TestMemTransform(unittest.TestCase): - def test_optimize_cat(self): + def _verify_cat_nop_memory_alloc(self, node: torch.fx.Node) -> None: + spec = node.meta.get("spec", None) + self.assertIsNotNone(spec) + dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0 + outer_size = math.prod(spec.shape[:dim]) + self.assertEqual( + outer_size, + 1, + f"{node=} has wrong outer size: {outer_size=}, expected 1.", + ) + inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize + dim_offset = 0 + for arg in cast(list[torch.fx.Node], node.args[0]): + arg_spec = arg.meta.get("spec", None) + self.assertEqual(arg_spec.mem_id, spec.mem_id) + self.assertEqual( + arg_spec.mem_offset, + spec.mem_offset + dim_offset * inner_dim_elements, + f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {dim_offset=} for cat on {dim=}, but output has {spec.mem_offset=}", + ) + dim_offset += arg_spec.shape[dim] + + def _verify_slice_nop_memory_alloc(self, node: torch.fx.Node) -> None: + spec = node.meta.get("spec", None) + self.assertIsNotNone(spec) + dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0 + outer_size = math.prod(spec.shape[:dim]) + self.assertEqual( + outer_size, + 1, + f"{node=} has wrong outer size: {outer_size=}, expected 1.", + ) + inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize + start: int = ( + cast(int, node.args[2]) + if (len(node.args) > 2 and node.args[2] is not None) + else 0 + ) + arg = cast(torch.fx.Node, node.args[0]) + arg_spec = arg.meta.get("spec", None) + self.assertEqual(arg_spec.mem_id, spec.mem_id) + self.assertEqual( + spec.mem_offset, + arg_spec.mem_offset + start * inner_dim_elements, + f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {start=} for slice on {dim=}, but output has {spec.mem_offset=}", + ) + + def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None: + spec = node.meta.get("spec", None) + self.assertIsNotNone(spec) + dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0 + outer_size = math.prod(spec.shape[:dim]) + self.assertEqual( + outer_size, + 1, + f"{node=} has wrong outer size: {outer_size=}, expected 1.", + ) + inner_dim_elements = math.prod(spec.shape[dim:]) * spec.dtype.itemsize + index: int = ( + cast(int, node.args[2]) + if (len(node.args) > 2 and node.args[2] is not None) + else 0 + ) + arg = cast(torch.fx.Node, node.args[0]) + arg_spec = arg.meta.get("spec", None) + self.assertEqual(arg_spec.mem_id, spec.mem_id) + self.assertEqual( + spec.mem_offset, + arg_spec.mem_offset + index * inner_dim_elements, + f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} for select on {dim=} {index=}, " + f"but output has {spec.mem_offset=}" + f"{spec=} {arg_spec=}", + ) + + def verify_nop_memory_alloc(self, graph_module): + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten._cat_nop.out + ): + self._verify_cat_nop_memory_alloc(node) + + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten._slice_copy_nop.Tensor_out + ): + self._verify_slice_nop_memory_alloc(node) + + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten._select_copy_nop.int_out + ): + self._verify_select_nop_memory_alloc(node) + + def test_optimize_cat_on_placeholders(self): + class Cat(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.cat((x, y)) + + x = torch.ones(3, 6) + y = torch.ones(2, 6) + # Optimizing cat ops is only at opt_level 2+, and requires the memory planning + # pass to run: + graph_module = ( + compiler.export_to_executorch_gen_etrecord( + Cat(), (x, y), opt_level=2, mem_algo=1 + ) + .exported_program() + .graph_module + ) + logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}") + graph_module.graph.eliminate_dead_code() + # Assert that cat op is optimized away + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) + # Assert that cat op is replaced by its nop version post optimization + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + self.verify_nop_memory_alloc(graph_module) + + def test_optimize_cat_outermost(self): class OptimizeCatFeasible1(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -135,7 +251,9 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) # Assert that cat op is replaced by its nop version post optimization self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + self.verify_nop_memory_alloc(graph_module) + def test_optimize_cat_non_outermost(self): class OptimizeCatFeasible2(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -160,7 +278,9 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) # Assert that cat op is replaced by its nop version post optimization self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + self.verify_nop_memory_alloc(graph_module) + def test_no_optimize_cat_non_outermost(self): class OptimizeCatInfeasible1(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -184,7 +304,9 @@ def forward(self, x, y): # Assert that cat op is not optimized away, since the concat is not # along the outermost dim self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + self.verify_nop_memory_alloc(graph_module) + def test_no_optimize_cat_non_outermost1(self): class OptimizeCatInfeasible2(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -209,6 +331,7 @@ def forward(self, x, y): # offsets are not multiple of 8 bytes, and the cat is not the output # of the graph. self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + self.verify_nop_memory_alloc(graph_module) def test_optimize_cat_with_slice(self): class OptimizeCatSliceFeasible(torch.nn.Module): @@ -237,6 +360,7 @@ def forward(self, x): graph_module.graph.eliminate_dead_code() # Assert that cat op is optimized away self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + self.verify_nop_memory_alloc(graph_module) def test_optimize_cat_with_slice_infeasible(self): class OptimizeCatSliceInfeasible(torch.nn.Module): @@ -262,6 +386,7 @@ def forward(self, x, y): graph_module.graph.eliminate_dead_code() # Assert that cat op is not optimized away self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + self.verify_nop_memory_alloc(graph_module) def test_optimize_slice_Tensor(self): class SliceTensor(torch.nn.Module): @@ -323,6 +448,7 @@ def forward(self, x, y, z): self.assertEqual( count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 3 ) + self.verify_nop_memory_alloc(graph_module) def test_optimize_select_Tensor(self): class SelectTensor(torch.nn.Module): @@ -387,6 +513,7 @@ def forward(self, x, y, z): self.assertEqual( count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 3 ) + self.verify_nop_memory_alloc(graph_module) # TODO: Test fails due to memory planning @unittest.expectedFailure @@ -416,6 +543,32 @@ def forward(self, x, y): graph_module.graph.eliminate_dead_code() # Assert that cat op is not optimized away self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1) + self.verify_nop_memory_alloc(graph_module) + + def test_optimize_cat_then_slice_on_mutable_buffer(self): + class CatWithPadding(torch.nn.Module): + def __init__(self, padding_shape): + super().__init__() + zeros = torch.zeros(padding_shape) + self.register_buffer("padding", zeros) + + def forward(self, x, y): + x = x.view(3, 5) + cat = torch.ops.aten.cat((x, self.padding.clone())) + slice_copy = torch.ops.aten.slice(cat, dim=0, start=x.shape[0]) + self.padding.copy_(slice_copy) + return cat.view(-1) + y + + x = torch.ones(15) + y = torch.ones(1) + et_prog_manager = compiler.export_to_executorch_gen_etrecord( + CatWithPadding((1, 5)), (x, y), opt_level=3 + ) + graph_module = et_prog_manager.exported_program().graph_module + logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}") + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + self.verify_nop_memory_alloc(graph_module) def test_optimize_cat_with_view(self): class CatViewFeasible(torch.nn.Module): @@ -442,6 +595,7 @@ def forward(self, x, y): # Assert that cat op is optimized away self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) + self.verify_nop_memory_alloc(graph_module) def test_no_optimize_cat_with_repeated_args(self): class CatViewInfeasible(torch.nn.Module): @@ -465,6 +619,7 @@ def forward(self, x): # Assert that cat op is not optimized away self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) + self.verify_nop_memory_alloc(graph_module) def test_no_optimize_cat_with_placeholder(self): class CatViewInfeasible(torch.nn.Module): @@ -492,6 +647,7 @@ def forward(self, x, y): # Assert that cat op is not optimized away self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) + self.verify_nop_memory_alloc(graph_module) def test_no_optimize_cat(self) -> None: class Model(torch.nn.Module): @@ -522,6 +678,7 @@ def forward(self, x) -> torch.Tensor: count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2 ) self.assertEqual(count_node(graph_module, memory.view), 2) + self.verify_nop_memory_alloc(graph_module) def test_optimize_slice_copy(self) -> None: class Model(torch.nn.Module): @@ -553,6 +710,7 @@ def forward(self, x) -> torch.Tensor: count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 0 ) self.assertEqual(count_node(graph_module, memory.view), 2) + self.verify_nop_memory_alloc(graph_module) def test_cat_then_cat(self) -> None: class Model(torch.nn.Module): @@ -579,6 +737,7 @@ def forward(self, x) -> torch.Tensor: graph_module.print_readable() self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 2) self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) + self.verify_nop_memory_alloc(graph_module) def test_view_for_unallocated_output(self): class Model(torch.nn.Module): @@ -602,3 +761,4 @@ def forward(self, x, y): .graph_module ) self.assertEqual(count_node(graph_module, memory.view), 1) + self.verify_nop_memory_alloc(graph_module)