From 506dd08e7abd721333afed05e44fd9459a09c323 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 6 Feb 2025 08:13:19 -0800 Subject: [PATCH] fix memory planning not skipping None values Summary: If a None is encountered skip it, but also dont increment the spec counter since Nones dont get specs Differential Revision: D69211071 --- exir/passes/memory_planning_pass.py | 5 ++- exir/tests/test_memory_planning.py | 47 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index 0fe137d54a7..f5431df431a 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -77,7 +77,8 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: out_alloc_node.meta["spec"] = node.meta["spec"] continue specs = get_node_tensor_specs(node) - for i, out_arg in enumerate(out_arg_names): + i = 0 + for out_arg in out_arg_names: out_alloc_node = node.kwargs[out_arg] if out_alloc_node is None: warnings.warn( @@ -85,6 +86,7 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: stacklevel=1, ) continue + # dont increment i as we dont have a spec for this node internal_assert( out_alloc_node.op == "call_function" and out_alloc_node.target == alloc, @@ -95,6 +97,7 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: f"Out-var's allocation node {out_alloc_node} already has a spec assigned", ) out_alloc_node.meta["spec"] = specs[i] + i += 1 @deprecated( "MemoryPlanningPass.call() is deprecated as it does not handle graphs \ diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index e4c47855b23..1e9a7e29f30 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -49,6 +49,7 @@ prepare_fx, ) from torch.export import export +from torch.export.experimental import _export_forward_backward from torch.export.exported_program import ExportGraphSignature from torch.fx import Graph, GraphModule, Node from torch.nn import functional as F @@ -724,3 +725,49 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: self.assertIsNone(node.meta["spec"].mem_offset) self.assertIsNone(node.meta["spec"].mem_id) self.assertEqual(constants, 2) + + def test_none_output(self) -> None: + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(6, 6, 5) + self.linear = nn.Linear(6, 2) + + def forward(self, x): + return self.linear(self.conv1(x).flatten(1)) + + class TrainingNet(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.loss = nn.CrossEntropyLoss() + + def forward(self, input, label): + pred = self.net(input) + return self.loss(pred, label) + + net = TrainingNet(Net()) + inputs = (torch.randn(1, 6, 5, 5), torch.ones(1, dtype=torch.int64)) + + ep = export(net, inputs) + ep = _export_forward_backward(ep) + ep = to_edge(ep) + ep = ep.to_executorch() + + ep.dump_executorch_program(True) + + # 155 just so happens to be the index of the user_grad output arg of + # convolution_backward.out. This is fairly fragile. + # Check that the None output is not memory planned. + self.assertEqual( + ep.executorch_program.execution_plan[0] + .values[155] + .val.data_buffer_idx, # pyright: ignore + 0, + ) + self.assertEqual( + ep.executorch_program.execution_plan[0] + .values[155] + .val.allocation_info, # pyright: ignore + None, + )