diff --git a/exir/capture/_config.py b/exir/capture/_config.py index b2252e122c9..3fbc8ae7ef3 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -6,7 +6,7 @@ # pyre-unsafe from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch @@ -94,9 +94,14 @@ class ExecutorchBackendConfig: # Moreover, static views will be elided from the ExecuTorch graph remove_view_copy: bool = True - # If set to true, all constant tensors will be stored in a separate file, - # external to the PTE file. - external_constants: bool = False + # Bool: if True, all constant tensors will be stored in a separate file. If False, + # all constant tensors will be stored in the PTE file. + # Callable: a function from torch.fx.Node to Optional[str]. This will be called for each + # placeholder (constant tensor) node, and if it returns a string, that node will be + # tagged with the string. If None, the constant tensor is stored in the PTE file. + # Otherwise, it is stored in a file named by the string. E.g., a function + # lambda x: "model_weights" will save all constants into a file "model_weights.ptd". + external_constants: Union[bool, Callable[[torch.fx.Node], Optional[str]]] = False # If set to true, all trainable weights will be stored in a separate file, # external to the PTE file. diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 4844088c0c2..165bc2951f7 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1717,9 +1717,38 @@ def forward(self, x): external_map = emitter_output.external_constant_map[ "_default_external_constant" ] + self.assertEqual(len(external_map), 2) self.assertEqual(external_map["linear.weight"], 0) self.assertEqual(external_map["linear.bias"], 1) + def test_constant_tagged_tensors_custom(self) -> None: + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + model = to_edge( + export(LinearModule(), (torch.ones(5, 5),), strict=True) + ).to_executorch( + config=ExecutorchBackendConfig( + external_constants=lambda x: ( + "linear_weight" if "weight" in x.name else None + ), + ) + ) + emitter_output = model._emitter_output + # constant_buffer contains placeholder and linear bias. + self.assertEqual(len(emitter_output.program.constant_buffer), 2) + # external constant buffer contains linear weight. + self.assertEqual(len(emitter_output.external_constant_buffer), 1) + # The lambda saves all constants to the key 'linear_weight'. + external_map = emitter_output.external_constant_map["linear_weight"] + self.assertEqual(len(external_map), 1) + self.assertEqual(external_map["linear.weight"], 0) + def test_constant_tagged_tensor_dedup(self) -> None: class ConstantModule(nn.Module): def __init__(self): diff --git a/exir/program/_program.py b/exir/program/_program.py index 72a3cd5e4be..5a96e02082b 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1737,11 +1737,19 @@ def to_executorch( # noqa (FLAKE8) C901 # TODO(who?) p.update_placeholder_tensor_specs(program, new_gm) - # Extract constants if the config says too. - if config.external_constants: + # Tag constant weights. + if ( + isinstance(config.external_constants, bool) + and config.external_constants + ): new_gm_res = external_constants_pass(new_gm) new_gm = new_gm_res.graph_module - elif config.external_mutable_weights: + elif callable(config.external_constants): + new_gm_res = external_constants_pass(new_gm, config.external_constants) + new_gm = new_gm_res.graph_module + + # Tag mutable weights. + if config.external_mutable_weights: new_gm_res = external_mutable_weights_pass(new_gm, program) new_gm = new_gm_res.graph_module