From 0b2ffea586401251654325cf91236da5f15e48b1 Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 12 Nov 2025 18:23:55 -0800 Subject: [PATCH] Do not restore constant_buffer, and use bytes instead of Buffer All constants are serialized in the segment (none in the Program). This PR: 1. Places constant data into the PTEFile class instead of restoring it into the Program. 2. Use List[bytes] instead of List[Buffer] for constant and mutable data. Buffer was initially used to maintain alignment; now, constants are serialized with alignment in the segment, and Buffer is not required. Update tests. After this, we can mark 'constant_buffer' as deprecated, as it's no longer being used in deserialization or emitter. Differential Revision: [D86913756](https://our.internmc.facebook.com/intern/diff/D86913756/) [ghstack-poisoned] --- .../bundled_program/test/test_bundle_data.py | 19 ++++- exir/_serialize/_program.py | 67 +++++++++-------- exir/_serialize/_serialize.py | 1 + exir/_serialize/test/test_program.py | 32 ++++----- exir/emit/_emit_program.py | 14 +++- exir/emit/_emitter.py | 13 ++-- exir/emit/test/test_emit.py | 71 ++++++++++--------- exir/tests/test_verification.py | 19 +++-- exir/verification/interpreter.py | 19 +++-- 9 files changed, 141 insertions(+), 114 deletions(-) diff --git a/devtools/bundled_program/test/test_bundle_data.py b/devtools/bundled_program/test/test_bundle_data.py index 9fdeb4a776d..35212f98f6a 100644 --- a/devtools/bundled_program/test/test_bundle_data.py +++ b/devtools/bundled_program/test/test_bundle_data.py @@ -70,11 +70,16 @@ def test_bundled_program(self) -> None: method_test_case.expected_outputs, ) + emitter_output = executorch_program._emitter_output self.assertEqual( bundled_program.serialize_to_schema().program, bytes( _serialize_pte_binary( - pte_file=_PTEFile(program=executorch_program.executorch_program) + pte_file=_PTEFile( + program=executorch_program.executorch_program, + constant_data=emitter_output.constant_data, + mutable_data=emitter_output.mutable_data, + ) ) ), ) @@ -116,10 +121,18 @@ def test_bundled_program_from_pte(self) -> None: bundled_program_ioset.expected_outputs, method_test_case.expected_outputs, ) - + emitter_output = executorch_program._emitter_output self.assertEqual( bundled_program.serialize_to_schema().program, - executorch_program.buffer, + bytes( + _serialize_pte_binary( + pte_file=_PTEFile( + program=executorch_program.executorch_program, + constant_data=emitter_output.constant_data, + mutable_data=emitter_output.mutable_data, + ) + ) + ), ) def test_bundled_miss_methods(self) -> None: diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 80374409a71..95dd2f52556 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -33,7 +33,6 @@ from executorch.exir.schema import ( BackendDelegateDataReference, BackendDelegateInlineData, - Buffer, DataLocation, DataSegment, NamedData, @@ -56,9 +55,8 @@ class PTEFile: """ program: Program - # TODO(lfq): add constant data (currently restored in the program) - # TODO(lfq): update this to List[bytes] - mutable_data: Optional[List[Buffer]] = None + constant_data: Optional[List[bytes]] = None + mutable_data: Optional[List[bytes]] = None named_data: Optional[NamedDataStoreOutput] = None @@ -346,14 +344,14 @@ def _extract_delegate_segments( def _extract_constant_segment( - constant_buffer: List[Buffer], + constant_buffer: List[bytes], tensor_alignment: Optional[int] = None, ) -> Tuple[Cord, List[int]]: """Copies the tensors from the provided list into a Cord and tracks the offsets of each tensor. Args: - constant_buffer: list of Buffers from which to extract constants from. Not modified. + constant_buffer: list of bytes from which to extract constants from. Not modified. tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align with this value. Defaults to ALIGNMENT. @@ -365,8 +363,8 @@ def _extract_constant_segment( current_offset: int = 0 for i in range(len(constant_buffer)): buffer = constant_buffer[i] - constant_segment_data.append(buffer.storage) - buffer_length = len(buffer.storage) + constant_segment_data.append(buffer) + buffer_length = len(buffer) pad_length = ( padding_required(buffer_length, tensor_alignment) if tensor_alignment is not None @@ -460,23 +458,22 @@ def serialize_pte_binary( # This may be constant data, delegate data or named data. segments: List[AlignedData] = [] - constant_segment_data, constant_segment_offsets = _extract_constant_segment( - program.constant_buffer, tensor_alignment=constant_tensor_alignment - ) - - # If there are no constants, len(constant_segment_data) = 0. However, there may - # be non-constants, in which case len(constant_segment_offsets) = 1, containing - # the placeholder value 0. Ensure the placeholder value is put into - # program.constant_segment.offsets. - if len(constant_segment_offsets) > 0: - # Update program.constant_segment with constant subsegment offset information. - program.constant_segment = SubsegmentOffsets( - segment_index=len(segments), offsets=constant_segment_offsets + if pte_file.constant_data is not None: + constant_segment_data, constant_segment_offsets = _extract_constant_segment( + pte_file.constant_data, tensor_alignment=constant_tensor_alignment ) - # Clear the constant buffer, as constant data will be stored in segments. - program.constant_buffer = [] - # Add to the aggregate segments cord. - segments.append(AlignedData(constant_segment_data)) + + # If there are no constants, len(constant_segment_data) = 0. However, there may + # be non-constants, in which case len(constant_segment_offsets) = 1, containing + # the placeholder value 0. Ensure the placeholder value is put into + # program.constant_segment.offsets. + if len(constant_segment_offsets) > 0: + # Update program.constant_segment with constant subsegment offset information. + program.constant_segment = SubsegmentOffsets( + segment_index=len(segments), offsets=constant_segment_offsets + ) + # Add to the aggregate segments cord. + segments.append(AlignedData(constant_segment_data)) if pte_file.mutable_data is not None: mutable_segment_data, mutable_segment_offsets = _extract_constant_segment( @@ -637,8 +634,9 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: ) # Replace constants from constant_segment into constant_buffer. + constant_data = None if program.constant_segment and len(program.constant_segment.offsets) > 0: - constant_buffers: List[Buffer] = [] + constant_buffers: List[bytes] = [] constant_segment = segments[program.constant_segment.segment_index] for i in range(len(program.constant_segment.offsets)): start_offset = program.constant_segment.offsets[i] @@ -649,17 +647,15 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: if i < len(program.constant_segment.offsets) - 1 else len(constant_segment) ) - constant_buffers.append( - Buffer(storage=constant_segment[start_offset:end_offset]) - ) - program.constant_buffer = constant_buffers + constant_buffers.append(constant_segment[start_offset:end_offset]) + constant_data = constant_buffers program.constant_segment.segment_index = 0 program.constant_segment.offsets = [] # Extract mutable segments. mutable_data = None if program.mutable_data_segments and len(program.mutable_data_segments.offsets) > 0: - mutable_buffers: List[Buffer] = [] + mutable_buffers: List[bytes] = [] mutable_segment = segments[program.mutable_segment.segment_index] for i in range(len(program.mutable_segments.offsets)): start_offset = program.mutable_segment.offsets[i] @@ -670,9 +666,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: if i < len(program.mutable_segment.offsets) - 1 else len(mutable_segment) ) - mutable_buffers.append( - Buffer(storage=mutable_segment[start_offset:end_offset]) - ) + mutable_buffers.append(mutable_segment[start_offset:end_offset]) mutable_data = mutable_buffers program.mutable_segment.segment_index = 0 program.mutable_segment.offsets = [] @@ -699,7 +693,12 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: named_data = named_data_store.get_named_data_store_output() program.named_data = [] program.segments = [] - return PTEFile(program=program, mutable_data=mutable_data, named_data=named_data) + return PTEFile( + program=program, + constant_data=constant_data, + mutable_data=mutable_data, + named_data=named_data, + ) def deserialize_pte_binary(program_data: bytes) -> PTEFile: diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index 60b6079f4a8..41ec1d84cec 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -49,6 +49,7 @@ def serialize_for_executorch( pte: Cord = serialize_pte_binary( pte_file=PTEFile( program=emitter_output.program, + constant_data=emitter_output.constant_data, mutable_data=emitter_output.mutable_data, named_data=pte_named_data, ), diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 46e8f020a0b..57b1b8c47f9 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -49,12 +49,6 @@ CONSTANT_TENSOR_ALIGNMENT: int = 16 -def add_constant_data(program: Program, blobs: Sequence[bytes]) -> None: - """Adds the provided constant data blobs to the program.""" - for blob in blobs: - program.constant_buffer.append(Buffer(storage=blob)) - - def add_delegate_data( program: Program, plan: ExecutionPlan, blobs: Sequence[bytes] ) -> None: @@ -169,12 +163,14 @@ def constant_segment_with_tensor_alignment( self.gen_blob_data(constant_tensor_alignment, b"\x30\x33\x03"), self.gen_blob_data(constant_tensor_alignment + 1, b"\x40\x44\x04"), ) - add_constant_data(program, blobs) # Extract blobs into constant segment during serialization. pte_data = bytes( serialize_pte_binary( - PTEFile(program=program), + PTEFile( + program=program, + constant_data=blobs, + ), segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=constant_tensor_alignment, ) @@ -289,9 +285,7 @@ def constant_segment_with_tensor_alignment( # during serialization. self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual( - len(deserialized.program.constant_buffer), len(program.constant_buffer) - ) + self.assertEqual(len(deserialized.constant_data), len(blobs)) self.assertEqual(deserialized.mutable_data, None) self.assertEqual(deserialized.named_data, None) @@ -647,12 +641,13 @@ def test_round_trip_with_segments(self) -> None: def test_no_constants(self) -> None: program = get_test_program() - # Insert placeholder for non-const tensors. - add_constant_data(program, [b""]) pte_data = bytes( serialize_pte_binary( - PTEFile(program=program), + PTEFile( + program=program, + constant_data=[b""], # placeholder for non-const tensors. + ), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, @@ -736,7 +731,6 @@ def test_constant_delegate_and_named_data_segments(self) -> None: self.gen_blob_data(SEGMENT_ALIGNMENT // 2, b"\x30\x33\x03"), self.gen_blob_data(SEGMENT_ALIGNMENT + 1, b"\x40\x44\x04"), ) - add_constant_data(program, constant_blobs) add_delegate_data(program, program.execution_plan[0], delegate_blobs) # Create named data segment. @@ -755,7 +749,9 @@ def test_constant_delegate_and_named_data_segments(self) -> None: # Extract the blobs into segments during serialization. pte_data = bytes( serialize_pte_binary( - PTEFile(program=program, named_data=named_data), + PTEFile( + program=program, constant_data=constant_blobs, named_data=named_data + ), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, @@ -933,9 +929,7 @@ def test_constant_delegate_and_named_data_segments(self) -> None: # during serialization. self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual( - len(deserialized.program.constant_buffer), len(program.constant_buffer) - ) + self.assertEqual(len(deserialized.constant_data), len(constant_blobs)) self.assertEqual(deserialized.mutable_data, None) self._check_named_data_store_output(deserialized.named_data, named_data) diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index d25ee3c538b..89875f9f81c 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -50,11 +50,14 @@ class EmitterOutput: # generated by each instruction. instruction_id_to_num_outs_map: Dict[str, Dict[int, int]] - mutable_data: Optional[List[Buffer]] + # Constant data stored in the PTE file. + constant_data: Optional[List[bytes]] + # Mutable data stored in the PTE file. + mutable_data: Optional[List[bytes]] # Constants are optionally stored in external files. # Aggregate unique external constants into one buffer. - external_constant_buffer: List[bytes] + external_constant_buffer: Optional[List[bytes]] # Each constant_tag groups a set of constants together. # {constant_tag: {fqn: index into external_constant_buffer}} external_constant_map: Optional[Dict[str, Dict[str, int]]] @@ -198,7 +201,7 @@ def emit_program( program=Program( version=EXECUTORCH_SCHEMA_VERSION, execution_plan=plans, - constant_buffer=program_state.constant_buffer, + constant_buffer=[], # Do not add constants here anymore. backend_delegate_data=program_state.backend_delegate_data, # Segments may be added at serialization time. segments=[], @@ -206,6 +209,11 @@ def emit_program( constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]), mutable_data_segments=None, # Will be filled in during serialization ), + constant_data=( + program_state.constant_buffer + if len(program_state.constant_buffer) > 0 # Keep the placeholder value. + else None + ), mutable_data=( program_state.mutable_buffer if len(program_state.mutable_buffer) > 1 diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 5c24f08c732..0885109d638 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -117,9 +117,9 @@ class _ProgramState: cached_spec_hash_values: Dict[str, int] = field(default_factory=dict) cached_spec_mutable_hash_values: Dict[str, int] = field(default_factory=dict) # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder. - constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")]) + constant_buffer: List[bytes] = field(default_factory=lambda: [b""]) # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder. - mutable_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")]) + mutable_buffer: List[bytes] = field(default_factory=lambda: [b""]) # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference, # and should be copied to Program.backend_delegate_data. backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list) @@ -398,11 +398,6 @@ def _save_new_const_tensor( """Saves a new constant tensor to the constant buffer and returns the buffer idx""" self.program_state.allocated_specs.append(spec) - # +1 because the first buffer location is reserved. - - # Update buffer_idx to point to the end of the list where we are adding the new buffer. - buffer = Buffer(storage=buffer_data) - # Tensor is stored outside of the PTE file. if ( spec.extra_tensor_info is not None @@ -425,12 +420,12 @@ def _save_new_const_tensor( elif allocation_info: buffer_idx = len(self.program_state.mutable_buffer) self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx - self.program_state.mutable_buffer.append(buffer) + self.program_state.mutable_buffer.append(buffer_data) # Tensor is stored in the PTE file. else: buffer_idx = len(self.program_state.constant_buffer) self.program_state.cached_spec_hash_values[hashed] = buffer_idx - self.program_state.constant_buffer.append(buffer) + self.program_state.constant_buffer.append(buffer_data) return buffer_idx diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 4844088c0c2..b320170b3bb 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1020,7 +1020,8 @@ def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor: "forward_relu": program_relu.exported_program(), "forward_sigmoid": program_sigmoid.exported_program(), } - merged_program = emit_program(exir_input, False).program + merged_emitter_output = emit_program(exir_input, False) + merged_program = merged_emitter_output.program self.assertEqual(len(merged_program.execution_plan), 2) self.assertEqual( @@ -1033,18 +1034,18 @@ def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor: ) # reserved spot, weight, bias self.assertEqual( - len(program_sigmoid._emitter_output.program.constant_buffer), + len(program_sigmoid._emitter_output.constant_data), 3, ) self.assertEqual( - len(program_relu._emitter_output.program.constant_buffer), + len(program_relu._emitter_output.constant_data), 3, ) # sum of the entry points minus 1 because we only have one reserved spot still self.assertEqual( - len(merged_program.constant_buffer), - len(program_sigmoid._emitter_output.program.constant_buffer) - + len(program_relu._emitter_output.program.constant_buffer) + len(merged_emitter_output.constant_data), + len(program_sigmoid._emitter_output.constant_data) + + len(program_relu._emitter_output.constant_data) - 1, ) @@ -1081,20 +1082,21 @@ def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor: "forward_relu": program_relu.exported_program(), "forward_sigmoid": program_sigmoid.exported_program(), } - merged_program = emit_program(exir_input, False).program + merged_emitter_output = emit_program(exir_input, False) + merged_program = merged_emitter_output.program self.assertEqual(len(merged_program.execution_plan), 2) # reserved spot, weight, bias self.assertEqual( - len(program_sigmoid._emitter_output.program.constant_buffer), + len(program_sigmoid._emitter_output.constant_data), 3, ) self.assertEqual( - len(program_relu._emitter_output.program.constant_buffer), + len(program_relu._emitter_output.constant_data), 3, ) # weights are shared between entry points so the merged one should deduplicate everything - self.assertEqual(len(merged_program.constant_buffer), 3) + self.assertEqual(len(merged_emitter_output.constant_data), 3) self._compare_execution_plans( merged_program.execution_plan[0], @@ -1510,12 +1512,13 @@ def forward(self, x): self.assertEqual(model.W1.untyped_storage().nbytes(), 8) self.assertEqual(model.W2.nbytes, 4) self.assertEqual(model.W2.untyped_storage().nbytes(), 8) - program = to_edge(export(model, (torch.ones(1),), strict=True)).to_executorch() - - program = program._emitter_output.program + executorch_program = to_edge( + export(model, (torch.ones(1),), strict=True) + ).to_executorch() + emitter_output = executorch_program._emitter_output # each emitted weight is not a view - self.assertEqual(len(program.constant_buffer[1].storage), 4) - self.assertEqual(len(program.constant_buffer[2].storage), 4) + self.assertEqual(len(emitter_output.constant_data[1]), 4) + self.assertEqual(len(emitter_output.constant_data[2]), 4) def test_non_persistent_buffer(self) -> None: class NonPersistentBuffer(nn.Module): @@ -1527,11 +1530,13 @@ def forward(self, x): return x + self.buf model = NonPersistentBuffer() - program = to_edge(export(model, (torch.ones(1),), strict=True)).to_executorch() - program = program._emitter_output.program + executorch_program = to_edge( + export(model, (torch.ones(1),), strict=True) + ).to_executorch() + emitter_output = executorch_program._emitter_output # confirm that the buffer was emitted - self.assertEqual(len(program.constant_buffer), 2) - self.assertEqual(len(program.constant_buffer[1].storage), 8) + self.assertEqual(len(emitter_output.constant_data), 2) + self.assertEqual(len(emitter_output.constant_data[1]), 8) def test_emit_lifted_tensor_constant(self) -> None: class LiftedTensorConstants(nn.Module): @@ -1545,15 +1550,15 @@ def forward(self, x): model = LiftedTensorConstants() # Specify that we want to move non-lifted constants to external file et_cfg = ExecutorchBackendConfig(external_constants=True) - program = to_edge( + executorch_program = to_edge( export(model, (torch.ones(3, 2),), strict=True) ).to_executorch(et_cfg) - program = program._emitter_output.program - exec_plan = program.execution_plan[0] + emitter_output = executorch_program._emitter_output + exec_plan = emitter_output.program.execution_plan[0] # There should only be 1 input to this model. self.assertEqual(len(exec_plan.inputs), 1) - self.assertEqual(len(program.constant_buffer), 2) - self.assertEqual(len(program.constant_buffer[1].storage), 24) + self.assertEqual(len(emitter_output.constant_data), 2) + self.assertEqual(len(emitter_output.constant_data[1]), 24) def test_emit_lifted_constant(self) -> None: class LiftedConstants(nn.Module): @@ -1567,16 +1572,16 @@ def forward(self, x): model = LiftedConstants() # Specify that we want to move non-lifted constants to external file et_cfg = ExecutorchBackendConfig(external_constants=True) - program = to_edge( + executorch_program = to_edge( export(model, (torch.ones(3, 2),), strict=True) ).to_executorch(et_cfg) - program = program._emitter_output.program - exec_plan = program.execution_plan[0] + emitter_output = executorch_program._emitter_output + exec_plan = emitter_output.program.execution_plan[0] # There should only be 1 input to this model. self.assertEqual(len(exec_plan.inputs), 1) - self.assertEqual(len(program.constant_buffer), 2) - self.assertEqual(len(program.constant_buffer[1].storage), 8) + self.assertEqual(len(emitter_output.constant_data), 2) + self.assertEqual(len(emitter_output.constant_data[1]), 8) def test_mutable_buffers(self) -> None: def count_copies(gm: torch.fx.GraphModule) -> int: @@ -1709,7 +1714,7 @@ def forward(self, x): ) emitter_output = model._emitter_output # Check that constant_buffer is empty besides the non-constant placeholder 0. - self.assertEqual(len(emitter_output.program.constant_buffer), 1) + self.assertEqual(len(emitter_output.constant_data), 1) # Check that constant weights are in the external constant buffer. self.assertEqual(len(emitter_output.external_constant_buffer), 2) # Setting external_constants=True, saves all constants to the key @@ -1742,7 +1747,7 @@ def forward(self, x): ) emitter_output = model._emitter_output # constant_buffer is empty besides the non-constant placeholder 0. - self.assertEqual(len(emitter_output.program.constant_buffer), 1) + self.assertEqual(len(emitter_output.constant_data), 1) # only one item in the external constant buffer. self.assertEqual(len(emitter_output.external_constant_buffer), 1) # Setting external_constants=True, saves all constants to the key @@ -1781,7 +1786,7 @@ def forward(self, x): ) emitter_output = model._emitter_output # constant_buffer is empty besides the non-constant placeholder 0. - self.assertEqual(len(emitter_output.program.constant_buffer), 1) + self.assertEqual(len(emitter_output.constant_data), 1) # Two items in the external constant buffer. self.assertEqual(len(emitter_output.external_constant_buffer), 2) # Setting external_constants=True, saves all constants to the key @@ -1940,7 +1945,7 @@ def forward(self, input, label): emitter_output = ep._emitter_output # Check that constant_buffer is empty besides the non-constant placeholder 0. - self.assertEqual(len(emitter_output.program.constant_buffer), 1) + self.assertEqual(len(emitter_output.constant_data), 1) # Check that constant weights are in the external constant buffer. self.assertEqual(len(emitter_output.external_constant_buffer), 2) # Setting external_mutable_weights=True, saves all constants with an associated gradient to the key diff --git a/exir/tests/test_verification.py b/exir/tests/test_verification.py index 90073216b2d..61d069de360 100644 --- a/exir/tests/test_verification.py +++ b/exir/tests/test_verification.py @@ -10,9 +10,10 @@ import torch from executorch.exir import to_edge + +from executorch.exir._serialize import _PTEFile from executorch.exir.passes.const_prop_pass import ConstPropPass from executorch.exir.schema import Tensor, TensorList - from executorch.exir.verification.interpreter import Interpreter from executorch.exir.verification.verifier import EXIREdgeDialectVerifier from torch._export.verifier import SpecViolationError @@ -34,7 +35,7 @@ def f(x: torch.Tensor) -> torch.Tensor: return torch.ones(2) + x + torch.ones(2) # Generate program - program = ( + emitter_output = ( to_edge(export(WrapperModule(f), (torch.randn(2),), strict=True)) .transform( [ @@ -42,10 +43,16 @@ def f(x: torch.Tensor) -> torch.Tensor: ] ) .to_executorch() - ._emitter_output.program + ._emitter_output ) - test = Interpreter(program) + test = Interpreter( + _PTEFile( + program=emitter_output.program, + constant_data=emitter_output.constant_data, + mutable_data=emitter_output.mutable_data, + ) + ) for val_idx in range(len(test.execution_plan.values)): val = test.execution_plan.values[val_idx].val if not ( @@ -96,7 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # Initialize and test Interpreter -- assert that the operators are same as above - test = Interpreter(program) + test = Interpreter(_PTEFile(program=program)) self.assertEqual( set(test.get_operators_list()), {torch.ops.aten.mul.out, torch.ops.aten.sub.out}, @@ -112,7 +119,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # Initialize and test Interpreter -- assert that the operators are same as above - test = Interpreter(program) + test = Interpreter(_PTEFile(program=program)) self.assertEqual( set(test.get_operators_list()), { diff --git a/exir/verification/interpreter.py b/exir/verification/interpreter.py index fff6a6d79bc..59c6a22ea69 100644 --- a/exir/verification/interpreter.py +++ b/exir/verification/interpreter.py @@ -12,10 +12,10 @@ # pyre-fixme[21]: Could not find module `executorch.exir.verification.bindings`. import executorch.exir.verification.bindings as bindings # @manual=//executorch/exir/verification:bindings import executorch.extension.pytree as ex_pytree - import torch from executorch import exir +from executorch.exir._serialize import _PTEFile from executorch.exir.schema import ( Bool, @@ -126,9 +126,10 @@ def make_operators_list( class Interpreter: - def __init__(self, program: Program) -> None: + def __init__(self, pte_file: _PTEFile) -> None: # Currently there is only 1 execution plan in the list -- this assert will help # catch any changes in the future + program = pte_file.program assert len(program.execution_plan) == 1 self.execution_plan: exir.schema.ExecutionPlan = program.execution_plan[0] self.container_metatype: exir.schema.ContainerMetadata = program.execution_plan[ @@ -137,11 +138,15 @@ def __init__(self, program: Program) -> None: # create buffer in memory and get reference to it # pyre-ignore - self.data_buffers: List[bindings.DataBuffer] = [ - # pyre-ignore - bindings.DataBuffer(b.storage, len(b.storage)) - for b in program.constant_buffer - ] + self.data_buffers: List[bindings.DataBuffer] = ( + [] + if pte_file.constant_data is None + else [ + # pyre-ignore + bindings.DataBuffer(b, len(b)) + for b in pte_file.constant_data + ] + ) # generate the list of values (including tensors) and operators from the execution plan self._value_list: List[ValueType] = [