diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 9cc8a1e809c..fe18e49a623 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1640,13 +1640,26 @@ def placeholder( # noqa: C901 else: spec.extra_tensor_info.fully_qualified_name = fqn spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL - if self.emitter_state.emit_mutable_buffer_names and is_mutable_buffer: - if spec.extra_tensor_info is None: - spec.extra_tensor_info = ExtraTensorInfo( - fully_qualified_name=fqn, location=TensorDataLocation.SEGMENT + + if is_mutable_buffer: + # Emit names if we are supposed to. + if self.emitter_state.emit_mutable_buffer_names: + if spec.extra_tensor_info is None: + spec.extra_tensor_info = ExtraTensorInfo( + fully_qualified_name=fqn, + location=TensorDataLocation.SEGMENT, + ) + else: + spec.extra_tensor_info.fully_qualified_name = fqn + # if We aren't emitting the name then it needs to be memory planned. + elif spec.mem_id is None or spec.mem_offset is None: + raise InternalError( + self._emit_node_specific_error( + self.node, + # [2:] to remove the b_ prefix buffers get + f'Mutable buffer "{target[2:]}" must have a memory id and offset if we are emitting it without a name. Please either memory plan your mutable buffers or call to_executorch with config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)', + ) ) - else: - spec.extra_tensor_info.fully_qualified_name = fqn # From the fqn find the corresponding tensor real_tensor = None diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index a2df53e8cf3..186c5a402ab 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1838,8 +1838,40 @@ def forward(self, x): ep = to_edge(ep) # Lower the graph to executorch. ep = ep.to_executorch( - config=ExecutorchBackendConfig(emit_mutable_buffer_names=True) + config=ExecutorchBackendConfig( + emit_mutable_buffer_names=True, + memory_planning_pass=MemoryPlanningPass(alloc_mutable_buffers=False), + ) ) for val in ep.executorch_program.execution_plan[0].values: if isinstance(val, Tensor) and val.extra_tensor_info: self.assertEqual(val.extra_tensor_info.fully_qualified_name, "buffer") + self.assertEqual(val.allocation_info, None) + + def test_emit_mutable_buffer_names_fails(self) -> None: + class Net(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + self.register_buffer("buffer", torch.zeros(1, 2)) + + def forward(self, x): + self.buffer.add_(1) + return self.linear(x) + self.buffer + + net = Net() + + ep = export(net, (torch.randn(1, 2),), strict=True) + # Lower the graph to edge dialect. + ep = to_edge(ep) + # Lower the graph to executorch. + # Must emit mutable buffer names if we don't allocate mutable buffers + with self.assertRaises(InternalError): + ep.to_executorch( + config=ExecutorchBackendConfig( + emit_mutable_buffer_names=False, + memory_planning_pass=MemoryPlanningPass( + alloc_mutable_buffers=False + ), + ) + ) diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 17640a9f7aa..83598940882 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -44,12 +44,14 @@ def __init__( graph_module: torch.fx.GraphModule, alloc_graph_input: bool, alloc_graph_output: bool, + alloc_mutable_buffers: bool, graph_signature: Optional[ExportGraphSignature] = None, ) -> None: self.graph_module = graph_module self.graph_signature = graph_signature self.alloc_graph_input = alloc_graph_input self.alloc_graph_output = alloc_graph_output + self.alloc_mutable_buffers = alloc_mutable_buffers @classmethod def mem_obj_id_match( @@ -149,6 +151,7 @@ def verify_storage_reuse( ignore_const=True, ignore_graph_input=not self.alloc_graph_input, ignore_graph_output=not self.alloc_graph_output, + ignore_mutable_buffers=not self.alloc_mutable_buffers, do_assertion=False, ignore_out_var_node=False, dedup=True, @@ -374,6 +377,7 @@ def collect_specs_from_nodes( # noqa: C901 graph_signature: Optional[ExportGraphSignature] = None, ignore_graph_input: bool = False, ignore_graph_output: bool = False, + ignore_mutable_buffers: bool = False, ignore_const: bool = True, ignore_out_var_node: bool = True, dedup: bool = True, @@ -414,6 +418,9 @@ def collect_specs_from_nodes( # noqa: C901 if _is_inplace_node(node): continue + if _is_mutable_buffer(node, graph_signature) and ignore_mutable_buffers: + continue + if do_assertion: internal_assert( node.op in ("placeholder", "output") @@ -469,6 +476,7 @@ def update_all_tensors_lifetime( Set the lifetime for all the tensors encountered in the Fx graph. """ specs = set() + for node_idx, node in enumerate(graph_module.graph.nodes): for spec in collect_specs_from_nodes( filter_nodes(itertools.chain([node], node.args, node.kwargs.values())), @@ -1053,6 +1061,7 @@ def apply_algo( graph_signature: Optional[ExportGraphSignature] = None, alloc_graph_input: bool = True, alloc_graph_output: bool = True, + alloc_mutable_buffers: bool = True, ) -> List[int]: """ Recursively apply algo to graph_module and its submodules for control flow. @@ -1065,12 +1074,10 @@ def apply_algo( storage with tensors in the outer module. TODO: make these optimizations once we have some baseline working. """ - # Extract the nodes and their lifespans from the graph_module # Difficult to just filter the list of specs returned by this due to # how we flag trainable weights. _ = update_all_tensors_lifetime(graph_module, graph_signature) - # Filter specs based on alloc_graph_input and alloc_graph_output specs = collect_specs_from_nodes( graph_module.graph.nodes, @@ -1078,6 +1085,7 @@ def apply_algo( do_assertion=False, ignore_graph_input=not alloc_graph_input, ignore_graph_output=not alloc_graph_output, + ignore_mutable_buffers=not alloc_mutable_buffers, ) # Get extra padding for XNNPACK if needed diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index f4b3ad8a8a7..9bd4ab20bf5 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -44,6 +44,7 @@ def __init__( allow_lifetime_and_storage_overlap: bool = False, alloc_graph_input: bool = True, alloc_graph_output: bool = True, + alloc_mutable_buffers: bool = True, alignment: int = ALIGNMENT, ) -> None: r""" @@ -54,10 +55,11 @@ def __init__( """ if memory_planning_algo is None: memory_planning_algo = MemoryPlanningAlgorithmSuite() - self.memory_planning_algo = memory_planning_algo + self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap self.alloc_graph_input = alloc_graph_input self.alloc_graph_output = alloc_graph_output + self.alloc_mutable_buffers = alloc_mutable_buffers self.alignment = alignment def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: @@ -124,13 +126,15 @@ def run( # customized fields. Using the graph_module object to convey information across # passes/stages is quite natural and avoid yet another 'context' data structure # to do the job. + _ = apply_algo( - self.memory_planning_algo, # pyre-ignore[6] + self.memory_planning_algo, graph_module, self.alignment, graph_signature, self.alloc_graph_input, self.alloc_graph_output, + self.alloc_mutable_buffers, ) # TODO: make the verifier do the work recursively to handle @@ -139,6 +143,7 @@ def run( graph_module, self.alloc_graph_input, self.alloc_graph_output, + self.alloc_mutable_buffers, graph_signature, ) diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 52986aaa04c..b87ae2dfb58 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -241,6 +241,7 @@ def maketest( use_functionalization: bool = True, alloc_graph_input: bool = True, alloc_graph_output: bool = True, + alloc_mutable_buffer: bool = True, has_unused_graph_input: bool = False, ) -> Callable[..., None]: # parameterized.expand is not compatible with maketest. I'll just loop thru @@ -282,10 +283,17 @@ def wrapper(self: "TestMemoryPlanning") -> None: )(graph_module).graph_module self.verify_reuse( - graph_module, expect_reuse, alloc_graph_input, alloc_graph_output + graph_module, + expect_reuse, + alloc_graph_input, + alloc_graph_output, + alloc_mutable_buffer, ) self.verify_graph_input_output( - graph_module, alloc_graph_input, alloc_graph_output + graph_module, + alloc_graph_input, + alloc_graph_output, + alloc_mutable_buffer, ) self.verify_overlap_placeholders(has_unused_graph_input, graph_module) @@ -306,6 +314,7 @@ def verify_reuse( expect_reuse: bool, alloc_graph_input: bool, alloc_graph_output: bool, + alloc_mutable_buffer: bool, ) -> None: r""" Do sanity check and verify tensor storage reuse. @@ -321,6 +330,7 @@ def verify_reuse( graph_module, alloc_graph_input=alloc_graph_input, alloc_graph_output=alloc_graph_output, + alloc_mutable_buffers=alloc_mutable_buffer, ).verify_storage_reuse() print(f"num_reuse_pairs is {num_reuse_pairs}") @@ -334,9 +344,10 @@ def verify_graph_input_output( graph_module: torch.fx.GraphModule, alloc_graph_input: bool, alloc_graph_output: bool, + alloc_mutable_buffers: bool, ) -> None: Verifier( - graph_module, alloc_graph_input, alloc_graph_output + graph_module, alloc_graph_input, alloc_graph_output, alloc_mutable_buffers ).verify_graph_input_output() def verify_overlap_placeholders( @@ -404,13 +415,16 @@ def verify_overlap_placeholders( ) def test_graph_input_output(self) -> None: - for alloc_graph_input, alloc_graph_output in itertools.product( - [True, False], [True, False] - ): + for ( + alloc_graph_input, + alloc_graph_output, + alloc_mutable_buffers, + ) in itertools.product([True, False], [True, False], [True, False]): case = maketest( ModelWithDifferentTensorSizes, alloc_graph_input=alloc_graph_input, alloc_graph_output=alloc_graph_output, + alloc_mutable_buffer=alloc_mutable_buffers, ) case(self) @@ -535,6 +549,7 @@ def test_multiple_pools( graph_module, alloc_graph_input=True, alloc_graph_output=True, + alloc_mutable_buffers=True, ) verifier.verify_storage_reuse() verifier.verify_graph_input_output()