diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 661f8cf0d41..5cc00fa5ab0 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -374,3 +374,37 @@ python_unittest( "//executorch/exir/dialects:lib", ], ) + + +python_library( + name = "memory_planning", + srcs = [ + "memory_planning.py", + ], + deps = [ + "fbsource//third-party/pypi/tabulate:tabulate", + ":memory_constraints", + ":pass_utils", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:memory_planning", + "//executorch/exir:tensor", + "//executorch/exir/passes:lib", + ], +) + + +python_library( + name = "memory_constraints", + srcs = [ + "memory_constraints.py", + ], + deps = [ + ":pass_utils", + ":utils", + "//caffe2:torch", + "//executorch/exir:memory", + "//executorch/exir:pass_manager", + "//executorch/exir:tensor", + ], +) diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index 667c1286334..83a00518eb4 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import collections import itertools import logging @@ -331,14 +333,15 @@ def find_peak_memory_usage( # | Peak memory usage across all spaces | 2380032 bytes | Node 86 | # +-------------------------------------+---------------+---------+ def print_memory_planning_info( - # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type. executorch_prog: ExecutorchProgramManager, memory_config: MemoryConfig, + opt_level: int, alloc_graph_input: bool, alloc_graph_output: bool, ) -> None: # Get the peak memory usages per memory space mem_constraints = MemConstraints( + opt_level=opt_level, alloc_graph_input=alloc_graph_input, alloc_graph_output=alloc_graph_output, ) @@ -406,6 +409,7 @@ class CadenceMemoryPlanning: def __init__( self, memory_config: MemoryConfig, + opt_level: int, mem_algo: int, alloc_graph_input: bool = True, alloc_graph_output: bool = True, @@ -421,6 +425,7 @@ def __init__( self._init_mem_algos() self.memory_config = memory_config + self.opt_level = opt_level self.mem_algo = mem_algo self.alloc_graph_input = alloc_graph_input self.alloc_graph_output = alloc_graph_output @@ -434,6 +439,7 @@ def _init_mem_algos(self) -> None: def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: mem_constraints = MemConstraints( + opt_level=self.opt_level, alloc_graph_input=self.alloc_graph_input, alloc_graph_output=self.alloc_graph_output, ) @@ -448,7 +454,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: # True. mem_planning = MemoryPlanningPass( algo, - allow_lifetime_and_storage_overlap=False, + allow_lifetime_and_storage_overlap=(self.opt_level >= 2), alloc_graph_input=self.alloc_graph_input, alloc_graph_output=self.alloc_graph_output, )