|
7 | 7 |
|
8 | 8 | import logging
|
9 | 9 |
|
10 |
| -import os |
11 | 10 | from collections import Counter
|
12 | 11 | from pprint import pformat
|
13 | 12 | from typing import (
|
|
47 | 46 | )
|
48 | 47 | from executorch.backends.arm.test.runner_utils import (
|
49 | 48 | dbg_tosa_fb_to_json,
|
50 |
| - get_elf_path, |
51 | 49 | get_output_quantization_params,
|
52 |
| - get_target_board, |
53 |
| - run_target, |
54 | 50 | TosaReferenceModelDispatch,
|
55 | 51 | )
|
56 | 52 |
|
57 | 53 | from executorch.backends.arm.test.tester.analyze_output_utils import (
|
58 | 54 | dump_error_output,
|
59 | 55 | print_error_diffs,
|
60 | 56 | )
|
| 57 | +from executorch.backends.arm.test.tester.serialize import Serialize |
61 | 58 | from executorch.backends.arm.tosa import TosaSpecification
|
62 | 59 | from executorch.backends.arm.tosa.mapping import extract_tensor_meta
|
63 | 60 | from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
|
|
96 | 93 |
|
97 | 94 | from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec
|
98 | 95 | from torch.fx import Graph
|
99 |
| -from torch.utils._pytree import tree_flatten |
100 | 96 |
|
101 | 97 |
|
102 | 98 | logger = logging.getLogger(__name__)
|
@@ -184,44 +180,6 @@ def run(
|
184 | 180 | generate_etrecord=generate_etrecord,
|
185 | 181 | )
|
186 | 182 |
|
187 |
| - |
188 |
| -class Serialize(tester.Serialize): |
189 |
| - def __init__(self, compile_spec: list[CompileSpec], timeout): |
190 |
| - super().__init__() |
191 |
| - self.timeout = timeout |
192 |
| - self.executorch_program_manager: ExecutorchProgramManager | None |
193 |
| - self.compile_spec = compile_spec |
194 |
| - |
195 |
| - def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: |
196 |
| - super().run(artifact, inputs) |
197 |
| - # Keep the entire ExecutorchProgramManager for execution. |
198 |
| - self.executorch_program_manager = artifact |
199 |
| - |
200 |
| - def run_artifact(self, inputs): |
201 |
| - if self.executorch_program_manager is None: |
202 |
| - raise RuntimeError( |
203 |
| - "Tried running artifact from Serialize stage without running the stage." |
204 |
| - ) |
205 |
| - inputs_flattened, _ = tree_flatten(inputs) |
206 |
| - intermediate_path = get_intermediate_path(self.compile_spec) |
207 |
| - target_board = get_target_board(self.compile_spec) |
208 |
| - elf_path = get_elf_path(target_board) |
209 |
| - |
210 |
| - if not os.path.exists(elf_path): |
211 |
| - raise FileNotFoundError( |
212 |
| - f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?" |
213 |
| - ) |
214 |
| - |
215 |
| - return run_target( |
216 |
| - self.executorch_program_manager, |
217 |
| - inputs_flattened, |
218 |
| - intermediate_path, |
219 |
| - target_board, |
220 |
| - elf_path, |
221 |
| - self.timeout, |
222 |
| - ) |
223 |
| - |
224 |
| - |
225 | 183 | class ToExecutorch(tester.ToExecutorch):
|
226 | 184 | def run_artifact(self, inputs):
|
227 | 185 | with TosaReferenceModelDispatch():
|
@@ -423,7 +381,11 @@ def serialize(
|
423 | 381 | self, serialize_stage: Optional[Serialize] = None, timeout: int = 480
|
424 | 382 | ):
|
425 | 383 | if serialize_stage is None:
|
426 |
| - serialize_stage = Serialize(self.compile_spec, timeout) |
| 384 | + serialize_stage = Serialize( |
| 385 | + compile_spec=self.compile_spec, |
| 386 | + module=self.original_module, |
| 387 | + timeout=timeout |
| 388 | + ) |
427 | 389 | assert (
|
428 | 390 | get_intermediate_path(self.compile_spec) is not None
|
429 | 391 | ), "Can't dump serialized file when compile specs do not contain an artifact path."
|
|
0 commit comments