diff --git a/exir/capture/_config.py b/exir/capture/_config.py index a4247579837..c838ab88e48 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -92,3 +92,7 @@ class ExecutorchBackendConfig: # If set to true, all constant tensors will be stored in a separate file, # external to the PTE file. external_constants: bool = False + + # If set to true, all trainable weights will be stored in a separate file, + # external to the PTE file. + external_mutable_weights: bool = False diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index c40a00b2407..562ed145699 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -387,38 +387,36 @@ def _save_new_const_tensor( # Update buffer_idx to point to the end of the list where we are adding the new buffer. buffer = Buffer(storage=buffer_data) - # Tensor is mutable with initial state. - if allocation_info: + # Tensor is stored outside of the PTE file. + if ( + spec.extra_tensor_info is not None + and spec.extra_tensor_info.fully_qualified_name is not None + and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL + ): + assert ( + constant_tag is not None + ), "Constant tag is not set for external tensor" + # TODO (#7633): Handle case where we have both mutable and non mutable weights that we want to put in the same external file. + # We will need to create 2 segments in that case, but it'll be a bit until we see this case. LLM finetuning will probably require this. + + buffer_idx = len(self.program_state.external_constant_buffer) + self.program_state.external_constant_hash[hashed] = buffer_idx + self.program_state.external_constant_buffer.append(buffer_data) + if constant_tag not in self.program_state.external_constant_map: + self.program_state.external_constant_map[constant_tag] = {} + self.program_state.external_constant_map[constant_tag][ + spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. + ] = buffer_idx + # Tensor is mutable with initial state. Place into mutable segment + elif allocation_info: buffer_idx = len(self.program_state.mutable_buffer) self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx self.program_state.mutable_buffer.append(buffer) - - # Tensor is constant. + # Tensor is stored in the PTE file. else: - # Tensor is stored outside of the PTE file. - if ( - spec.extra_tensor_info is not None - and spec.extra_tensor_info.fully_qualified_name is not None - and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL - ): - assert ( - constant_tag is not None - ), "Constant tag is not set for external tensor" - - buffer_idx = len(self.program_state.external_constant_buffer) - self.program_state.external_constant_hash[hashed] = buffer_idx - self.program_state.external_constant_buffer.append(buffer_data) - if constant_tag not in self.program_state.external_constant_map: - self.program_state.external_constant_map[constant_tag] = {} - self.program_state.external_constant_map[constant_tag][ - spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. - ] = buffer_idx - - # Tensor is stored in the PTE file. - else: - buffer_idx = len(self.program_state.constant_buffer) - self.program_state.cached_spec_hash_values[hashed] = buffer_idx - self.program_state.constant_buffer.append(buffer) + buffer_idx = len(self.program_state.constant_buffer) + self.program_state.cached_spec_hash_values[hashed] = buffer_idx + self.program_state.constant_buffer.append(buffer) return buffer_idx @@ -458,7 +456,7 @@ def _tensor_spec_to_evalue( hashed = hashlib.sha256(buffer_data).hexdigest() - if allocation_info: + if allocation_info and spec.extra_tensor_info is None: buffer_idx = self.program_state.cached_spec_mutable_hash_values.get( hashed, -1 ) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 3fca3958feb..117781546c5 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -67,6 +67,7 @@ from torch import nn from torch.export import Dim, export, export_for_training +from torch.export.experimental import _export_forward_backward class WrapperModule(torch.nn.Module): @@ -1733,3 +1734,53 @@ def forward(self, x): self.assertEqual( len(edge_program_manager.executorch_program.backend_delegate_data), 1 ) + + def test_constant_tagged_mutable_tensors(self) -> None: + class Net(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + # On device training requires the loss to be embedded in the model (and be the first output). + # We wrap the original model here and add the loss calculation. This will be the model we export. + class TrainingNet(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.loss = nn.CrossEntropyLoss() + + def forward(self, input, label): + pred = self.net(input) + return self.loss(pred, label), pred.detach().argmax(dim=1) + + net = TrainingNet(Net()) + + # Captures the forward graph. The graph will look similar to the model definition now. + # Will move to export_for_training soon which is the api planned to be supported in the long term. + ep = export( + net, (torch.randn(1, 2), torch.ones(1, dtype=torch.int64)), strict=True + ) + # Captures the backward graph. The exported_program now contains the joint forward and backward graph. + ep = _export_forward_backward(ep) + # Lower the graph to edge dialect. + ep = to_edge(ep) + # Lower the graph to executorch. + ep = ep.to_executorch( + config=ExecutorchBackendConfig(external_mutable_weights=True) + ) + + emitter_output = ep._emitter_output + # Check that constant_buffer is empty besides the non-constant placeholder 0. + self.assertEqual(len(emitter_output.program.constant_buffer), 1) + # Check that constant weights are in the external constant buffer. + self.assertEqual(len(emitter_output.external_constant_buffer), 2) + # Setting external_mutable_weights=True, saves all constants with an associated gradient to the key + # '_default_external_constant'. + external_map = emitter_output.external_constant_map[ + "_default_external_constant" + ] + self.assertEqual(external_map["net.linear.weight"], 0) + self.assertEqual(external_map["net.linear.bias"], 1) diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py index 1429e15cbb1..bc0126a482d 100644 --- a/exir/passes/external_constants_pass.py +++ b/exir/passes/external_constants_pass.py @@ -7,17 +7,20 @@ # pyre-strict import torch +from executorch.exir.pass_base import PassResult from executorch.exir.tensor import TensorSpec -from torch.export.exported_program import ExportedProgram +from torch.export.exported_program import ExportedProgram, OutputKind +from torch.fx import GraphModule def external_constants_pass( - ep: ExportedProgram, -) -> ExportedProgram: + gm: GraphModule, +) -> PassResult: """ Move all constants to external file. """ - for module in ep.graph_module.modules(): + mutated = False + for module in gm.modules(): if not isinstance(module, torch.fx.GraphModule): continue @@ -26,4 +29,46 @@ def external_constants_pass( spec = node.meta.get("spec") if isinstance(spec, TensorSpec) and spec.const: node.meta["constant_tag"] = "_default_external_constant" - return ep + mutated = True + return PassResult(gm, mutated) + + +def _is_mutable_weight(node: torch.fx.Node, ep: ExportedProgram) -> bool: + grad_targets = [ + spec.target + for spec in ep.graph_signature.output_specs + if spec.kind == OutputKind.GRADIENT_TO_PARAMETER + ] + return ( + node.op == "placeholder" + and node.target in ep.graph_signature.inputs_to_parameters.keys() + and ep.graph_signature.inputs_to_parameters[node.target] in grad_targets + ) + + +def external_mutable_weights_pass( + gm: GraphModule, + ep: ExportedProgram, +) -> PassResult: + """ + Move all mutable weights to external file. + """ + # pass the gm and the ep seperately as the gm is being mutated by a bunch of passes in to_executorch, + # so the gm in the ep is lagging the graph signature is still correct. + # This is really tech debt and all the passes should be refactored to just mutate the ep. + mutated = False + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + for node in module.graph.nodes: + if node.op == "placeholder": + spec = node.meta.get("spec") + if ( + isinstance(spec, TensorSpec) + and spec.const + and _is_mutable_weight(node, ep) + ): + node.meta["constant_tag"] = "_default_external_constant" + mutated = True + return PassResult(gm, mutated) diff --git a/exir/program/_program.py b/exir/program/_program.py index e8cee0b5da8..a87400cf7df 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -35,7 +35,10 @@ MemoryFormatOpsPass, OpReplacePass, ) -from executorch.exir.passes.external_constants_pass import external_constants_pass +from executorch.exir.passes.external_constants_pass import ( + external_constants_pass, + external_mutable_weights_pass, +) from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) @@ -1395,6 +1398,14 @@ def to_executorch( # TODO(who?) p.update_placeholder_tensor_specs(program, new_gm) + # Extract constants if the config says too. + if config.external_constants: + new_gm_res = external_constants_pass(new_gm) + new_gm = new_gm_res.graph_module + elif config.external_mutable_weights: + new_gm_res = external_mutable_weights_pass(new_gm, program) + new_gm = new_gm_res.graph_module + if isinstance(config.memory_planning_pass, dict): memory_planning_pass = config.memory_planning_pass.get( name, ExecutorchBackendConfig().memory_planning_pass @@ -1409,8 +1420,8 @@ def to_executorch( else: new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] - if config.external_constants: - new_gm_res = external_constants_pass(new_gm_res) + # WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS. + # THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER. assert new_gm_res is not None new_gm = new_gm_res.graph_module