From 311809cf4b1a198171c5c1ae6274aea77c103a75 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 10 Sep 2025 13:02:56 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- backends/arm/operators/op_abs.py | 4 +--- backends/arm/operators/op_sum.py | 8 ++------ examples/arm/aot_arm_compiler.py | 24 +----------------------- 3 files changed, 4 insertions(+), 32 deletions(-) diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index ec76eb5517f..625293d66e0 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -73,9 +73,7 @@ def define_node( abs_output = output # Do the INT32 Abs - self._serialize_operator( - node, - tosa_graph, + tosa_graph.addOperator( ts.TosaOp.Op().ABS, [ rescaled_inputs[0].name, diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 00676d9f9b3..0bd152a8b8c 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -67,9 +67,7 @@ def define_node( dtype=ts.DType.INT32, ) - self._serialize_operator( - node, - tosa_graph, + tosa_graph.addOperator( ts.TosaOp.Op().REDUCE_SUM, [rescaled_inputs[0].name], [intermediate.name], @@ -113,9 +111,7 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.ReduceSumAttribute(tensor.dim_order.index(dim)) - self._serialize_operator( - node, - tosa_graph, + tosa_graph.addOperator( ts.TosaOp.Op().REDUCE_SUM, [tensor.name], [output.name], diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 106ab35363c..8132751f6f0 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -18,7 +18,6 @@ import torch from examples.devtools.scripts.export_bundled_program import save_bundled_program -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner from executorch.backends.arm.quantizer import ( EthosUQuantizer, @@ -387,7 +386,6 @@ def get_compile_spec( memory_mode: Optional[str] = None, quantize: bool = False, config: Optional[str] = None, - debug_mode: Optional[str] = None, ) -> TosaCompileSpec | EthosUCompileSpec | VgfCompileSpec: compile_spec = None if target.startswith("TOSA"): @@ -416,10 +414,6 @@ def get_compile_spec( if intermediates is not None: compile_spec.dump_intermediate_artifacts_to(intermediates) - if debug_mode is not None: - mode = ArmCompileSpec.DebugMode[debug_mode.upper()] - compile_spec.dump_debug_info(mode) - return compile_spec @@ -607,12 +601,6 @@ def get_args(): action="store_true", help="Enable the QuantizedOpFusionPass fusion step", ) - parser.add_argument( - "--enable_debug_mode", - required=False, - choices=["json", "tosa"], - help="Flag to enable ATen-to-TOSA debug mode.", - ) args = parser.parse_args() if args.evaluate and ( @@ -747,7 +735,6 @@ def to_edge_TOSA_delegate( args.memory_mode, args.quantize, args.config, - args.enable_debug_mode, ) model_int8 = None @@ -789,7 +776,6 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ args.memory_mode, args.quantize, args.config, - args.enable_debug_mode, ) model, exported_program = quantize_model( args, model, example_inputs, compile_spec @@ -838,21 +824,12 @@ def transform_for_cortex_m_backend(edge, args): exported_program = torch.export.export( model, example_inputs, strict=args.strict_export ) - model = exported_program.module() model_fp32 = model - model_name = os.path.basename(os.path.splitext(args.model_name)[0]) if args.intermediates: os.makedirs(args.intermediates, exist_ok=True) - # We only support Python3.10 and above, so use a later pickle protocol - torch.export.save( - exported_program, - f"{args.intermediates}/{model_name}_exported_program.pt2", - pickle_protocol=5, - ) - # Quantize if required model_int8 = None if args.delegate: @@ -885,6 +862,7 @@ def transform_for_cortex_m_backend(edge, args): else: raise e + model_name = os.path.basename(os.path.splitext(args.model_name)[0]) output_name = f"{model_name}" + ( f"_arm_delegate_{args.target}" if args.delegate is True From 693e713d3fc29907adabb31fbc179933c0ef5d8a Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 10 Sep 2025 13:03:01 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- backends/arm/arm_backend.py | 245 ++++++++++++++++++ backends/arm/common/arm_compile_spec.py | 195 -------------- backends/arm/debug/schema.py | 6 +- backends/arm/ethosu/__init__.py | 6 +- backends/arm/ethosu/compile_spec.py | 101 -------- backends/arm/ethosu/partitioner.py | 18 +- backends/arm/operators/node_visitor.py | 4 +- backends/arm/quantizer/arm_quantizer.py | 52 ++-- backends/arm/runtime/VelaBinStream.cpp | 2 +- backends/arm/runtime/VelaBinStream.h | 4 +- .../arm/scripts/TOSA_minimal_example.ipynb | 25 +- backends/arm/test/common.py | 193 ++++++++++---- backends/arm/test/misc/test_compile_spec.py | 50 ---- backends/arm/test/misc/test_debug_feats.py | 6 +- backends/arm/test/misc/test_debug_hook.py | 6 +- .../test/misc/test_extract_io_params_tosa.py | 22 +- backends/arm/test/misc/test_outputs_order.py | 11 +- backends/arm/test/ops/test_add.py | 10 +- backends/arm/test/runner_utils.py | 37 ++- .../arm/test/tester/analyze_output_utils.py | 3 +- backends/arm/test/tester/arm_tester.py | 70 ++--- backends/arm/test/tester/test_pipeline.py | 15 +- backends/arm/tosa/backend.py | 6 +- backends/arm/tosa/compile_spec.py | 25 -- backends/arm/tosa/partitioner.py | 20 +- backends/arm/vgf/__init__.py | 6 +- backends/arm/vgf/compile_spec.py | 66 ----- backends/arm/vgf/partitioner.py | 18 +- examples/arm/aot_arm_compiler.py | 49 ++-- examples/arm/ethos_u_minimal_example.ipynb | 11 +- examples/arm/vgf_minimal_example.ipynb | 19 +- 31 files changed, 637 insertions(+), 664 deletions(-) create mode 100644 backends/arm/arm_backend.py delete mode 100644 backends/arm/common/arm_compile_spec.py delete mode 100644 backends/arm/ethosu/compile_spec.py delete mode 100644 backends/arm/test/misc/test_compile_spec.py delete mode 100644 backends/arm/tosa/compile_spec.py delete mode 100644 backends/arm/vgf/compile_spec.py diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py new file mode 100644 index 00000000000..2e71f91dbb6 --- /dev/null +++ b/backends/arm/arm_backend.py @@ -0,0 +1,245 @@ +# Copyright 2023-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# +# Main implementation of AoT flow to partition and preprocess for Arm target +# backends. Converts via TOSA as an intermediate form supported by AoT and +# JIT compiler flows. +# +from enum import Enum +from typing import List, Optional + +from executorch.backends.arm.tosa import TosaSpecification + +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) + + +class ArmCompileSpecBuilder: + class DebugMode(Enum): + JSON = 1 + TOSA = 2 + + def __init__(self): + self.compile_spec: List[CompileSpec] = [] + self.compiler_flags = [] + self.output_format = None + self.path_for_intermediates = None + self.tosa_spec = None + self.tosa_debug_mode = None + + def vgf_compile_spec( + self, + tosa_spec: TosaSpecification = None, # type: ignore[assignment] + compiler_flags: Optional[str] = "", + ) -> "ArmCompileSpecBuilder": + """ + Generate compile spec for VGF compatible targets + + Args: + compiler_flags: Extra compiler flags for converter_backend + """ + self.output_format = "vgf" + self.compiler_flags = [ + compiler_flags, + ] + + if tosa_spec is None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + + tosa_version = tosa_spec.version # type: ignore[attr-defined] + tosa_profiles = tosa_spec.profiles # type: ignore[attr-defined] + + if tosa_version.major != 1: + raise ValueError( + "Arm backend only supports converter-backend for TOSA version 1. " + f"Invalid TOSA version: {tosa_version}" + ) + + if "FP" not in tosa_profiles and "INT" not in tosa_profiles: + raise ValueError( + "Arm backend only supports converter-backend for FP or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + if len(tosa_profiles) != 1: + raise ValueError( + "For now Arm backend only supports converter-backend for either FP or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + self.tosa_spec = tosa_spec + + return self + + def ethosu_compile_spec( + self, + target: str, + system_config: Optional[str] = None, + memory_mode: Optional[str] = None, + extra_flags: Optional[str] = None, + config_ini: Optional[str] = "Arm/vela.ini", + ) -> "ArmCompileSpecBuilder": + """ + Generate compile spec for Ethos-U NPU + + Args: + target: Ethos-U accelerator configuration, e.g. ethos-u55-128 + system_config: System configuration to select from the Vel + configuration file + memory_mode: Memory mode to select from the Vela configuration file + extra_flags: Extra flags for the Vela compiler + config_ini: Vela configuration file(s) in Python ConfigParser .ini + file format + """ + assert ( + self.output_format is None + ), f"Output format already set to f{self.output_format}" + self.output_format = "vela" + self.compiler_flags = [ + f"--accelerator-config={target}", + f"--config={config_ini}", + ] + + # default system config and memory mode + if "ethos-u55" in target: + if system_config is None: + system_config = "Ethos_U55_High_End_Embedded" + if memory_mode is None: + memory_mode = "Shared_Sram" + elif "ethos-u85" in target: + if system_config is None: + system_config = "Ethos_U85_SYS_DRAM_Mid" + if memory_mode is None: + memory_mode = "Sram_Only" + else: + raise RuntimeError(f"Unknown ethos target: {target}") + + if system_config is not None: + self.compiler_flags.append(f"--system-config={system_config}") + if memory_mode is not None: + self.compiler_flags.append(f"--memory-mode={memory_mode}") + if extra_flags is not None: + self.compiler_flags.append(extra_flags) + + # We require raw output and regor, so add these flags if absent. This + # overrides any other output setting. + self.compiler_flags.append("--output-format=raw") + self.compiler_flags.append("--debug-force-regor") + + base_tosa_version = "TOSA-1.0+INT+int16" + if "u55" in target: + # Add the Ethos-U55 extension marker + base_tosa_version += "+u55" + self.tosa_spec = TosaSpecification.create_from_string(base_tosa_version) + + return self + + def tosa_compile_spec( + self, tosa_spec: str | TosaSpecification + ) -> "ArmCompileSpecBuilder": + """ + Generate compile spec for TOSA flatbuffer output + """ + assert ( + self.output_format is None + ), f"Output format already set: {self.output_format}" + self.output_format = "tosa" + if isinstance(tosa_spec, TosaSpecification): + self.tosa_spec = tosa_spec + elif isinstance(tosa_spec, str): + self.tosa_spec = TosaSpecification.create_from_string(tosa_spec) + else: + raise RuntimeError(f"Invalid type for {tosa_spec}!") + return self + + def dump_intermediate_artifacts_to( + self, output_path: str + ) -> "ArmCompileSpecBuilder": + """ + Sets a path for dumping intermediate results during such as tosa and pte. + """ + self.path_for_intermediates = output_path + return self + + def dump_debug_info(self, debug_mode: DebugMode) -> "ArmCompileSpecBuilder": + """ + Dump debugging information into the intermediates path + """ + self.tosa_debug_mode = debug_mode.name + return self + + def build(self) -> List[CompileSpec]: + """ + Generate a list of compile spec objects from the builder + """ + assert self.tosa_spec + + # Always supply a TOSA version + self.compile_spec = [CompileSpec("tosa_spec", str(self.tosa_spec).encode())] + + # Add compile flags, these are backend specific, refer to the backend + # documentation. + self.compile_spec += [ + CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()), + ] + + # encode output format + self.compile_spec.append( + CompileSpec("output_format", self.output_format.encode()) + ) + + if self.path_for_intermediates is not None: + self.compile_spec.append( + CompileSpec("debug_artifact_path", self.path_for_intermediates.encode()) + ) + + if self.tosa_debug_mode is not None: + if not self.path_for_intermediates: + raise ValueError( + "dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()" + ) + + self.compile_spec.append( + CompileSpec("dump_debug_info", self.tosa_debug_mode.encode()) + ) + + return self.compile_spec + + +def is_tosa(compile_spec: List[CompileSpec]) -> bool: + has_tosa_output = False + has_tosa_spec = False + for spec in compile_spec: + if spec.key == "output_format": + has_tosa_output = spec.value.decode() == "tosa" + if spec.key == "tosa_spec": + has_tosa_spec = True + + return has_tosa_output and has_tosa_spec + + +def is_ethosu(compile_spec: List[CompileSpec]) -> bool: + for spec in compile_spec: + if spec.key == "output_format": + return spec.value.decode() == "vela" + return False + + +def is_vgf(compile_spec: List[CompileSpec]) -> bool: + for spec in compile_spec: + if spec.key == "output_format": + return spec.value.decode() == "vgf" + return False + + +def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]: + for spec in compile_spec: + if spec.key == "debug_artifact_path": + return spec.value.decode() + return None diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py deleted file mode 100644 index c6818e2716a..00000000000 --- a/backends/arm/common/arm_compile_spec.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -# -# Main implementation of AoT flow to partition and preprocess for Arm target -# backends. Converts via TOSA as an intermediate form supported by AoT and -# JIT compiler flows. -# - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from enum import Enum - -from executorch.backends.arm.tosa import TosaSpecification - -from executorch.exir.backend.compile_spec_schema import CompileSpec - - -@dataclass(init=False) -class ArmCompileSpec(ABC): - class DebugMode(Enum): - JSON = 1 - TOSA = 2 - - tosa_spec: TosaSpecification - compiler_flags: list[str] = field(default_factory=list) - path_for_intermediates: str | None = None - tosa_debug_mode: DebugMode | None = None - - _TOSA_SPEC_KEY = "tosa_spec" - _COMPILE_FLAGS_KEY = "compile_flags" - _OUTPUT_FORMAT_KEY = "output_format" - _DEBUG_ARTIFACT_KEY = "debug_artifact_path" - _DEBUG_MODE_KEY = "dump_debug_info" - - def _set_compile_specs( - self, - tosa_spec: TosaSpecification, - compiler_flags: list[str], - path_for_intermediates: str | None = None, - tosa_debug_mode: DebugMode | None = None, - ): - """Set all values of dataclass directly.""" - self.tosa_spec = tosa_spec - self.compiler_flags = compiler_flags - self.path_for_intermediates = path_for_intermediates - self.tosa_debug_mode = tosa_debug_mode - - @classmethod - def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 - tosa_spec: TosaSpecification | None = None - output_format: str | None = None - compiler_flags: list[str] | None = None - path_for_intermediates: str | None = None - tosa_debug_mode: ArmCompileSpec.DebugMode | None = None - unknown_specs: dict[str, str] = {} - for spec in compile_specs: - key = spec.key - val = spec.value.decode() - if key == ArmCompileSpec._TOSA_SPEC_KEY: - if tosa_spec is not None: - raise ValueError("More than one tosa_spec entry in compile spec.") - tosa_spec = TosaSpecification.create_from_string(val) - elif key == ArmCompileSpec._COMPILE_FLAGS_KEY: - if compiler_flags is not None: - raise ValueError( - "More than one compiler flags entry in compile spec." - ) - compiler_flags = val.split(" ") - elif key == ArmCompileSpec._OUTPUT_FORMAT_KEY: - if output_format is not None: - raise ValueError( - "More than one output format entry in compile spec." - ) - output_format = val - elif key == ArmCompileSpec._DEBUG_ARTIFACT_KEY: - if path_for_intermediates is not None: - raise ValueError( - "More than one debug artifact path entry in compile spec." - ) - path_for_intermediates = val - elif key == ArmCompileSpec._DEBUG_MODE_KEY: - if tosa_debug_mode is not None: - raise ValueError( - "More than one tosa_debug_mode entry in compile spec." - ) - tosa_debug_mode = ArmCompileSpec.DebugMode[val] - else: - unknown_specs[key] = val - - if tosa_spec is None: - raise ValueError("No tosa_spec in compile spec.") - if output_format is None: - raise ValueError("No output_format in compile spec.") - if output_format != cls.get_output_format(): - raise ValueError( - f"Incorrect output format '{output_format}' for {cls.__name__}, expected '{cls.get_output_format()}'" - ) - if compiler_flags is None: - compiler_flags = [] - - # Create new object from class, but bypass __init__ and use _set_compile_specs instead. - compile_spec = cls.__new__(cls) - compile_spec._set_compile_specs( - tosa_spec=tosa_spec, - compiler_flags=compiler_flags, - path_for_intermediates=path_for_intermediates, - tosa_debug_mode=tosa_debug_mode, - ) - cls.from_list_hook(compile_spec, unknown_specs) - compile_spec.validate() - return compile_spec - - @classmethod - def from_list_hook(cls, compile_spec, specs: dict[str, str]): # noqa: B027 - """Allows subclasses to hook into parsing compile spec lists.""" - pass - - @abstractmethod - def validate(self): - """Throws an error if the compile spec is not valid.""" - - def to_list(self): - """Get the ArmCompileSpec in list form.""" - assert self.tosa_spec - - # Always supply a TOSA version - compile_spec = [ - CompileSpec(ArmCompileSpec._TOSA_SPEC_KEY, str(self.tosa_spec).encode()) - ] - - # Add compile flags, these are backend specific, refer to the backend - # documentation. - if len(self.compiler_flags) > 0: - compile_spec += [ - CompileSpec( - ArmCompileSpec._COMPILE_FLAGS_KEY, - " ".join(self.compiler_flags).encode(), - ), - ] - - # Add output format to identify kind of compile spec. - compile_spec.append( - CompileSpec( - ArmCompileSpec._OUTPUT_FORMAT_KEY, self.get_output_format().encode() - ) - ) - - if self.path_for_intermediates is not None: - compile_spec.append( - CompileSpec( - ArmCompileSpec._DEBUG_ARTIFACT_KEY, - self.path_for_intermediates.encode(), - ) - ) - - if self.tosa_debug_mode is not None: - if not self.path_for_intermediates: - raise ValueError( - "dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()" - ) - - compile_spec.append( - CompileSpec( - ArmCompileSpec._DEBUG_MODE_KEY, self.tosa_debug_mode.name.encode() - ) - ) - - return compile_spec - - def get_intermediate_path(self) -> str | None: - return self.path_for_intermediates - - def dump_intermediate_artifacts_to(self, output_path: str | None): - """ - Sets a path for dumping intermediate results during such as tosa and pte. - """ - self.path_for_intermediates = output_path - return self - - def dump_debug_info(self, debug_mode: DebugMode | None): - """ - Dump debugging information into the intermediates path - """ - self.tosa_debug_mode = debug_mode - return self - - @classmethod - @abstractmethod - def get_output_format(cls) -> str: - """Returns a constant string that is the output format of the class.""" diff --git a/backends/arm/debug/schema.py b/backends/arm/debug/schema.py index 46742a8ce61..82f0fd6bf7e 100644 --- a/backends/arm/debug/schema.py +++ b/backends/arm/debug/schema.py @@ -13,7 +13,7 @@ import serializer.tosa_serializer as ts # type: ignore import torch -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from torch.fx.traceback import NodeSource @@ -112,7 +112,7 @@ def to_dict(self) -> dict[str, Any]: class DebugHook: - def __init__(self, debug_mode: ArmCompileSpec.DebugMode) -> None: + def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None: self._debug_events: list[DebugSchema] = [] self.__op_id_to_name = {} self.mode = debug_mode @@ -126,7 +126,7 @@ def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema # If the debug data is being embedded into the TOSA flatbuffer # do not collect TOSADebugSchema data, it's redundent - if self.mode != ArmCompileSpec.DebugMode.TOSA: + if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA: tosa_debug_info = TosaDebugSchema( node_name=str(tosa_op), operator_name=self.__op_id_to_name[tosa_op_id], diff --git a/backends/arm/ethosu/__init__.py b/backends/arm/ethosu/__init__.py index 25a91dc5929..f6cc1329dfe 100644 --- a/backends/arm/ethosu/__init__.py +++ b/backends/arm/ethosu/__init__.py @@ -6,7 +6,9 @@ # pyre-unsafe from .backend import EthosUBackend # noqa: F401 -from .compile_spec import EthosUCompileSpec # noqa: F401 from .partitioner import EthosUPartitioner # noqa: F401 -__all__ = ["EthosUBackend", "EthosUPartitioner", "EthosUCompileSpec"] +__all__ = [ + "EthosUBackend", + "EthosUPartitioner", +] diff --git a/backends/arm/ethosu/compile_spec.py b/backends/arm/ethosu/compile_spec.py deleted file mode 100644 index 5f3f92fdd0e..00000000000 --- a/backends/arm/ethosu/compile_spec.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec - -from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] - TosaSpecification, -) - -from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] - CompileSpec, -) - - -class EthosUCompileSpec(ArmCompileSpec): - - _TARGET_KEY = "target" - - def __init__( - self, - target: str, - system_config: str | None = None, - memory_mode: str | None = None, - extra_flags: list[str] | None = None, - config_ini: str | None = "Arm/vela.ini", - ): - """Generate compile spec for Ethos-U NPU - Args: - target: Ethos-U accelerator configuration, e.g. ethos-u55-128 - system_config: System configuration to select from the Vela - configuration file - memory_mode: Memory mode to select from the Vela configuration file - extra_flags: Extra flags for the Vela compiler - config_ini: Vela configuration file(s) in Python ConfigParser .ini - file format - """ - self.target = target - - # Set vela compiler flags - if config_ini is None: - config_ini = "Arm/vela.ini" - compiler_flags = [] if extra_flags is None else extra_flags - compiler_flags.extend( - [ - f"--accelerator-config={target}", - f"--config={config_ini}", - "--output-format=raw", - "--debug-force-regor", - ] - ) - # default system config and memory mode - if "ethos-u55" in self.target: - if system_config is None: - system_config = "Ethos_U55_High_End_Embedded" - if memory_mode is None: - memory_mode = "Shared_Sram" - elif "ethos-u85" in self.target: - if system_config is None: - system_config = "Ethos_U85_SYS_DRAM_Mid" - if memory_mode is None: - memory_mode = "Sram_Only" - else: - raise RuntimeError(f"Unknown ethos target: {self.target}") - - compiler_flags.append(f"--system-config={system_config}") - compiler_flags.append(f"--memory-mode={memory_mode}") - - # Set TOSA version. - base_tosa_version = "TOSA-1.0+INT+int16" - if "u55" in self.target: - # Add the Ethos-U55 extension marker - base_tosa_version += "+u55" - tosa_spec = TosaSpecification.create_from_string(base_tosa_version) - - self._set_compile_specs(tosa_spec, compiler_flags) - self.validate() - - def to_list(self): - compile_specs = super().to_list() - compile_specs.append(CompileSpec(self._TARGET_KEY, self.target.encode())) - return compile_specs - - @classmethod - def from_list_hook(cls, compile_spec, specs: dict[str, str]): - compile_spec.target = specs.get(cls._TARGET_KEY, None) - - def validate(self): - if len(self.compiler_flags) == 0: - raise ValueError( - "compile_flags are required in the CompileSpec list for EthosUBackend" - ) - if "u55" in self.target and not self.tosa_spec.is_U55_subset: - raise ValueError( - f"Target was {self.target} but tosa spec was not u55 subset." - ) - - @classmethod - def get_output_format(cls) -> str: - return "vela" diff --git a/backends/arm/ethosu/partitioner.py b/backends/arm/ethosu/partitioner.py index d2fad094c03..d76b29eb1d9 100644 --- a/backends/arm/ethosu/partitioner.py +++ b/backends/arm/ethosu/partitioner.py @@ -5,10 +5,14 @@ # pyre-unsafe -from typing import final, Optional, Sequence +from typing import final, List, Optional, Sequence -from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec +from executorch.backends.arm.arm_backend import ( + is_ethosu, +) # usort: skip +from executorch.backends.arm.ethosu import EthosUBackend from executorch.backends.arm.tosa.partitioner import TOSAPartitioner +from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import DelegationSpec from torch.fx.passes.operator_support import OperatorSupportBase @@ -17,12 +21,12 @@ class EthosUPartitioner(TOSAPartitioner): def __init__( self, - compile_spec: EthosUCompileSpec, + compile_spec: List[CompileSpec], additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: + if not is_ethosu(compile_spec): + raise RuntimeError("compile spec is not targeting Ethos-U") + # Override the delegation spec for Ethos-U - self.delegation_spec = DelegationSpec( - EthosUBackend.__name__, compile_spec.to_list() - ) + self.delegation_spec = DelegationSpec(EthosUBackend.__name__, compile_spec) self.additional_checks = additional_checks - self.tosa_spec = compile_spec.tosa_spec diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 172adbc7c78..54a81bdaaff 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -10,7 +10,7 @@ import torch -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.specification import TosaSpecification @@ -59,7 +59,7 @@ def _serialize_operator( tosa_op_id=tosa_op, ) - if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA: + if self.debug_hook.mode == ArmCompileSpecBuilder.DebugMode.TOSA: op_location = json.dumps(debug_info.to_dict()) tosa_graph.addOperator( diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index e6240a08c8e..ae7c8255428 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -14,17 +14,21 @@ from __future__ import annotations import functools -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch -from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.quantizer import QuantizationConfig from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.common.arm_compile_spec import ( - ArmCompileSpec, -) # isort: skip -from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.backends.arm.tosa.specification import get_tosa_spec + +from .arm_quantizer_utils import is_annotated, mark_node_as_annotated +from .quantization_annotator import annotate_graph +from executorch.backends.arm.arm_backend import ( + is_ethosu, + is_vgf, +) # usort: skip +from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.fx import GraphModule, Node from torchao.quantization.pt2e import ( @@ -45,9 +49,6 @@ Quantizer, ) -from .arm_quantizer_utils import is_annotated, mark_node_as_annotated -from .quantization_annotator import annotate_graph - __all__ = [ "TOSAQuantizer", "EthosUQuantizer", @@ -299,16 +300,27 @@ def not_module_type_or_name_filter(n: Node) -> bool: class TOSAQuantizer(Quantizer): def __init__( - self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec + self, compile_spec_or_tosa_spec: Union[TosaSpecification, List[CompileSpec]] ) -> None: super().__init__() if isinstance(compile_spec_or_tosa_spec, TosaSpecification): self.tosa_spec = compile_spec_or_tosa_spec self.compile_spec = None - elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): + elif isinstance(compile_spec_or_tosa_spec, list): self.compile_spec = compile_spec_or_tosa_spec - self.tosa_spec = self.compile_spec.tosa_spec + # find entry that is 'tosa_spec' + for cs in compile_spec_or_tosa_spec: + if cs.key == "tosa_spec": + spec_val = ( + cs.value.decode() if isinstance(cs.value, bytes) else cs.value + ) + self.tosa_spec = TosaSpecification.create_from_string(spec_val) + break + else: + raise ValueError( + "compile_spec list did not contain a 'tosa_spec' entry" + ) else: raise TypeError( f"TOSAQuantizer constructor expects " @@ -454,10 +466,18 @@ def validate(self, model: GraphModule) -> None: class EthosUQuantizer(TOSAQuantizer): - def __init__(self, compile_spec: EthosUCompileSpec) -> None: - super().__init__(compile_spec) + def __init__(self, compile_spec: list[CompileSpec]) -> None: + if not is_ethosu(compile_spec): + raise RuntimeError("compile spec is not targeting Ethos-U") + + tosa_spec = get_tosa_spec(compile_spec) + super().__init__(tosa_spec) class VgfQuantizer(TOSAQuantizer): - def __init__(self, compile_spec: VgfCompileSpec) -> None: - super().__init__(compile_spec) + def __init__(self, compile_spec: list[CompileSpec]) -> None: + if not is_vgf(compile_spec): + raise RuntimeError("compile spec is not targeting VGF") + + tosa_spec = get_tosa_spec(compile_spec) + super().__init__(tosa_spec) diff --git a/backends/arm/runtime/VelaBinStream.cpp b/backends/arm/runtime/VelaBinStream.cpp index c8d568499c9..180219c75b5 100644 --- a/backends/arm/runtime/VelaBinStream.cpp +++ b/backends/arm/runtime/VelaBinStream.cpp @@ -6,7 +6,7 @@ */ /* - * Warning: Do not change this without changing arm_vela.py::vela_compile + * Warning: Do not change this without changing arm_backend.py::vela_compile * as that function emits this format and the two need to align. */ diff --git a/backends/arm/runtime/VelaBinStream.h b/backends/arm/runtime/VelaBinStream.h index 7a7ea9b6266..04b8b2ada00 100644 --- a/backends/arm/runtime/VelaBinStream.h +++ b/backends/arm/runtime/VelaBinStream.h @@ -1,5 +1,5 @@ /* - * Copyright 2023-2025 Arm Limited and/or its affiliates. + * Copyright 2023-2024 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -7,7 +7,7 @@ /* * Minimal reading function for vela_bin_stream wire format. This is an - * implementation detail of the arm backend AoT flow and ArmBackendEthosU + * implementation detail of the arm_backend AoT flow and ArmBackendEthosU * and subject to change. * This format captures the command stream, I/O and memory layout data to * enable execution of the command stream on Ethos-U hardware. diff --git a/backends/arm/scripts/TOSA_minimal_example.ipynb b/backends/arm/scripts/TOSA_minimal_example.ipynb index b79780c6a07..785affc657b 100644 --- a/backends/arm/scripts/TOSA_minimal_example.ipynb +++ b/backends/arm/scripts/TOSA_minimal_example.ipynb @@ -86,7 +86,10 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n", + "from executorch.backends.arm.arm_backend import (\n", + " ArmCompileSpecBuilder,\n", + ")\n", + "from executorch.backends.arm.tosa.specification import TosaSpecification\n", "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "from pathlib import Path\n", "\n", @@ -96,7 +99,11 @@ "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", "# Dump intermediate artifacts (in this case TOSA flat buffers) to specified location\n", - "compile_spec = TosaCompileSpec(target).dump_intermediate_artifacts_to(str(cwd_dir / base_name))\n", + "tosa_spec = TosaSpecification.create_from_string(target)\n", + "spec_builder = (ArmCompileSpecBuilder()\n", + " .tosa_compile_spec(tosa_spec)\n", + " .dump_intermediate_artifacts_to(str(cwd_dir / base_name)))\n", + "compile_spec = spec_builder.build()\n", "\n", "_ = graph_module.print_readable()\n", "\n", @@ -123,11 +130,15 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n", + "from executorch.backends.arm.arm_backend import (\n", + " ArmCompileSpecBuilder,\n", + " get_tosa_spec,\n", + ")\n", "from executorch.backends.arm.quantizer import (\n", " TOSAQuantizer,\n", " get_symmetric_quantization_config,\n", ")\n", + "from executorch.backends.arm.tosa.specification import TosaSpecification\n", "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "from pathlib import Path\n", "\n", @@ -137,10 +148,14 @@ "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", "# Dump intermediate artifacts (in this case TOSA flat buffers) to specified location\n", - "compile_spec = TosaCompileSpec(target).dump_intermediate_artifacts_to(str(cwd_dir / base_name))\n", + "tosa_spec = TosaSpecification.create_from_string(target)\n", + "spec_builder = (ArmCompileSpecBuilder()\n", + " .tosa_compile_spec(tosa_spec)\n", + " .dump_intermediate_artifacts_to(str(cwd_dir / base_name)))\n", + "compile_spec = spec_builder.build()\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", - "quantizer = TOSAQuantizer(compile_spec)\n", + "quantizer = TOSAQuantizer(get_tosa_spec(compile_spec))\n", "operator_config = get_symmetric_quantization_config()\n", "quantizer.set_global(operator_config)\n", "\n", diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 963084d6091..608c273b2ef 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -13,7 +13,7 @@ from typing import Any, Optional import pytest -from executorch.backends.arm.ethosu import EthosUCompileSpec +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.test.runner_utils import ( arm_executor_runner_exists, corstone300_installed, @@ -22,8 +22,7 @@ vkml_emulation_layer_installed, ) from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec -from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.exir.backend.compile_spec_schema import CompileSpec def get_time_formatted_path(path: str, log_prefix: str) -> str: @@ -65,21 +64,43 @@ def maybe_get_tosa_collate_path() -> str | None: def get_tosa_compile_spec( tosa_spec: str | TosaSpecification, - custom_path=None, - tosa_debug_mode: TosaCompileSpec.DebugMode | None = None, -) -> TosaCompileSpec: - """Get the compile spec for default TOSA tests.""" + custom_path: Optional[str] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, +) -> list[CompileSpec]: + """ + Default compile spec for TOSA tests. + """ + return get_tosa_compile_spec_unbuilt( + tosa_spec, + custom_path, + tosa_debug_mode, + ).build() + + +def get_tosa_compile_spec_unbuilt( + tosa_spec: str | TosaSpecification, + custom_path: Optional[str], + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], +) -> ArmCompileSpecBuilder: + """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify + the compile spec before calling .build() to finalize it. + """ if not custom_path: custom_path = maybe_get_tosa_collate_path() + if custom_path is not None: os.makedirs(custom_path, exist_ok=True) - compile_spec = ( - TosaCompileSpec(tosa_spec) + compile_spec_builder = ( + ArmCompileSpecBuilder() + .tosa_compile_spec(tosa_spec) .dump_intermediate_artifacts_to(custom_path) - .dump_debug_info(tosa_debug_mode) ) - return compile_spec + + if tosa_debug_mode is not None: + compile_spec_builder.dump_debug_info(tosa_debug_mode) + + return compile_spec_builder def get_u55_compile_spec( @@ -88,10 +109,72 @@ def get_u55_compile_spec( memory_mode: str = "Shared_Sram", extra_flags: str = "--debug-force-regor --output-format=raw", custom_path: Optional[str] = None, - config: Optional[str] = None, - tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, -) -> EthosUCompileSpec: - """Default compile spec for Ethos-U55 tests.""" + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + config: Optional[str] = "Arm/vela.ini", +) -> list[CompileSpec]: + """ + Compile spec for Ethos-U55. + """ + return get_u55_compile_spec_unbuilt( + macs=macs, + system_config=system_config, + memory_mode=memory_mode, + extra_flags=extra_flags, + custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, + config=config, + ).build() + + +def get_u85_compile_spec( + macs: int = 128, + system_config: str = "Ethos_U85_SYS_DRAM_Mid", + memory_mode: str = "Shared_Sram", + extra_flags: str = "--output-format=raw", + custom_path: Optional[str] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + config: Optional[str] = "Arm/vela.ini", +) -> list[CompileSpec]: + """ + Compile spec for Ethos-U85. + """ + return get_u85_compile_spec_unbuilt( # type: ignore[attr-defined] + macs=macs, + system_config=system_config, + memory_mode=memory_mode, + extra_flags=extra_flags, + custom_path=custom_path, + tosa_debug_mode=tosa_debug_mode, + config=config, + ).build() + + +def get_vgf_compile_spec( + tosa_spec: str | TosaSpecification, + compiler_flags: Optional[str] = "", + custom_path: Optional[str] = "", + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, +) -> list[CompileSpec]: + """ + Default compile spec for VGF tests. + """ + return get_vgf_compile_spec_unbuilt( + tosa_spec, compiler_flags, custom_path, tosa_debug_mode + ).build() + + +def get_u55_compile_spec_unbuilt( + macs: int, + system_config: str, + memory_mode: str, + extra_flags: str, + custom_path: Optional[str], + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], + config: Optional[str], +) -> ArmCompileSpecBuilder: + """Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify + the compile spec before calling .build() to finalize it. + """ artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u55_") if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) @@ -99,67 +182,67 @@ def get_u55_compile_spec( # https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md assert macs in [32, 64, 128, 256], "Unsupported MACs value" - if extra_flags is not None: - extra_flags_list = extra_flags.split(" ") - else: - extra_flags_list = [] compile_spec = ( - EthosUCompileSpec( + ArmCompileSpecBuilder() + .ethosu_compile_spec( f"ethos-u55-{macs}", system_config=system_config, memory_mode=memory_mode, - extra_flags=extra_flags_list, + extra_flags=extra_flags, config_ini=config, ) .dump_intermediate_artifacts_to(artifact_path) - .dump_debug_info(tosa_debug_mode) ) - return compile_spec + if tosa_debug_mode is not None: + compile_spec.dump_debug_info(tosa_debug_mode) + + return compile_spec -def get_u85_compile_spec( - macs: int = 128, - system_config="Ethos_U85_SYS_DRAM_Mid", - memory_mode="Shared_Sram", - extra_flags="--output-format=raw", - custom_path: Optional[str] = None, - config: Optional[str] = None, - tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, -) -> EthosUCompileSpec: - """Default compile spec for Ethos-U85 tests.""" +def get_u85_compile_spec_unbuilt( + macs: int, + system_config: str, + memory_mode: str, + extra_flags: str, + custom_path: Optional[str], + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], + config: Optional[str], +) -> list[CompileSpec]: + """Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify + the compile spec before calling .build() to finalize it. + """ artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u85_") if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) assert macs in [128, 256, 512, 1024, 2048], "Unsupported MACs value" - if extra_flags is not None: - extra_flags_list = extra_flags.split(" ") - else: - extra_flags_list = [] - compile_spec = ( - EthosUCompileSpec( + ArmCompileSpecBuilder() + .ethosu_compile_spec( f"ethos-u85-{macs}", system_config=system_config, memory_mode=memory_mode, - extra_flags=extra_flags_list, + extra_flags=extra_flags, config_ini=config, ) .dump_intermediate_artifacts_to(artifact_path) - .dump_debug_info(tosa_debug_mode) ) + + if tosa_debug_mode is not None: + compile_spec.dump_debug_info(tosa_debug_mode) + return compile_spec # type: ignore[return-value] -def get_vgf_compile_spec( +def get_vgf_compile_spec_unbuilt( tosa_spec: str | TosaSpecification, - compiler_flags: Optional[str] = "", - custom_path=None, - tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, -) -> VgfCompileSpec: - """Get the ArmCompileSpec for the default VGF tests, to modify + compiler_flags: Optional[str], + custom_path: Optional[str], + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], +) -> ArmCompileSpecBuilder: + """Get the ArmCompileSpecBuilder for the default VGF tests, to modify the compile spec before calling .build() to finalize it. """ if "FP" in repr(tosa_spec): @@ -172,18 +255,16 @@ def get_vgf_compile_spec( if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) - if compiler_flags is not None: - compiler_flags_list = compiler_flags.split(" ") - else: - compiler_flags_list = [] - - compile_spec = ( - VgfCompileSpec(tosa_spec, compiler_flags_list) + compile_spec_builder = ( + ArmCompileSpecBuilder() + .vgf_compile_spec(tosa_spec, compiler_flags) .dump_intermediate_artifacts_to(artifact_path) - .dump_debug_info(tosa_debug_mode) ) - return compile_spec + if tosa_debug_mode is not None: + compile_spec_builder.dump_debug_info(tosa_debug_mode) + + return compile_spec_builder XfailIfNoCorstone300 = pytest.mark.xfail( diff --git a/backends/arm/test/misc/test_compile_spec.py b/backends/arm/test/misc/test_compile_spec.py deleted file mode 100644 index a1b42cd22b5..00000000000 --- a/backends/arm/test/misc/test_compile_spec.py +++ /dev/null @@ -1,50 +0,0 @@ -from executorch.backends.arm.ethosu import EthosUCompileSpec -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec -from executorch.backends.arm.vgf import VgfCompileSpec -from pytest import raises - - -def test_ethos_u_compile_spec(): - compile_spec = ( - EthosUCompileSpec("ethos-u55", extra_flags=["--my-flag"]) - .dump_intermediate_artifacts_to("my_path") - .dump_debug_info(EthosUCompileSpec.DebugMode.TOSA) - ) - spec_list = compile_spec.to_list() - - assert EthosUCompileSpec.from_list(spec_list) == compile_spec - assert "--my-flag" in compile_spec.compiler_flags - assert "--output-format=raw" in compile_spec.compiler_flags - with raises(ValueError, match="Incorrect output format"): - VgfCompileSpec.from_list(spec_list) - - spec_list.pop(0) - with raises(ValueError, match="No tosa_spec in compile spec."): - EthosUCompileSpec.from_list(spec_list) - - -def test_vgf_compile_spec(): - compile_spec = ( - VgfCompileSpec(compiler_flags=["--my-flag"]) - .dump_intermediate_artifacts_to("my_path") - .dump_debug_info(None) - ) - compile_spec2 = VgfCompileSpec( - compiler_flags=["--my-flag2"] - ).dump_intermediate_artifacts_to("my_path") - - spec_list = compile_spec.to_list() - - assert VgfCompileSpec.from_list(spec_list) == compile_spec - assert VgfCompileSpec.from_list(spec_list) != compile_spec2 - with raises(ValueError, match="Incorrect output format"): - EthosUCompileSpec.from_list(spec_list) - - -def test_tosa_compile_spec(): - compile_spec = TosaCompileSpec("TOSA-1.0+INT") - spec_list = compile_spec.to_list() - - assert TosaCompileSpec.from_list(spec_list) == compile_spec - with raises(ValueError, match="Incorrect output format"): - VgfCompileSpec.from_list(spec_list) diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 3796d3dce4a..3e10a9336f9 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -14,7 +14,7 @@ import pytest import torch -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -200,7 +200,7 @@ def test_dump_tosa_debug_json(test_data: input_t1): aten_op=[], exir_op=[], custom_path=tmpdir, - tosa_debug_mode=ArmCompileSpec.DebugMode.JSON, + tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.JSON, ) pipeline.pop_stage("run_method_and_compare_outputs") @@ -231,7 +231,7 @@ def test_dump_tosa_debug_tosa(test_data: input_t1): aten_op=[], exir_op=[], custom_path=tmpdir, - tosa_debug_mode=ArmCompileSpec.DebugMode.TOSA, + tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.TOSA, ) pipeline.pop_stage("run_method_and_compare_outputs") diff --git a/backends/arm/test/misc/test_debug_hook.py b/backends/arm/test/misc/test_debug_hook.py index 376c65ff093..935f3984403 100644 --- a/backends/arm/test/misc/test_debug_hook.py +++ b/backends/arm/test/misc/test_debug_hook.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from types import SimpleNamespace -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.debug.schema import DebugHook, DebugSchema from executorch.backends.arm.test import common @@ -158,7 +158,7 @@ def _compare_node_and_schema(debug_event: DebugSchema, mocked_node): @common.parametrize("test_data", TESTCASES) def test_debug_hook_add_json(test_data: DebugHookTestCase): - hook = DebugHook(ArmCompileSpec.DebugMode.JSON) + hook = DebugHook(ArmCompileSpecBuilder.DebugMode.JSON) hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) debug_events = hook._debug_events @@ -171,7 +171,7 @@ def test_debug_hook_add_json(test_data: DebugHookTestCase): @common.parametrize("test_data", TESTCASES) def test_debug_hook_add_tosa(test_data: DebugHookTestCase): - hook = DebugHook(ArmCompileSpec.DebugMode.TOSA) + hook = DebugHook(ArmCompileSpecBuilder.DebugMode.TOSA) hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) debug_events = hook._debug_events diff --git a/backends/arm/test/misc/test_extract_io_params_tosa.py b/backends/arm/test/misc/test_extract_io_params_tosa.py index 90104c54899..da471b0bb74 100644 --- a/backends/arm/test/misc/test_extract_io_params_tosa.py +++ b/backends/arm/test/misc/test_extract_io_params_tosa.py @@ -7,6 +7,7 @@ import pytest import torch +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer import VgfQuantizer from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, @@ -14,9 +15,9 @@ ) from executorch.backends.arm.test.common import SkipIfNoModelConverter -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner +from executorch.backends.arm.vgf import VgfPartitioner from executorch.exir import to_edge_transform_and_lower from executorch.exir.passes.quantize_io_pass import extract_io_quant_params from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -28,11 +29,11 @@ def forward(self, x, y): @pytest.mark.parametrize( - "compile_spec_cls, quantizer_cls, partitioner_cls", + "builder_method, quantizer_cls, partitioner_cls", [ - (TosaCompileSpec, TOSAQuantizer, TOSAPartitioner), + ("tosa_compile_spec", TOSAQuantizer, TOSAPartitioner), pytest.param( - VgfCompileSpec, + "vgf_compile_spec", VgfQuantizer, VgfPartitioner, marks=SkipIfNoModelConverter, @@ -40,11 +41,7 @@ def forward(self, x, y): ), ], ) -def test_roundtrip_extracts_io_params( - compile_spec_cls: type[TosaCompileSpec] | type[VgfCompileSpec], - quantizer_cls, - partitioner_cls, -): +def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner_cls): """ Validates that IO quantization parameters round-trip for both flows. """ @@ -54,7 +51,10 @@ def test_roundtrip_extracts_io_params( ) mod = SimpleAdd().eval() - compile_spec = compile_spec_cls("TOSA-1.0+INT") + base_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + compile_spec = getattr(ArmCompileSpecBuilder(), builder_method)( + tosa_spec=base_spec + ).build() quantizer = quantizer_cls(compile_spec) operator_config = get_symmetric_quantization_config(is_qat=True) diff --git a/backends/arm/test/misc/test_outputs_order.py b/backends/arm/test/misc/test_outputs_order.py index ff02ffc360a..43d35b6d13c 100644 --- a/backends/arm/test/misc/test_outputs_order.py +++ b/backends/arm/test/misc/test_outputs_order.py @@ -9,11 +9,11 @@ import pytest import torch +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, TOSAQuantizer, ) -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.exir import to_edge_transform_and_lower @@ -81,7 +81,7 @@ def test_network_output_order_and_restore(tmp_path, batch_size): model = Network(batch_norm=True).eval() # Prepare spec spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - compile_spec = TosaCompileSpec(tosa_spec=spec) + compile_spec = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=spec).build() # Setup quantizer quantizer = TOSAQuantizer(compile_spec) quantizer.set_global( @@ -89,7 +89,7 @@ def test_network_output_order_and_restore(tmp_path, batch_size): ) # Trace the model dummy = torch.randn(batch_size, 1, 28, 28) - fx_mod = torch.export.export(model, (dummy,)).module() + fx_mod = torch.export.export_for_training(model, (dummy,)).module() model = prepare_pt2e(fx_mod, quantizer) model(dummy) model = convert_pt2e(model) @@ -98,7 +98,10 @@ def test_network_output_order_and_restore(tmp_path, batch_size): with tempfile.TemporaryDirectory() as tmpdir: art_dir = Path(tmpdir) part = TOSAPartitioner( - TosaCompileSpec(spec).dump_intermediate_artifacts_to(str(art_dir)) + ArmCompileSpecBuilder() + .tosa_compile_spec(spec) + .dump_intermediate_artifacts_to(str(art_dir)) + .build() ) _ = to_edge_transform_and_lower(aten_gm, partitioner=[part]) # Expect exactly one .tosa file in the artefact dir diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 24fdfbb5457..2eabd302df6 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -5,7 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import cast, Tuple +from typing import Tuple import pytest import torch @@ -23,6 +23,7 @@ VgfPipeline, ) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.specification import get_tosa_spec from executorch.backends.xnnpack.test.tester import Quantize from torchao.quantization.pt2e import HistogramObserver from torchao.quantization.pt2e.quantizer import QuantizationSpec @@ -102,13 +103,14 @@ def test_add_tensor_tosa_INT(test_data: input_t1): @common.parametrize("test_data", Add.test_data) def test_add_tensor_tosa_INT_i32(test_data: input_t1): pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op) - tosa_version = cast(str, conftest.get_option("tosa_version")) + tosa_version = conftest.get_option("tosa_version") tosa_profiles = { "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"), } # Create a quantizer with int8 quantization on the input and output but int32 on everything else. - quantizer = arm_quantizer.TOSAQuantizer(tosa_profiles[tosa_version]) - + quantizer = arm_quantizer.TOSAQuantizer( + get_tosa_spec(common.get_tosa_compile_spec(tosa_profiles[tosa_version])) + ) quantizer.set_io(arm_quantizer.get_symmetric_quantization_config()) observer_options = {"eps": 2**-16} observer = HistogramObserver.with_args(**observer_options) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 1b59b186a2e..aeb0e3a56bd 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -17,14 +17,16 @@ import numpy as np import torch -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.ethosu import EthosUCompileSpec +from executorch.backends.arm.arm_backend import is_tosa, is_vgf from executorch.backends.arm.test.conftest import is_option_enabled -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec -from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification -from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.backends.arm.tosa.specification import ( + get_tosa_spec, + Tosa_1_00, + TosaSpecification, +) from executorch.exir import ExecutorchProgramManager, ExportedProgram +from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule from torch.fx.node import Node @@ -166,9 +168,14 @@ def __init__(self): def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs): tosa_buffer = lowered_backend_module.processed_bytes - compile_spec = TosaCompileSpec.from_list(lowered_backend_module.compile_specs) + compile_specs = lowered_backend_module.compile_specs + if not is_tosa(compile_specs): + raise RuntimeError( + "Model needs to be compiled to tosa to run reference model." + ) + tosa_spec = get_tosa_spec(compile_specs) - return run_tosa_graph(tosa_buffer, compile_spec.tosa_spec, inputs) + return run_tosa_graph(tosa_buffer, tosa_spec, inputs) def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) @@ -718,12 +725,14 @@ def run_tosa_graph( return [torch.from_numpy(output) for output in outputs_np] -def get_target_board(compile_spec: ArmCompileSpec) -> str | None: - if isinstance(compile_spec, VgfCompileSpec): +def get_target_board(compile_spec: list[CompileSpec]) -> str | None: + if is_vgf(compile_spec): return "vkml_emulation_layer" - if isinstance(compile_spec, EthosUCompileSpec): - if "u55" in compile_spec.target: - return "corstone-300" - if "u85" in compile_spec.target: - return "corstone-320" + for spec in compile_spec: + if spec.key == "compile_flags": + flags = spec.value.decode() + if "u55" in flags: + return "corstone-300" + elif "u85" in flags: + return "corstone-320" return None diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index c707eed8013..82d4f5d9837 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -7,6 +7,7 @@ import tempfile import torch +from executorch.backends.arm.arm_backend import get_intermediate_path from executorch.backends.arm.test.runner_utils import ( get_input_quantization_params, get_output_quantization_params, @@ -244,7 +245,7 @@ def dump_error_output( # Capture assertion error and print more info banner = "=" * 40 + "TOSA debug info" + "=" * 40 logger.error(banner) - path_to_tosa_files = tester.compile_spec.get_intermediate_path() + path_to_tosa_files = get_intermediate_path(tester.compile_spec) if path_to_tosa_files is None: path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_") diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 284d4d6d1c4..fe17bd3f448 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -32,8 +32,13 @@ from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner +from executorch.backends.arm.arm_backend import ( + get_intermediate_path, + is_ethosu, + is_tosa, + is_vgf, +) +from executorch.backends.arm.ethosu import EthosUPartitioner from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -54,11 +59,11 @@ print_error_diffs, ) from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.mapping import extract_tensor_meta from executorch.backends.arm.tosa.partitioner import TOSAPartitioner +from executorch.backends.arm.tosa.specification import get_tosa_spec -from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner +from executorch.backends.arm.vgf import VgfPartitioner from executorch.backends.test.harness.stages import Stage, StageType from executorch.backends.xnnpack.test.tester import Tester @@ -72,6 +77,7 @@ to_edge_transform_and_lower, ) from executorch.exir.backend.backend_api import validation_disabled +from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.operator_support import ( DontPartition, DontPartitionModule, @@ -125,7 +131,7 @@ def get_output_format(lowered_module) -> str | None: to_print = dbg_tosa_fb_to_json(tosa_fb) to_print = pformat(to_print, compact=True, indent=1) output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" - elif output_format == EthosUCompileSpec.get_output_format(): + elif output_format == "vela": vela_cmd_stream = lowered_module.processed_bytes output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" else: @@ -180,7 +186,7 @@ def run( class Serialize(tester.Serialize): - def __init__(self, compile_spec: ArmCompileSpec, timeout): + def __init__(self, compile_spec: list[CompileSpec], timeout): super().__init__() self.timeout = timeout self.executorch_program_manager: ExecutorchProgramManager | None @@ -197,7 +203,7 @@ def run_artifact(self, inputs): "Tried running artifact from Serialize stage without running the stage." ) inputs_flattened, _ = tree_flatten(inputs) - intermediate_path = self.compile_spec.get_intermediate_path() + intermediate_path = get_intermediate_path(self.compile_spec) target_board = get_target_board(self.compile_spec) elf_path = get_elf_path(target_board) @@ -291,7 +297,7 @@ def __init__( self, model: torch.nn.Module, example_inputs: Tuple, - compile_spec: ArmCompileSpec, + compile_spec: List[CompileSpec], tosa_ref_model_path: str | None = None, dynamic_shapes: Optional[Tuple[Any]] = None, constant_methods: Optional[Dict[str, Any]] = None, @@ -325,11 +331,12 @@ def quantize( ): if quantize_stage is None: quantizer = None - if isinstance(self.compile_spec, TosaCompileSpec): - quantizer = TOSAQuantizer(self.compile_spec) - elif isinstance(self.compile_spec, EthosUCompileSpec): + if is_tosa(self.compile_spec): + tosa_spec = get_tosa_spec(self.compile_spec) + quantizer = TOSAQuantizer(tosa_spec) + elif is_ethosu(self.compile_spec): quantizer = EthosUQuantizer(self.compile_spec) - elif isinstance(self.compile_spec, VgfCompileSpec): + elif is_vgf(self.compile_spec): quantizer = VgfQuantizer(self.compile_spec) quantize_stage = tester.Quantize( quantizer, @@ -352,12 +359,10 @@ def to_edge( def partition(self, partition_stage: Optional[Partition] = None): if partition_stage is None: - if isinstance(self.compile_spec, TosaCompileSpec): - arm_partitioner = TOSAPartitioner(self.compile_spec) - elif isinstance(self.compile_spec, EthosUCompileSpec): - arm_partitioner = EthosUPartitioner(self.compile_spec) - elif isinstance(self.compile_spec, VgfCompileSpec): - arm_partitioner = VgfPartitioner(self.compile_spec) + if is_tosa(self.compile_spec): + arm_partitioner = TOSAPartitioner(compile_spec=self.compile_spec) + elif is_ethosu(self.compile_spec): + arm_partitioner = EthosUPartitioner(compile_spec=self.compile_spec) else: raise ValueError("compile spec doesn't target any Arm Partitioner") partition_stage = Partition(arm_partitioner) @@ -375,24 +380,23 @@ def to_edge_transform_and_lower( Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, ): - if transform_passes is not None: - raise RuntimeError( - "transform passes are given to ArmTester at construction." - ) - if to_edge_and_lower_stage is None: if partitioners is None: - if isinstance(self.compile_spec, TosaCompileSpec): + arm_partitioner = None + if is_tosa(self.compile_spec): arm_partitioner = TOSAPartitioner( - self.compile_spec, additional_checks + compile_spec=self.compile_spec, + additional_checks=additional_checks, ) - elif isinstance(self.compile_spec, EthosUCompileSpec): + elif is_ethosu(self.compile_spec): arm_partitioner = EthosUPartitioner( - self.compile_spec, additional_checks + compile_spec=self.compile_spec, + additional_checks=additional_checks, ) - elif isinstance(self.compile_spec, VgfCompileSpec): + elif is_vgf(self.compile_spec): arm_partitioner = VgfPartitioner( - self.compile_spec, additional_checks + compile_spec=self.compile_spec, + additional_checks=additional_checks, ) else: raise ValueError("compile spec doesn't target any Arm Partitioner") @@ -421,7 +425,7 @@ def serialize( if serialize_stage is None: serialize_stage = Serialize(self.compile_spec, timeout) assert ( - self.compile_spec.get_intermediate_path() is not None + get_intermediate_path(self.compile_spec) is not None ), "Can't dump serialized file when compile specs do not contain an artifact path." return super().serialize(serialize_stage) @@ -617,7 +621,7 @@ def dump_dtype_distribution( to_print = f"{line} {self.cur} Placeholder Dtype Distribution {line}\n" graph = self.get_graph(self.cur) - tosa_spec = self.compile_spec.tosa_spec + tosa_spec = get_tosa_spec(self.compile_spec) dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution( graph, tosa_spec ) @@ -664,7 +668,7 @@ def run_transform_for_annotation_pipeline( # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run. artifact = self.get_artifact(stage) if self.cur == StageType.EXPORT: - new_gm = ArmPassManager(self.compile_spec.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type] + new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type] graph_module=artifact.graph_module ) else: @@ -780,7 +784,7 @@ def _get_tosa_operator_distribution( [operator["op"] for operator in block["operators"]] ) break - elif spec.value == EthosUCompileSpec.get_output_format().encode(): + elif spec.value == b"vela": return "Can not get operator distribution for Vela command stream." else: return f"Unknown output format '{spec.value}'." diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 123c1af44c3..102ccd209e9 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -21,8 +21,8 @@ ) import torch -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -37,6 +37,7 @@ ) from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_base import ExportPass from torch._export.pass_base import PassType @@ -103,7 +104,7 @@ def __init__( module: torch.nn.Module, test_data: T, aten_ops: str | List[str], - compile_spec: ArmCompileSpec, + compile_spec: List[CompileSpec], exir_ops: Optional[str | List[str]] = None, use_to_edge_transform_and_lower: bool = True, dynamic_shapes: Optional[Tuple[Any]] = None, @@ -339,7 +340,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -444,7 +445,7 @@ def __init__( run_on_tosa_ref_model: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 0, @@ -525,7 +526,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -616,7 +617,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -929,7 +930,7 @@ def __init__( per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, + tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 08b0d55aaeb..ce2b7a27487 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -16,7 +16,7 @@ from typing import cast, Dict, final, List, Set import serializer.tosa_serializer as ts # type: ignore -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.process_node import ( @@ -132,7 +132,7 @@ def preprocess( # noqa: C901 debug_hook = None if dump_debug_info is not None: - debug_hook = DebugHook(ArmCompileSpec.DebugMode[dump_debug_info]) + debug_hook = DebugHook(ArmCompileSpecBuilder.DebugMode[dump_debug_info]) # TODO: Fix the need to lazily import this. from executorch.backends.arm.operators.node_visitor import get_node_visitors @@ -192,7 +192,7 @@ def _sort_key(t: Node) -> int: ) if debug_hook is not None: - if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: + if debug_hook.mode == ArmCompileSpecBuilder.DebugMode.JSON: json_output = debug_hook.serialize() with open(f"{artifact_path}/debug.json", "w") as f: f.write(json_output) diff --git a/backends/arm/tosa/compile_spec.py b/backends/arm/tosa/compile_spec.py deleted file mode 100644 index 39403c867d7..00000000000 --- a/backends/arm/tosa/compile_spec.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.tosa import TosaSpecification - - -class TosaCompileSpec(ArmCompileSpec): - def __init__(self, tosa_spec: TosaSpecification | str): - if isinstance(tosa_spec, str): - tosa_spec = TosaSpecification.create_from_string(tosa_spec) - self._set_compile_specs(tosa_spec, []) - - def validate(self): - if len(self.compiler_flags) != 0: - raise ValueError( - f"TosaCompileSpec can't have compiler flags, got {self.compiler_flags}" - ) - pass - - @classmethod - def get_output_format(cls) -> str: - return "tosa" diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index ab381470968..c0f546fe50a 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -18,7 +18,8 @@ tosa_support_factory, ) from executorch.backends.arm.tosa.backend import TOSABackend -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.specification import get_tosa_spec +from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, @@ -37,7 +38,7 @@ def is_noop_clone(node: torch.fx.node.Node) -> bool: return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default -def is_noop_alias_copy(node: torch.fx.Node) -> bool: +def is_noop_alias_copy(node: torch.fx.node.Node) -> bool: return node.target == exir_ops.edge.aten.alias_copy.default @@ -59,14 +60,15 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool: class TOSAPartitioner(Partitioner): def __init__( self, - compile_spec: TosaCompileSpec, + compile_spec: List[CompileSpec], additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: - self.delegation_spec = DelegationSpec( - TOSABackend.__name__, compile_spec.to_list() - ) + from executorch.backends.arm.arm_backend import is_tosa + + if not is_tosa(compile_spec): + raise RuntimeError("compile spec is not targeting TOSA") + self.delegation_spec = DelegationSpec(TOSABackend.__name__, compile_spec) self.additional_checks = additional_checks - self.tosa_spec = compile_spec.tosa_spec def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa # Run the CapabilityBasedPartitioner to return the largest possible @@ -75,7 +77,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no logger.info("TOSAPartitioner::partition") partition_tags: dict[str, DelegationSpec] = {} - tosa_spec = self.tosa_spec + tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs) logger.info(f"Partitioning for {self.delegation_spec.backend_id}: {tosa_spec}") @@ -213,7 +215,7 @@ def filter_fn(node: torch.fx.Node) -> bool: torch.ops.aten.logit.default, ] + ops_to_not_decompose_if_quant_op - tosa_spec = self.tosa_spec + tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs) if not tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d # and upsample_nearest2d decompose into that it will not be possible to diff --git a/backends/arm/vgf/__init__.py b/backends/arm/vgf/__init__.py index f4ce8f5d1a4..4ab8144cbd6 100644 --- a/backends/arm/vgf/__init__.py +++ b/backends/arm/vgf/__init__.py @@ -6,7 +6,9 @@ # pyre-unsafe from .backend import VgfBackend # noqa: F401 -from .compile_spec import VgfCompileSpec # noqa: F401 from .partitioner import VgfPartitioner # noqa: F401 -__all__ = ["VgfBackend", "VgfPartitioner", "VgfCompileSpec"] +__all__ = [ + "VgfBackend", + "VgfPartitioner", +] diff --git a/backends/arm/vgf/compile_spec.py b/backends/arm/vgf/compile_spec.py deleted file mode 100644 index 452ea5c1956..00000000000 --- a/backends/arm/vgf/compile_spec.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging - -from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] - TosaSpecification, -) - -# debug functionality -logger = logging.getLogger(__name__) - - -class VgfCompileSpec(ArmCompileSpec): - - def __init__( - self, - tosa_spec: TosaSpecification | str | None = None, - compiler_flags: list[str] | None = None, - ): - """ - Generate compile spec for VGF compatible targets - - Args: - compiler_flags: Extra compiler flags for converter_backend - """ - - if tosa_spec is None: - tosa_spec = "TOSA-1.0+FP" - if isinstance(tosa_spec, str): - tosa_spec = TosaSpecification.create_from_string(tosa_spec) - - if compiler_flags is None: - compiler_flags = [] - self._set_compile_specs(tosa_spec, compiler_flags) - self.validate() - - def validate(self): - """Throws an error if the compile spec is not valid.""" - tosa_version = self.tosa_spec.version # type: ignore[attr-defined] - tosa_profiles = self.tosa_spec.profiles # type: ignore[attr-defined] - - if tosa_version.major != 1: - raise ValueError( - "Arm backend only supports converter-backend for TOSA version 1. " - f"Invalid TOSA version: {tosa_version}" - ) - - if "FP" not in tosa_profiles and "INT" not in tosa_profiles: - raise ValueError( - "Arm backend only supports converter-backend for FP or INT. " - f"Invalid TOSA profile: {tosa_profiles}" - ) - - if len(tosa_profiles) != 1: - raise ValueError( - "For now Arm backend only supports converter-backend for either FP or INT. " - f"Invalid TOSA profile: {tosa_profiles}" - ) - - @classmethod - def get_output_format(cls) -> str: - return "vgf" diff --git a/backends/arm/vgf/partitioner.py b/backends/arm/vgf/partitioner.py index ea10730e810..f6dab597487 100644 --- a/backends/arm/vgf/partitioner.py +++ b/backends/arm/vgf/partitioner.py @@ -5,10 +5,14 @@ # pyre-unsafe -from typing import final, Optional, Sequence +from typing import final, List, Optional, Sequence +from executorch.backends.arm.arm_backend import ( + is_vgf, +) # usort: skip from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.vgf import VgfBackend, VgfCompileSpec +from executorch.backends.arm.vgf import VgfBackend +from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import DelegationSpec from torch.fx.passes.operator_support import OperatorSupportBase @@ -17,12 +21,12 @@ class VgfPartitioner(TOSAPartitioner): def __init__( self, - compile_spec: VgfCompileSpec, + compile_spec: List[CompileSpec], additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: + if not is_vgf(compile_spec): + raise RuntimeError("compile spec is not targeting Vgf") + # Override the delegation spec for Vgf - self.delegation_spec = DelegationSpec( - VgfBackend.__name__, compile_spec.to_list() - ) + self.delegation_spec = DelegationSpec(VgfBackend.__name__, compile_spec) self.additional_checks = additional_checks - self.tosa_spec = compile_spec.tosa_spec diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 8132751f6f0..d7e1b64e3ca 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -18,7 +18,13 @@ import torch from examples.devtools.scripts.export_bundled_program import save_bundled_program -from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner +from executorch.backends.arm.arm_backend import ( + ArmCompileSpecBuilder, + is_ethosu, + is_tosa, + is_vgf, +) +from executorch.backends.arm.ethosu import EthosUPartitioner from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -26,15 +32,15 @@ VgfQuantizer, ) from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner +from executorch.backends.arm.tosa.specification import get_tosa_spec from executorch.backends.arm.util.arm_model_evaluator import ( GenericModelEvaluator, MobileNetV2Evaluator, ) -from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner +from executorch.backends.arm.vgf import VgfPartitioner # To use Cortex-M backend from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( @@ -54,6 +60,7 @@ ExecutorchBackendConfig, to_edge_transform_and_lower, ) +from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.extension.export_util.utils import save_pte_program from tabulate import tabulate from torch.utils.data import DataLoader @@ -142,7 +149,7 @@ def get_model_and_inputs_from_name( def quantize( model: torch.nn.Module, model_name: str, - compile_specs: EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec, + compile_specs: list[CompileSpec], example_inputs: Tuple[torch.Tensor], evaluator_name: str | None, evaluator_config: Dict[str, Any] | None, @@ -151,11 +158,11 @@ def quantize( logging.info("Quantizing Model...") logging.debug(f"Original model: {model}") quantizer = None - if isinstance(compile_specs, EthosUCompileSpec): + if is_ethosu(compile_specs): quantizer = EthosUQuantizer(compile_specs) - elif isinstance(compile_specs, TosaCompileSpec): - quantizer = TOSAQuantizer(compile_specs) - elif isinstance(compile_specs, VgfCompileSpec): + elif is_tosa(compile_specs): + quantizer = TOSAQuantizer(get_tosa_spec(compile_specs)) + elif is_vgf(compile_specs): quantizer = VgfQuantizer(compile_specs) else: raise RuntimeError("Unsupported compilespecs for quantization!") @@ -386,20 +393,20 @@ def get_compile_spec( memory_mode: Optional[str] = None, quantize: bool = False, config: Optional[str] = None, -) -> TosaCompileSpec | EthosUCompileSpec | VgfCompileSpec: - compile_spec = None +) -> list[CompileSpec]: + spec_builder = None if target.startswith("TOSA"): try: tosa_spec = TosaSpecification.create_from_string(target) - except Exception: + except: tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - compile_spec = TosaCompileSpec(tosa_spec) + spec_builder = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec) elif "ethos-u" in target: - compile_spec = EthosUCompileSpec( + spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec( target, system_config=system_config, memory_mode=memory_mode, - extra_flags=["--verbose-operators", "--verbose-cycle-estimate"], + extra_flags="--verbose-operators --verbose-cycle-estimate", config_ini=config, ) elif "vgf" in target: @@ -407,14 +414,12 @@ def get_compile_spec( tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") else: tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") - compile_spec = VgfCompileSpec(tosa_spec) - else: - raise RuntimeError(f"Unkown target {target}") + spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec) if intermediates is not None: - compile_spec.dump_intermediate_artifacts_to(intermediates) + spec_builder.dump_intermediate_artifacts_to(intermediates) - return compile_spec + return spec_builder.build() def evaluate_model( @@ -744,11 +749,11 @@ def to_edge_TOSA_delegate( ) model = model_int8 - if isinstance(compile_spec, EthosUCompileSpec): + if is_ethosu(compile_spec): partitioner = EthosUPartitioner(compile_spec) - elif isinstance(compile_spec, TosaCompileSpec): + elif is_tosa(compile_spec): partitioner = TOSAPartitioner(compile_spec) - elif isinstance(compile_spec, VgfCompileSpec): + elif is_vgf(compile_spec): partitioner = VgfPartitioner(compile_spec) else: raise RuntimeError(f"Unhandled compile spec: {compile_spec}") diff --git a/examples/arm/ethos_u_minimal_example.ipynb b/examples/arm/ethos_u_minimal_example.ipynb index dc8ea7193aa..e63a7d37e58 100644 --- a/examples/arm/ethos_u_minimal_example.ipynb +++ b/examples/arm/ethos_u_minimal_example.ipynb @@ -80,7 +80,7 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.ethosu import EthosUCompileSpec\n", + "from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder\n", "from executorch.backends.arm.quantizer import (\n", " EthosUQuantizer,\n", " get_symmetric_quantization_config,\n", @@ -90,12 +90,13 @@ "# Create a compilation spec describing the target for configuring the quantizer\n", "# Some args are used by the Arm Vela graph compiler later in the example. Refer to Arm Vela documentation for an\n", "# explanation of its flags: https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md\n", - "compile_spec = EthosUCompileSpec(\n", + "spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec(\n", " target=\"ethos-u55-128\",\n", " system_config=\"Ethos_U55_High_End_Embedded\",\n", " memory_mode=\"Shared_Sram\",\n", - " extra_flags=[\"--output-format=raw\", \"--debug-force-regor\"]\n", + " extra_flags=\"--output-format=raw --debug-force-regor\"\n", " )\n", + "compile_spec = spec_builder.build()\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", "quantizer = EthosUQuantizer(compile_spec)\n", @@ -241,7 +242,7 @@ ], "metadata": { "kernelspec": { - "display_name": "et_env", + "display_name": ".venv (3.10.15)", "language": "python", "name": "python3" }, @@ -255,7 +256,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/examples/arm/vgf_minimal_example.ipynb b/examples/arm/vgf_minimal_example.ipynb index 36004f2c7cd..35378817a7d 100644 --- a/examples/arm/vgf_minimal_example.ipynb +++ b/examples/arm/vgf_minimal_example.ipynb @@ -82,15 +82,21 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.vgf import VgfCompileSpec\n", + "from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder\n", + "from executorch.backends.arm.tosa import ( \n", + " TosaSpecification,\n", + ")\n", "\n", "# Create a compilation spec describing the floating point target.\n", - "compile_spec = VgfCompileSpec(\"TOSA-1.0+FP\")\n", + "tosa_spec = TosaSpecification.create_from_string(\"TOSA-1.0+FP\")\n", + "\n", + "spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec)\n", + "compile_spec = spec_builder.build()\n", "\n", "_ = graph_module.print_readable()\n", "\n", "# Create a new exported program using the graph_module\n", - "exported_program = torch.export.export(graph_module, example_inputs)" + "exported_program = torch.export.export_for_training(graph_module, example_inputs)" ] }, { @@ -119,7 +125,10 @@ "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", - "compile_spec = VgfCompileSpec(\"TOSA-1.0+INT\")\n", + "tosa_spec = TosaSpecification.create_from_string(\"TOSA-1.0+INT\")\n", + "\n", + "spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec)\n", + "compile_spec = spec_builder.build()\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", "quantizer = VgfQuantizer(compile_spec)\n", @@ -134,7 +143,7 @@ "_ = quantized_graph_module.print_readable()\n", "\n", "# Create a new exported program using the quantized_graph_module\n", - "quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs)" + "quantized_exported_program = torch.export.export_for_training(quantized_graph_module, example_inputs)" ] }, {