@@ -331,14 +331,15 @@ def find_peak_memory_usage(
331331# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
332332# +-------------------------------------+---------------+---------+
333333def print_memory_planning_info (
334- # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
335334 executorch_prog : ExecutorchProgramManager ,
336335 memory_config : MemoryConfig ,
336+ opt_level : int ,
337337 alloc_graph_input : bool ,
338338 alloc_graph_output : bool ,
339339) -> None :
340340 # Get the peak memory usages per memory space
341341 mem_constraints = MemConstraints (
342+ opt_level = opt_level ,
342343 alloc_graph_input = alloc_graph_input ,
343344 alloc_graph_output = alloc_graph_output ,
344345 )
@@ -406,6 +407,7 @@ class CadenceMemoryPlanning:
406407 def __init__ (
407408 self ,
408409 memory_config : MemoryConfig ,
410+ opt_level : int ,
409411 mem_algo : int ,
410412 alloc_graph_input : bool = True ,
411413 alloc_graph_output : bool = True ,
@@ -421,6 +423,7 @@ def __init__(
421423 self ._init_mem_algos ()
422424
423425 self .memory_config = memory_config
426+ self .opt_level = opt_level
424427 self .mem_algo = mem_algo
425428 self .alloc_graph_input = alloc_graph_input
426429 self .alloc_graph_output = alloc_graph_output
@@ -434,6 +437,7 @@ def _init_mem_algos(self) -> None:
434437
435438 def __call__ (self , graph_module : torch .fx .GraphModule ) -> PassResult :
436439 mem_constraints = MemConstraints (
440+ opt_level = self .opt_level ,
437441 alloc_graph_input = self .alloc_graph_input ,
438442 alloc_graph_output = self .alloc_graph_output ,
439443 )
@@ -448,7 +452,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
448452 # True.
449453 mem_planning = MemoryPlanningPass (
450454 algo ,
451- allow_lifetime_and_storage_overlap = False ,
455+ allow_lifetime_and_storage_overlap = ( self . opt_level >= 2 ) ,
452456 alloc_graph_input = self .alloc_graph_input ,
453457 alloc_graph_output = self .alloc_graph_output ,
454458 )
0 commit comments