@@ -331,14 +331,15 @@ def find_peak_memory_usage(
331331# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
332332# +-------------------------------------+---------------+---------+
333333def print_memory_planning_info (
334- # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
335334 executorch_prog : ExecutorchProgramManager ,
336335 memory_config : MemoryConfig ,
336+ opt_level : int ,
337337 alloc_graph_input : bool ,
338338 alloc_graph_output : bool ,
339339) -> None :
340340 # Get the peak memory usages per memory space
341341 mem_constraints = MemConstraints (
342+ opt_level = opt_level ,
342343 alloc_graph_input = alloc_graph_input ,
343344 alloc_graph_output = alloc_graph_output ,
344345 )
@@ -406,6 +407,7 @@ class CadenceMemoryPlanning:
406407 def __init__ (
407408 self ,
408409 memory_config : MemoryConfig ,
410+ opt_level : int ,
409411 mem_algo : int ,
410412 alloc_graph_input : bool = True ,
411413 alloc_graph_output : bool = True ,
@@ -421,6 +423,7 @@ def __init__(
421423 self ._init_mem_algos ()
422424
423425 self .memory_config = memory_config
426+ self .opt_level = opt_level
424427 self .mem_algo = mem_algo
425428 self .alloc_graph_input = alloc_graph_input
426429 self .alloc_graph_output = alloc_graph_output
@@ -448,7 +451,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
448451 # True.
449452 mem_planning = MemoryPlanningPass (
450453 algo ,
451- allow_lifetime_and_storage_overlap = False ,
454+ allow_lifetime_and_storage_overlap = ( self . opt_level >= 2 ) ,
452455 alloc_graph_input = self .alloc_graph_input ,
453456 alloc_graph_output = self .alloc_graph_output ,
454457 )
0 commit comments