From b67b19358c5e0da2a8a6725fb972f5e4e035a70c Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 21 Jan 2025 11:55:57 -0800 Subject: [PATCH] Add pass to extract mutable weights into a .ptd (#7798) Summary: Cleaned up the existing pass and fixed a typing error (EP -> PassResult), added another option in backend config to extract only mutable weights (training workflows will do this), fixed the ordering of ET passes and added a warning not to add stuff after memory planning (this pass was actually fine but in general we like having the invariant that memory planning is last), fixed the emitter to prioritize making it external vs mutable. Down the line we will need to support intermixing of mutable and non mutable in the same .ptd (memory regressions not correctness are the stakes), but no one needs that today so deferring. Reviewed By: lucylq Differential Revision: D68121580 --- exir/capture/_config.py | 4 ++ exir/emit/_emitter.py | 56 +++++++++++++------------- exir/emit/test/test_emit.py | 51 +++++++++++++++++++++++ exir/passes/external_constants_pass.py | 55 ++++++++++++++++++++++--- exir/program/_program.py | 17 ++++++-- 5 files changed, 146 insertions(+), 37 deletions(-) diff --git a/exir/capture/_config.py b/exir/capture/_config.py index a4247579837..c838ab88e48 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -92,3 +92,7 @@ class ExecutorchBackendConfig: # If set to true, all constant tensors will be stored in a separate file, # external to the PTE file. external_constants: bool = False + + # If set to true, all trainable weights will be stored in a separate file, + # external to the PTE file. + external_mutable_weights: bool = False diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index c40a00b2407..562ed145699 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -387,38 +387,36 @@ def _save_new_const_tensor( # Update buffer_idx to point to the end of the list where we are adding the new buffer. buffer = Buffer(storage=buffer_data) - # Tensor is mutable with initial state. - if allocation_info: + # Tensor is stored outside of the PTE file. + if ( + spec.extra_tensor_info is not None + and spec.extra_tensor_info.fully_qualified_name is not None + and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL + ): + assert ( + constant_tag is not None + ), "Constant tag is not set for external tensor" + # TODO (#7633): Handle case where we have both mutable and non mutable weights that we want to put in the same external file. + # We will need to create 2 segments in that case, but it'll be a bit until we see this case. LLM finetuning will probably require this. + + buffer_idx = len(self.program_state.external_constant_buffer) + self.program_state.external_constant_hash[hashed] = buffer_idx + self.program_state.external_constant_buffer.append(buffer_data) + if constant_tag not in self.program_state.external_constant_map: + self.program_state.external_constant_map[constant_tag] = {} + self.program_state.external_constant_map[constant_tag][ + spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. + ] = buffer_idx + # Tensor is mutable with initial state. Place into mutable segment + elif allocation_info: buffer_idx = len(self.program_state.mutable_buffer) self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx self.program_state.mutable_buffer.append(buffer) - - # Tensor is constant. + # Tensor is stored in the PTE file. else: - # Tensor is stored outside of the PTE file. - if ( - spec.extra_tensor_info is not None - and spec.extra_tensor_info.fully_qualified_name is not None - and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL - ): - assert ( - constant_tag is not None - ), "Constant tag is not set for external tensor" - - buffer_idx = len(self.program_state.external_constant_buffer) - self.program_state.external_constant_hash[hashed] = buffer_idx - self.program_state.external_constant_buffer.append(buffer_data) - if constant_tag not in self.program_state.external_constant_map: - self.program_state.external_constant_map[constant_tag] = {} - self.program_state.external_constant_map[constant_tag][ - spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. - ] = buffer_idx - - # Tensor is stored in the PTE file. - else: - buffer_idx = len(self.program_state.constant_buffer) - self.program_state.cached_spec_hash_values[hashed] = buffer_idx - self.program_state.constant_buffer.append(buffer) + buffer_idx = len(self.program_state.constant_buffer) + self.program_state.cached_spec_hash_values[hashed] = buffer_idx + self.program_state.constant_buffer.append(buffer) return buffer_idx @@ -458,7 +456,7 @@ def _tensor_spec_to_evalue( hashed = hashlib.sha256(buffer_data).hexdigest() - if allocation_info: + if allocation_info and spec.extra_tensor_info is None: buffer_idx = self.program_state.cached_spec_mutable_hash_values.get( hashed, -1 ) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 3fca3958feb..117781546c5 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -67,6 +67,7 @@ from torch import nn from torch.export import Dim, export, export_for_training +from torch.export.experimental import _export_forward_backward class WrapperModule(torch.nn.Module): @@ -1733,3 +1734,53 @@ def forward(self, x): self.assertEqual( len(edge_program_manager.executorch_program.backend_delegate_data), 1 ) + + def test_constant_tagged_mutable_tensors(self) -> None: + class Net(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + # On device training requires the loss to be embedded in the model (and be the first output). + # We wrap the original model here and add the loss calculation. This will be the model we export. + class TrainingNet(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.loss = nn.CrossEntropyLoss() + + def forward(self, input, label): + pred = self.net(input) + return self.loss(pred, label), pred.detach().argmax(dim=1) + + net = TrainingNet(Net()) + + # Captures the forward graph. The graph will look similar to the model definition now. + # Will move to export_for_training soon which is the api planned to be supported in the long term. + ep = export( + net, (torch.randn(1, 2), torch.ones(1, dtype=torch.int64)), strict=True + ) + # Captures the backward graph. The exported_program now contains the joint forward and backward graph. + ep = _export_forward_backward(ep) + # Lower the graph to edge dialect. + ep = to_edge(ep) + # Lower the graph to executorch. + ep = ep.to_executorch( + config=ExecutorchBackendConfig(external_mutable_weights=True) + ) + + emitter_output = ep._emitter_output + # Check that constant_buffer is empty besides the non-constant placeholder 0. + self.assertEqual(len(emitter_output.program.constant_buffer), 1) + # Check that constant weights are in the external constant buffer. + self.assertEqual(len(emitter_output.external_constant_buffer), 2) + # Setting external_mutable_weights=True, saves all constants with an associated gradient to the key + # '_default_external_constant'. + external_map = emitter_output.external_constant_map[ + "_default_external_constant" + ] + self.assertEqual(external_map["net.linear.weight"], 0) + self.assertEqual(external_map["net.linear.bias"], 1) diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py index 1429e15cbb1..bc0126a482d 100644 --- a/exir/passes/external_constants_pass.py +++ b/exir/passes/external_constants_pass.py @@ -7,17 +7,20 @@ # pyre-strict import torch +from executorch.exir.pass_base import PassResult from executorch.exir.tensor import TensorSpec -from torch.export.exported_program import ExportedProgram +from torch.export.exported_program import ExportedProgram, OutputKind +from torch.fx import GraphModule def external_constants_pass( - ep: ExportedProgram, -) -> ExportedProgram: + gm: GraphModule, +) -> PassResult: """ Move all constants to external file. """ - for module in ep.graph_module.modules(): + mutated = False + for module in gm.modules(): if not isinstance(module, torch.fx.GraphModule): continue @@ -26,4 +29,46 @@ def external_constants_pass( spec = node.meta.get("spec") if isinstance(spec, TensorSpec) and spec.const: node.meta["constant_tag"] = "_default_external_constant" - return ep + mutated = True + return PassResult(gm, mutated) + + +def _is_mutable_weight(node: torch.fx.Node, ep: ExportedProgram) -> bool: + grad_targets = [ + spec.target + for spec in ep.graph_signature.output_specs + if spec.kind == OutputKind.GRADIENT_TO_PARAMETER + ] + return ( + node.op == "placeholder" + and node.target in ep.graph_signature.inputs_to_parameters.keys() + and ep.graph_signature.inputs_to_parameters[node.target] in grad_targets + ) + + +def external_mutable_weights_pass( + gm: GraphModule, + ep: ExportedProgram, +) -> PassResult: + """ + Move all mutable weights to external file. + """ + # pass the gm and the ep seperately as the gm is being mutated by a bunch of passes in to_executorch, + # so the gm in the ep is lagging the graph signature is still correct. + # This is really tech debt and all the passes should be refactored to just mutate the ep. + mutated = False + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + for node in module.graph.nodes: + if node.op == "placeholder": + spec = node.meta.get("spec") + if ( + isinstance(spec, TensorSpec) + and spec.const + and _is_mutable_weight(node, ep) + ): + node.meta["constant_tag"] = "_default_external_constant" + mutated = True + return PassResult(gm, mutated) diff --git a/exir/program/_program.py b/exir/program/_program.py index e8cee0b5da8..a87400cf7df 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -35,7 +35,10 @@ MemoryFormatOpsPass, OpReplacePass, ) -from executorch.exir.passes.external_constants_pass import external_constants_pass +from executorch.exir.passes.external_constants_pass import ( + external_constants_pass, + external_mutable_weights_pass, +) from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) @@ -1395,6 +1398,14 @@ def to_executorch( # TODO(who?) p.update_placeholder_tensor_specs(program, new_gm) + # Extract constants if the config says too. + if config.external_constants: + new_gm_res = external_constants_pass(new_gm) + new_gm = new_gm_res.graph_module + elif config.external_mutable_weights: + new_gm_res = external_mutable_weights_pass(new_gm, program) + new_gm = new_gm_res.graph_module + if isinstance(config.memory_planning_pass, dict): memory_planning_pass = config.memory_planning_pass.get( name, ExecutorchBackendConfig().memory_planning_pass @@ -1409,8 +1420,8 @@ def to_executorch( else: new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] - if config.external_constants: - new_gm_res = external_constants_pass(new_gm_res) + # WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS. + # THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER. assert new_gm_res is not None new_gm = new_gm_res.graph_module