From 4ad792ebdde8dd57d023d13c94bfe691457afb38 Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 4 Dec 2024 22:18:30 -0800 Subject: [PATCH] [executorch][emitter] Emit FQNs Emit FQNs for external tensors. In the emitter, store external tensors as: ``` // list of unique tensors external_constants_buffer: List[bytes] // map of {constant_tag: {fqn: index into external_constant_buffer}} // constant_tag: may want to save multiple external constant files; group them together via the tag. // {fqn: index}; there may be multiple fqns pointing to the same data buffer. This is for deduplication. external_constants_map: [Dict[str, Dict[str, int]] ``` Differential Revision: [D66523226](https://our.internmc.facebook.com/intern/diff/D66523226/) [ghstack-poisoned] --- exir/emit/_emit_program.py | 9 ++++++++ exir/emit/_emitter.py | 47 ++++++++++++++++++++++++++++++++++++-- exir/schema.py | 2 +- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index 9c8c9dfd067..f9571143a1b 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -47,6 +47,13 @@ class EmitterOutput: mutable_data: Optional[List[Buffer]] + # Constants are optionally stored in external files. + # Aggregate unique external constants into one buffer. + external_constant_buffer: List[bytes] + # Each constant_tag groups a set of constants together. + # {constant_tag: {fqn: index into external_constant_buffer}} + external_constant_map: Optional[Dict[str, Dict[str, int]]] + def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule: gm = exported_program.graph_module @@ -199,4 +206,6 @@ def emit_program( if len(program_state.mutable_buffer) > 1 else None ), + external_constant_buffer=program_state.external_constant_buffer, + external_constant_map=program_state.external_constant_map, ) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 165cad7fd07..897fbdcf811 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -62,6 +62,7 @@ DoubleList, EValue, ExecutionPlan, + ExtraTensorInfo, FreeCall, Instruction, Int, @@ -120,6 +121,14 @@ class _ProgramState: # and should be copied to Program.backend_delegate_data. backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list) + # Constants are optionally stored in external files. + # Aggregate unique external constants into one buffer. + external_constant_buffer: List[bytes] = field(default_factory=list) + external_constant_hash: Dict[str, int] = field(default_factory=dict) + # Each constant_tag groups a set of constants together. + # {constant_tag: {fqn: index into external_constant_buffer}} + external_constant_map: Dict[str, Dict[str, int]] = field(default_factory=dict) + @dataclass class _EmitterState: @@ -328,7 +337,9 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue: ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}" ) - def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: + def _tensor_spec_to_evalue( + self, spec: TensorSpec, constant_tag: Optional[str] = None + ) -> EValue: """Constructs an EValue from the given TensorSpec.""" allocation_info = None @@ -389,6 +400,8 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: buffer_idx = self.program_state.cached_spec_mutable_hash_values.get( hashed, -1 ) + elif spec.location == DataLocation.EXTERNAL: + buffer_idx = self.program_state.external_constant_hash.get(hashed, -1) else: buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) @@ -405,6 +418,23 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: buffer_idx ) self.program_state.mutable_buffer.append(buffer) + + # Constant tensor, stored in external file. + elif spec.location == DataLocation.EXTERNAL: + assert ( + spec.extra_tensor_info is not None + and spec.extra_tensor_info.fully_qualified_name is not None + ), "Fully qualified name is not set for external tensor" + buffer_idx = len(self.program_state.external_constant_buffer) + self.program_state.external_constant_hash[hashed] = buffer_idx + self.program_state.external_constant_buffer.append(buffer_data) + if constant_tag: + if constant_tag not in self.program_state.external_constant_map: + self.program_state.external_constant_map[constant_tag] = {} + self.program_state.external_constant_map[constant_tag][ + spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. + ] = buffer_idx + # Constant tensor, stored in PTE. else: buffer_idx = len(self.program_state.constant_buffer) self.program_state.cached_spec_hash_values[hashed] = buffer_idx @@ -1539,11 +1569,24 @@ def placeholder( https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder """ spec = self.node.meta["spec"] + constant_tag = self.node.meta.get("constant_tag", None) is_user_input = True if isinstance(target, str) and isinstance(spec, TensorSpec): fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) + # If the placeholder has a constant_tag, it is external to the PTE file + # and requires a fqn and location=DataLocation.EXTERNAL + if constant_tag is not None: + assert ( + fqn is not None + ), "constant tagged tensors require a fully qualified name" + if spec.extra_tensor_info is None: + spec.extra_tensor_info = ExtraTensorInfo(fully_qualified_name=fqn) + else: + spec.extra_tensor_info.fully_qualified_name = fqn + spec.location = DataLocation.EXTERNAL + # From the fqn find the corresponding tensor real_tensor = None if fqn in self.exported_program.state_dict: @@ -1581,7 +1624,7 @@ def placeholder( spec.const = not (is_user_input or is_mutable_buffer) evalue = ( - self._tensor_spec_to_evalue(spec) + self._tensor_spec_to_evalue(spec, constant_tag) if isinstance(spec, TensorSpec) else self._constant_to_evalue(spec, None) ) diff --git a/exir/schema.py b/exir/schema.py index cbcd140b258..7a6064c5f62 100644 --- a/exir/schema.py +++ b/exir/schema.py @@ -49,7 +49,7 @@ class ExtraTensorInfo: Check program.fbs for explanations of this enum. """ - mutable_data_segments_idx: Optional[int] = None + mutable_data_segments_idx: Optional[int] = 0 fully_qualified_name: Optional[str] = None