From 49e6b1da431fcbb8a1759177776da26e608a1e51 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 3 Nov 2025 09:51:00 -0800 Subject: [PATCH 01/15] backend.py (#15430) Summary: This diff consolidates the backend functionality into a single target `//executorch/backends/aoti:aoti_backend` and simplifies the cuda backend target by making it dependent on the consolidated backend target. The following changes are made in this diff: * Creation of a new target `//executorch/backends/aoti:aoti_backend` in `fbcode/executorch/backends/aoti/targets.bzl` which includes the necessary dependencies for the AOTI backend. * Update of the `//executorch/backends/cuda:cuda_backend` target in `fbcode/executorch/backends/cuda/TARGETS` to depend on the new `//executorch/backends/aoti:aoti_backend` target instead of individual AOTI backend dependencies. * Creation of a new file `fbcode/executorch/backends/aoti/aoti_backend.py` which imports the necessary dependencies and passes for the AOTI backend. * Simplification of the `xplat/executorch/backends/cuda/cuda_backend.py` file by removing unnecessary imports and using the new `AotiBackend` class from the `aoti_backend.py` file. ghstack-source-id: 319556735 Reviewed By: larryliu0820 Differential Revision: D85704977 --- backends/aoti/aoti_backend.py | 261 ++++++++++++++++++++++++++ backends/aoti/targets.bzl | 17 ++ backends/apple/metal/metal_backend.py | 193 +++---------------- backends/cuda/TARGETS | 1 + backends/cuda/cuda_backend.py | 247 ++---------------------- exir/backend/backend_api.py | 16 +- extension/llm/tokenizers | 2 +- 7 files changed, 337 insertions(+), 400 deletions(-) create mode 100644 backends/aoti/aoti_backend.py diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py new file mode 100644 index 00000000000..6c1a8a8661c --- /dev/null +++ b/backends/aoti/aoti_backend.py @@ -0,0 +1,261 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional, Set + +import torch +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +@experimental( + "This API and all of aoti-driven backend related functionality are experimental." +) +class AotiBackend(BackendDetails, ABC): + """ + Base backend class for AOTInductor-based backends. + + This class provides common functionality for compiling models using AOTInductor + with different device targets (CUDA, Metal/MPS, etc.). + """ + + @staticmethod + @abstractmethod + def get_device_name() -> str: + """Return the device name for this backend (e.g., 'cuda', 'mps').""" + pass + + @staticmethod + @abstractmethod + def get_supported_fallback_kernels() -> Dict[str, Any]: + """Return the set of supported fallback kernels for this backend.""" + pass + + @staticmethod + @abstractmethod + def get_decomposition_table() -> Dict[Any, Any]: + """Return the decomposition table for this backend.""" + pass + + @staticmethod + @abstractmethod + def get_aoti_compile_options() -> Dict[str, typing.Any]: + """Return the AOTInductor compilation options for this backend.""" + pass + + @classmethod + @contextlib.contextmanager + def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]): + """ + Context manager to collect unsupported fallback kernels during compilation. + Monitors both extern kernel calls and runtime lookup. + """ + supported_kernels = cls.get_supported_fallback_kernels() + + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + original_generate_fallback_kernel_with_runtime_lookup_aot = ( + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + debug_handle: Optional[int] = None, + ): + if kernel not in supported_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, + kernel, + args, + device, + debug_args=debug_args, + debug_handle=debug_handle, + ) + + def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( + self, + op_overload, + raw_args, + output_args, + raw_outputs, + ): + kernel_name = getattr(op_overload, "_name", str(op_overload)) + if kernel_name not in supported_kernels: + missing_fallback_kernels.add(kernel_name) + + original_generate_fallback_kernel_with_runtime_lookup_aot( + self, op_overload, raw_args, output_args, raw_outputs + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels + + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( + original_generate_fallback_kernel_with_runtime_lookup_aot + ) + + @classmethod + def preprocess( + cls, + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Preprocess the edge program and compile it using AOTInductor. + Weights are always separated from the SO file. + """ + device_name = cls.get_device_name() + decomposition_table = cls.get_decomposition_table() + options = cls.get_aoti_compile_options() + + # Move the edge_program to the target device + device_edge_program = move_to_device_pass(edge_program, device_name) + + # Replace view_copy with view + ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) + + # Run decompositions if any + if decomposition_table: + device_edge_program = device_edge_program.run_decompositions( + decomposition_table + ) + + edge_program_module = device_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = device_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in device_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Track missing fallback kernels + missing_fallback_kernels: Set[str] = set() + + # Compile with fallback kernel collection + with cls.collect_unsupported_fallback_kernels( + missing_fallback_kernels + ), torch.no_grad(): + paths = torch._inductor.aot_compile( + edge_program_module, tuple(user_input_placeholders), options=options + ) + + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + method_name = cls.method_name_from_compile_specs(compile_specs) + raise RuntimeError( + f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # Extract paths - weights are always separated + so_path = None + blob_path = None + + if isinstance(paths, list): + for path in paths: + if path.endswith(".wrapper.so"): + so_path = path + elif path.endswith(".wrapper_weights.blob"): + blob_path = path + else: + so_path = paths + + if so_path is None or blob_path is None: + raise RuntimeError( + f"Could not find required files in compiled paths, got {paths}" + ) + + # Read SO file + with open(so_path, "rb") as f: + so_data = f.read() + + # Read weights blob + with open(blob_path, "rb") as f: + blob_data = f.read() + + # Create named data store + named_data_store = NamedDataStore() + method_name = cls.method_name_from_compile_specs(compile_specs) + + # Add SO and weights blob separately + named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) + weights_blob_data_type = f"aoti_{device_name}_blob" + named_data_store.add_named_data( + method_name + "_weights_blob", blob_data, 1, weights_blob_data_type + ) + + # Clean up the generated files + os.remove(so_path) + os.remove(blob_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @staticmethod + def generate_method_name_compile_spec( + method_name: str, + ) -> CompileSpec: + """ + Generate a CompileSpec for the given method name. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @staticmethod + def method_name_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> str: + """ + Extract the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index be5fe490721..327bef8cc53 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -16,6 +16,23 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "aoti_backend", + srcs = [ + "aoti_backend.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/aoti/passes:passes", + "//executorch/exir/_serialize:lib", + "//executorch/exir/backend:backend_details", + "//executorch/exir/backend:compile_spec_schema", + ], + ) + # AOTI common shims functionality runtime.cxx_library( name = "common_shims", diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 7d1a5496be3..d73639beb54 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -4,107 +4,44 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib -import os import typing -from enum import Enum +from typing import Any, Dict, final -from typing import Any, Dict, final, List, Optional, Set - -import torch -from executorch.backends.aoti.passes.replace_view_copy_with_view import ( - ReplaceViewCopyWithViewPass, -) -from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.backends.aoti.aoti_backend import AotiBackend from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) -from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu -from torch.export.passes import move_to_device_pass - - -# exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = { - "aoti_torch_mps_convolution": None, - "aoti_torch_mps_mm_out": None, - "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, -} - -# required fallback kernels but not supported -missing_fallback_kernels: Set[str] = set() - - -class COMPILE_SPEC_KEYS(Enum): - METHOD_NAME = "method_name" - - -# context manager for non-fallback guarantee -# it will raise exception when generating fallback kernels during aoti compile -@contextlib.contextmanager -def collect_unsupported_fallback_kernels(): - original_generate_c_shim_extern_kernel_call = ( - CppWrapperCpu.generate_c_shim_extern_kernel_call - ) - - def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - debug_handle: Optional[int] = None, - ): - if kernel not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel) - - original_generate_c_shim_extern_kernel_call( - self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle - ) - - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels - ) - try: - yield - finally: - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - original_generate_c_shim_extern_kernel_call - ) @final @experimental( "This API and all of Metal backend related functionality are experimental." ) -class MetalBackend(BackendDetails): - @staticmethod - def preprocess( - edge_program: ExportedProgram, - compile_specs: List[CompileSpec], - ) -> PreprocessResult: - print("entering the lowerable parts in MetalBackend.preprocess....") - # Move the edge_program from CPU to MPS for aoti compile - mps_edge_program = move_to_device_pass(edge_program, "mps") +class MetalBackend(AotiBackend): + """ + MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate + optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices + using the Executorch runtime. + """ - # replace slice_copy with slice - ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module) + @staticmethod + def get_device_name() -> str: + return "mps" - edge_program_module = mps_edge_program.module() + @staticmethod + def get_supported_fallback_kernels() -> Dict[str, Any]: + return { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, + } - # Grab all input placeholders from the graph - user_input_names = mps_edge_program.graph_signature.user_inputs - user_input_placeholders = [] - for node in mps_edge_program.graph.nodes: - if node.op == "placeholder" and node.name in user_input_names: - user_input_placeholders.append(node.meta["val"]) + @staticmethod + def get_decomposition_table() -> Dict[Any, Any]: + return {} - # Base options for all devices - options: dict[str, typing.Any] = { + @staticmethod + def get_aoti_compile_options() -> Dict[str, typing.Any]: + return { # Do not link against the full PyTorch/libtorch library "aot_inductor.link_libtorch": False, # Separate weight constants from the .so file @@ -117,83 +54,3 @@ def preprocess( # "aot_inductor.debug_compile": True, # "aot_inductor.force_mmap_weights": False, } - - with collect_unsupported_fallback_kernels(): - paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] - if len(missing_fallback_kernels) > 0: - formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) - raise RuntimeError( - f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" - "Please add them to the AOTI backend." - ) - - # Extract the .so and .blob paths from the returned list - so_path = None - blob_path = None - for path in paths: - if path.endswith(".wrapper.so"): - so_path = path - elif path.endswith(".wrapper_weights.blob"): - blob_path = path - - if so_path is None or blob_path is None: - raise RuntimeError( - f"Could not find required files in compiled paths, got {paths}" - ) - - # pyre-ignorep[6]: Incompatible parameter type - with open(so_path, "rb") as f: - so_data = f.read() - - named_data_store = NamedDataStore() - method_name = MetalBackend.method_name_from_compile_specs(compile_specs) - - # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) - - # Add weights blob to named data store - with open(blob_path, "rb") as f: - blob_data = f.read() - - named_data_store.add_named_data( - method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob" - ) - - # Clean up the weights blob file - os.remove(blob_path) - - # Clean up the generated so file; it has been packaged into the NamedDataStore - # pyre-ignorep[6]: Incompatible parameter type - os.remove(so_path) - - return PreprocessResult( - processed_bytes=b"", - debug_handle_map={}, - data_store_output=named_data_store.get_named_data_store_output(), - ) - - @staticmethod - def generate_method_name_compile_spec( - method_name: str, - ) -> CompileSpec: - """ - Generates a CompileSpec for the given method name. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.METHOD_NAME.value, - method_name.encode("utf-8"), - ) - - @staticmethod - def method_name_from_compile_specs( - compile_specs: List[CompileSpec], - ) -> str: - """ - Returns the method name from the compile specs. - """ - for spec in compile_specs: - if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: - return spec.value.decode("utf-8") - raise RuntimeError( - f"Could not find method name in compile specs: {compile_specs}" - ) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index d8256f77c41..3ae4eec6680 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -17,6 +17,7 @@ runtime.python_library( "//executorch/exir/_serialize:lib", "//executorch/exir/backend:backend_details", "//executorch/exir/backend:compile_spec_schema", + "//executorch/backends/aoti:aoti_backend", ], ) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 772e24c75b3..0ba45e44060 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -4,116 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib -import os import typing -from enum import Enum -from importlib import resources - -from typing import Any, Dict, final, List, Optional, Set +from typing import Any, Dict, final import torch -from executorch.backends.aoti.passes.replace_view_copy_with_view import ( - ReplaceViewCopyWithViewPass, -) - -from executorch.backends.cuda.triton.replacement_pass import ( - ReplaceEdgeOpWithTritonOpPass, -) -from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.backends.aoti.aoti_backend import AotiBackend from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) -from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.decomposition import conv1d_to_conv2d -from torch.export.passes import move_to_device_pass - - -cuda_decomposition_table = { - torch.ops.aten.conv1d.default: conv1d_to_conv2d, -} - -# exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = { - "at::_ops::_weight_int4pack_mm::call": None, -} - -# required fallback kernels but not supported -missing_fallback_kernels: Set[str] = set() - - -class COMPILE_SPEC_KEYS(Enum): - METHOD_NAME = "method_name" - - -# context manager for non-fallback guarantee -# it will raise exception when generating fallback kernels during aoti compile -@contextlib.contextmanager -def collect_unsupported_fallback_kernels(): - original_generate_c_shim_extern_kernel_call = ( - CppWrapperCpu.generate_c_shim_extern_kernel_call - ) - original_generate_fallback_kernel_with_runtime_lookup_aot = ( - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot - ) - - def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - ): - if kernel not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel) - - original_generate_c_shim_extern_kernel_call( - self, kernel, args, device, debug_args=debug_args - ) - - def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( - self, - op_overload, - raw_args, - output_args, - raw_outputs, - ): - # Extract kernel name for collection - kernel_name = getattr(op_overload, "_name", str(op_overload)) - if kernel_name not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel_name) - - original_generate_fallback_kernel_with_runtime_lookup_aot( - self, op_overload, raw_args, output_args, raw_outputs - ) - - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels - ) - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( - generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels - ) - try: - yield - finally: - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - original_generate_c_shim_extern_kernel_call - ) - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( - original_generate_fallback_kernel_with_runtime_lookup_aot - ) @final @experimental( "This API and all of cuda backend related functionality are experimental." ) -class CudaBackend(BackendDetails): +class CudaBackend(AotiBackend): """ CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices @@ -121,33 +25,24 @@ class CudaBackend(BackendDetails): """ @staticmethod - def preprocess( # noqa: C901 - edge_program: ExportedProgram, - compile_specs: List[CompileSpec], - ) -> PreprocessResult: - # Move the edge_program from CPU to CUDA for aoti compile - cuda_edge_program = move_to_device_pass(edge_program, "cuda") - - # replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int - ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) - - # Replace aten ops with triton ops - ReplaceEdgeOpWithTritonOpPass()(cuda_edge_program.graph_module) - - cuda_edge_program = cuda_edge_program.run_decompositions( - cuda_decomposition_table - ) + def get_device_name() -> str: + return "cuda" - edge_program_module = cuda_edge_program.module() + @staticmethod + def get_supported_fallback_kernels() -> Dict[str, Any]: + return { + "at::_ops::_weight_int4pack_mm::call": None, + } - # Grab all input placeholders from the graph - user_input_names = cuda_edge_program.graph_signature.user_inputs - user_input_placeholders = [] - for node in cuda_edge_program.graph.nodes: - if node.op == "placeholder" and node.name in user_input_names: - user_input_placeholders.append(node.meta["val"]) + @staticmethod + def get_decomposition_table() -> Dict[Any, Any]: + return { + torch.ops.aten.conv1d.default: conv1d_to_conv2d, + } - options: dict[str, typing.Any] = { + @staticmethod + def get_aoti_compile_options() -> Dict[str, typing.Any]: + return { # Disable this to support sdpa decomposition # TODO(gasoonjia): remove it after pin bump to latest pytorch "loop_ordering_after_fusion": False, @@ -169,109 +64,3 @@ def preprocess( # noqa: C901 # Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch "max_autotune_conv_backends": "TRITON", } - - platform = "linux" - shim_library_path = None - for spec in compile_specs: - if spec.key == "platform": - platform = spec.value.decode("utf-8") - if spec.key == "shim_library_path": - shim_library_path = spec.value.decode("utf-8") - - assert platform == "linux" or platform == "windows" - if platform == "windows" and shim_library_path is None: - lib_dir = resources.files("executorch").joinpath("data/lib") - shim_library_path = str(lib_dir) - if platform == "linux": - assert shim_library_path is None - - if platform == "windows": - options.update( - { - "aot_inductor.cross_target_platform": "windows", - "aot_inductor.aoti_shim_library": "aoti_cuda_shims", - "aot_inductor.aoti_shim_library_path": shim_library_path, - "aot_inductor.precompile_headers": False, - } - ) - - with collect_unsupported_fallback_kernels(), torch.no_grad(): - # torch._logging.set_logs(post_grad_graphs=True) - # Here we should expect 1 so file and 1 weight blob in the same directory. - paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] - if len(missing_fallback_kernels) > 0: - formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) - raise RuntimeError( - f"Method {CudaBackend.method_name_from_compile_specs(compile_specs)} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" - "Please add them to the AOTI backend." - ) - - # Extract the .so and .blob paths from the returned list - so_path = None - blob_path = None - for path in paths: - if path.endswith(".wrapper.so"): - so_path = path - elif path.endswith(".wrapper_weights.blob"): - blob_path = path - - if so_path is None or blob_path is None: - raise RuntimeError( - f"Could not find required files in compiled paths, got {paths}" - ) - - # pyre-ignorep[6]: Incompatible parameter type - with open(so_path, "rb") as f: - so_data = f.read() - - named_data_store = NamedDataStore() - method_name = CudaBackend.method_name_from_compile_specs(compile_specs) - - # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) - - # Add weights blob to named data store - with open(blob_path, "rb") as f: - blob_data = f.read() - named_data_store.add_named_data( - method_name + "_weights_blob", blob_data, 1, "aoti_cuda_blob" - ) - # Clean up the weights blob file - os.remove(blob_path) - - # Clean up the generated so file; it has been packaged into the NamedDataStore - # pyre-ignorep[6]: Incompatible parameter type - os.remove(so_path) - - return PreprocessResult( - processed_bytes=b"", - debug_handle_map={}, - data_store_output=named_data_store.get_named_data_store_output(), - ) - - @staticmethod - def generate_method_name_compile_spec( - method_name: str, - ) -> CompileSpec: - """ - Returns the compile spec representing the model compute precision, for additional details - please refer to the documentation for ``coremltools.precision``. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.METHOD_NAME.value, - method_name.encode("utf-8"), - ) - - @staticmethod - def method_name_from_compile_specs( - compile_specs: List[CompileSpec], - ) -> str: - """ - Returns the method name from the compile specs. - """ - for spec in compile_specs: - if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: - return spec.value.decode("utf-8") - raise RuntimeError( - f"Could not find method name in compile specs: {compile_specs}" - ) diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index dd8d97d66ac..3a072e03599 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -10,7 +10,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import singledispatch -from typing import Dict, Generator, List, Mapping +from typing import Dict, Generator, List, Mapping, Set import torch @@ -581,9 +581,21 @@ def lower_all_submodules_to_backend( for method_name, call_submodule_nodes in method_to_submodules_nodes.items() } + def _get_all_final_backend_details_subclasses(cls) -> Set[type]: + subclasses = set() + if len(cls.__subclasses__()) == 0: + return {cls} + else: + for subclass in cls.__subclasses__(): + # Recursively check subclasses + subclasses.update(_get_all_final_backend_details_subclasses(subclass)) + return subclasses + backend_name_to_subclass = { - subclass.__name__: subclass for subclass in BackendDetails.__subclasses__() + subclass.__name__: subclass + for subclass in _get_all_final_backend_details_subclasses(BackendDetails) } + if backend_id not in backend_name_to_subclass: raise NotImplementedError(f"Backend {backend_id} was not found.") diff --git a/extension/llm/tokenizers b/extension/llm/tokenizers index 3aada3fe28c..d710a0cf10c 160000 --- a/extension/llm/tokenizers +++ b/extension/llm/tokenizers @@ -1 +1 @@ -Subproject commit 3aada3fe28c945d14d5ec62254eb56ccdf10eb11 +Subproject commit d710a0cf10cfa8cb7ffda33c4e61af63119bc95f From e98d3392246fceac82953e64d5e09a7be6a49b99 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 3 Nov 2025 22:53:02 -0800 Subject: [PATCH 02/15] solve qualcomm import issue --- setup.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/setup.py b/setup.py index 71aa4c543d4..735840a29c3 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ # imports. import logging import os +import platform import re import shutil import site @@ -90,6 +91,23 @@ def _is_windows() -> bool: return sys.platform == "win32" +# Duplicate of the one in backends.qualcomm.scripts.download_qnn_sdk to avoid +# import errors. +def is_linux_x86() -> bool: + """ + Check if the current platform is Linux x86_64. + + Returns: + bool: True if the system is Linux x86_64, False otherwise. + """ + return platform.system().lower() == "linux" and platform.machine().lower() in ( + "x86_64", + "amd64", + "i386", + "i686", + ) + + class Version: """Static strings that describe the version of the pip package.""" From 05e081f9dbf3ee8eda86a00b8cb278455603403c Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Mon, 3 Nov 2025 23:45:36 -0800 Subject: [PATCH 03/15] Update metal.yml --- .github/workflows/metal.yml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 92351883e8f..798d038a248 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -69,7 +69,14 @@ jobs: echo "::endgroup::" echo "::group::Setup ExecuTorch" - PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh + ./install_executorch.sh + echo "::endgroup::" + + echo "::group::Setup Huggingface" + pip install -U "huggingface_hub[cli]<1.0" accelerate + huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} echo "::endgroup::" echo "::group::Pip List" From 37ba81907b32527671b1a7f7040434a08bb8fc55 Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Mon, 3 Nov 2025 23:57:41 -0800 Subject: [PATCH 04/15] Update metal.yml --- .github/workflows/metal.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 798d038a248..168a61262e9 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -69,14 +69,15 @@ jobs: echo "::endgroup::" echo "::group::Setup ExecuTorch" - ./install_executorch.sh + PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh echo "::endgroup::" echo "::group::Setup Huggingface" pip install -U "huggingface_hub[cli]<1.0" accelerate huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} + ${CONDA_RUN} pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} + ${CONDA_RUN} pip install mistral-common librosa echo "::endgroup::" echo "::group::Pip List" From 7bd8781e70c3b5d822b95437450aab0e13430cfc Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Mon, 3 Nov 2025 23:58:35 -0800 Subject: [PATCH 05/15] Update metal.yml --- .github/workflows/metal.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 168a61262e9..45e0eb4b1cf 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -73,8 +73,8 @@ jobs: echo "::endgroup::" echo "::group::Setup Huggingface" - pip install -U "huggingface_hub[cli]<1.0" accelerate - huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" accelerate + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} ${CONDA_RUN} pip install mistral-common librosa From 47617316e3499cf51b7df7d6fa9872a61f4a83cb Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 4 Nov 2025 10:34:58 -0800 Subject: [PATCH 06/15] recover metal workflow --- .github/workflows/metal.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 45e0eb4b1cf..92351883e8f 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -72,14 +72,6 @@ jobs: PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh echo "::endgroup::" - echo "::group::Setup Huggingface" - ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" accelerate - ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN - OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - ${CONDA_RUN} pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} - ${CONDA_RUN} pip install mistral-common librosa - echo "::endgroup::" - echo "::group::Pip List" ${CONDA_RUN} pip list echo "::endgroup::" From dc1fd12a861e3187a11bfd8338194b6a7a76fe40 Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Wed, 5 Nov 2025 13:00:58 -0800 Subject: [PATCH 07/15] Update metal device name --- backends/aoti/aoti_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 6c1a8a8661c..fa978e0cfdd 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -39,13 +39,13 @@ class AotiBackend(BackendDetails, ABC): Base backend class for AOTInductor-based backends. This class provides common functionality for compiling models using AOTInductor - with different device targets (CUDA, Metal/MPS, etc.). + with different device targets (CUDA, Metal, etc.). """ @staticmethod @abstractmethod def get_device_name() -> str: - """Return the device name for this backend (e.g., 'cuda', 'mps').""" + """Return the device name for this backend (e.g., 'cuda', 'metal').""" pass @staticmethod From b11a3fbd725f235498a0281917bcd786cc4ab965 Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Wed, 5 Nov 2025 13:01:48 -0800 Subject: [PATCH 08/15] Update metal device name in metal_backend.py --- backends/apple/metal/metal_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index d73639beb54..2bb9d9cc569 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -24,7 +24,7 @@ class MetalBackend(AotiBackend): @staticmethod def get_device_name() -> str: - return "mps" + return "metal" @staticmethod def get_supported_fallback_kernels() -> Dict[str, Any]: From 3b05c522a8520e8e20bf3c1a64f01d931de2e9e5 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 7 Nov 2025 13:24:09 -0800 Subject: [PATCH 09/15] make aoti_backend not a real backend --- backends/aoti/aoti_backend.py | 14 +++++++------- backends/apple/metal/metal_backend.py | 3 ++- backends/cuda/cuda_backend.py | 3 ++- exir/backend/backend_api.py | 16 ++-------------- 4 files changed, 13 insertions(+), 23 deletions(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index fa978e0cfdd..4a9340444a2 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -17,11 +17,7 @@ ) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) +from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch.export.passes import move_to_device_pass @@ -34,12 +30,16 @@ class COMPILE_SPEC_KEYS(Enum): @experimental( "This API and all of aoti-driven backend related functionality are experimental." ) -class AotiBackend(BackendDetails, ABC): +class AotiBackend(ABC): """ - Base backend class for AOTInductor-based backends. + Base mixin class for AOTInductor-based backends. This class provides common functionality for compiling models using AOTInductor with different device targets (CUDA, Metal, etc.). + + This is a mixin class, not an actual backend object, for aoti-driven backens. + Concrete backends (e.g., CudaBackend, MetalBackend) should inherit from both + BackendDetails and AotiBackend to get the full functionality. """ @staticmethod diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 2bb9d9cc569..1c6d13440ff 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -8,6 +8,7 @@ from typing import Any, Dict, final from executorch.backends.aoti.aoti_backend import AotiBackend +from executorch.exir.backend.backend_details import BackendDetails from executorch.exir._warnings import experimental @@ -15,7 +16,7 @@ @experimental( "This API and all of Metal backend related functionality are experimental." ) -class MetalBackend(AotiBackend): +class MetalBackend(BackendDetails, AotiBackend): """ MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 0ba45e44060..1ae43963fb9 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -9,6 +9,7 @@ import torch from executorch.backends.aoti.aoti_backend import AotiBackend +from executorch.exir.backend.backend_details import BackendDetails from executorch.exir._warnings import experimental from torch._inductor.decomposition import conv1d_to_conv2d @@ -17,7 +18,7 @@ @experimental( "This API and all of cuda backend related functionality are experimental." ) -class CudaBackend(AotiBackend): +class CudaBackend(BackendDetails, AotiBackend): """ CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 3a072e03599..dd8d97d66ac 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -10,7 +10,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import singledispatch -from typing import Dict, Generator, List, Mapping, Set +from typing import Dict, Generator, List, Mapping import torch @@ -581,21 +581,9 @@ def lower_all_submodules_to_backend( for method_name, call_submodule_nodes in method_to_submodules_nodes.items() } - def _get_all_final_backend_details_subclasses(cls) -> Set[type]: - subclasses = set() - if len(cls.__subclasses__()) == 0: - return {cls} - else: - for subclass in cls.__subclasses__(): - # Recursively check subclasses - subclasses.update(_get_all_final_backend_details_subclasses(subclass)) - return subclasses - backend_name_to_subclass = { - subclass.__name__: subclass - for subclass in _get_all_final_backend_details_subclasses(BackendDetails) + subclass.__name__: subclass for subclass in BackendDetails.__subclasses__() } - if backend_id not in backend_name_to_subclass: raise NotImplementedError(f"Backend {backend_id} was not found.") From b83071f3deda55b94c9ad4033865275707a53b16 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 10 Nov 2025 13:25:20 -0800 Subject: [PATCH 10/15] swap inherit order --- backends/apple/metal/metal_backend.py | 2 +- backends/cuda/cuda_backend.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 1c6d13440ff..64a428e86fd 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -16,7 +16,7 @@ @experimental( "This API and all of Metal backend related functionality are experimental." ) -class MetalBackend(BackendDetails, AotiBackend): +class MetalBackend(AotiBackend, BackendDetails): """ MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 1ae43963fb9..e9a2d594adb 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -9,8 +9,8 @@ import torch from executorch.backends.aoti.aoti_backend import AotiBackend -from executorch.exir.backend.backend_details import BackendDetails from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import BackendDetails from torch._inductor.decomposition import conv1d_to_conv2d @@ -18,7 +18,7 @@ @experimental( "This API and all of cuda backend related functionality are experimental." ) -class CudaBackend(BackendDetails, AotiBackend): +class CudaBackend(AotiBackend, BackendDetails): """ CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices From c9d871a7aa459230ab65141334e5cbe649d0fc12 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 10 Nov 2025 13:47:42 -0800 Subject: [PATCH 11/15] run lintruner --- backends/apple/metal/metal_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 64a428e86fd..7d759f3a3d0 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -8,8 +8,8 @@ from typing import Any, Dict, final from executorch.backends.aoti.aoti_backend import AotiBackend -from executorch.exir.backend.backend_details import BackendDetails from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import BackendDetails @final From 4be608d94ba23f028d4b2b8dc2795acb010704e0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 10 Nov 2025 15:51:28 -0800 Subject: [PATCH 12/15] solve ci --- backends/aoti/aoti_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 4a9340444a2..6b9cb5d34ca 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -148,7 +148,9 @@ def preprocess( options = cls.get_aoti_compile_options() # Move the edge_program to the target device - device_edge_program = move_to_device_pass(edge_program, device_name) + device_edge_program = move_to_device_pass( + edge_program, device_name if device_name != "metal" else "mps" + ) # Replace view_copy with view ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) From 1e90fd0a25456604039c1a8fafe908b71ecd60d8 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 20 Nov 2025 15:04:56 -0800 Subject: [PATCH 13/15] merge lastest update --- backends/aoti/aoti_backend.py | 37 +++++++++----- backends/apple/metal/metal_backend.py | 28 +++++++---- backends/cuda/cuda_backend.py | 69 +++++++++++++++++++++++---- 3 files changed, 104 insertions(+), 30 deletions(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 6b9cb5d34ca..cea10fc5b05 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -42,30 +42,38 @@ class AotiBackend(ABC): BackendDetails and AotiBackend to get the full functionality. """ - @staticmethod + @classmethod @abstractmethod - def get_device_name() -> str: + def get_device_name(cls) -> str: """Return the device name for this backend (e.g., 'cuda', 'metal').""" pass - @staticmethod + @classmethod @abstractmethod - def get_supported_fallback_kernels() -> Dict[str, Any]: + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: """Return the set of supported fallback kernels for this backend.""" pass - @staticmethod + @classmethod @abstractmethod - def get_decomposition_table() -> Dict[Any, Any]: + def get_decomposition_table(cls) -> Dict[Any, Any]: """Return the decomposition table for this backend.""" pass - @staticmethod + @classmethod @abstractmethod - def get_aoti_compile_options() -> Dict[str, typing.Any]: + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: """Return the AOTInductor compilation options for this backend.""" pass + @classmethod + @abstractmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition.""" + pass + @classmethod @contextlib.contextmanager def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]): @@ -145,7 +153,7 @@ def preprocess( """ device_name = cls.get_device_name() decomposition_table = cls.get_decomposition_table() - options = cls.get_aoti_compile_options() + options = cls.get_aoti_compile_options(compile_specs) # Move the edge_program to the target device device_edge_program = move_to_device_pass( @@ -155,6 +163,11 @@ def preprocess( # Replace view_copy with view ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) + # Apply custom backend-specific passes + custom_passes = cls.get_custom_passes() + for custom_pass in custom_passes: + custom_pass(device_edge_program.graph_module) + # Run decompositions if any if decomposition_table: device_edge_program = device_edge_program.run_decompositions( @@ -236,8 +249,9 @@ def preprocess( data_store_output=named_data_store.get_named_data_store_output(), ) - @staticmethod + @classmethod def generate_method_name_compile_spec( + cls, method_name: str, ) -> CompileSpec: """ @@ -248,8 +262,9 @@ def generate_method_name_compile_spec( method_name.encode("utf-8"), ) - @staticmethod + @classmethod def method_name_from_compile_specs( + cls, compile_specs: List[CompileSpec], ) -> str: """ diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 7d759f3a3d0..1b27b027fc2 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -5,11 +5,12 @@ # LICENSE file in the root directory of this source tree. import typing -from typing import Any, Dict, final +from typing import Any, Dict, final, List from executorch.backends.aoti.aoti_backend import AotiBackend from executorch.exir._warnings import experimental from executorch.exir.backend.backend_details import BackendDetails +from executorch.exir.backend.compile_spec_schema import CompileSpec @final @@ -23,12 +24,12 @@ class MetalBackend(AotiBackend, BackendDetails): using the Executorch runtime. """ - @staticmethod - def get_device_name() -> str: + @classmethod + def get_device_name(cls) -> str: return "metal" - @staticmethod - def get_supported_fallback_kernels() -> Dict[str, Any]: + @classmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "aoti_torch_mps_addmm_out": None, "aoti_torch_mps_convolution": None, @@ -36,12 +37,21 @@ def get_supported_fallback_kernels() -> Dict[str, Any]: "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, } - @staticmethod - def get_decomposition_table() -> Dict[Any, Any]: + @classmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: return {} - @staticmethod - def get_aoti_compile_options() -> Dict[str, typing.Any]: + @classmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return Metal-specific passes (currently none)""" + return [] + + @classmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """Get AOTI compile options for Metal backend.""" + _ = compile_specs # Unused, but required by interface return { # Do not link against the full PyTorch/libtorch library "aot_inductor.link_libtorch": False, diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index e9a2d594adb..cc2d662b335 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -5,12 +5,17 @@ # LICENSE file in the root directory of this source tree. import typing -from typing import Any, Dict, final +from importlib import resources +from typing import Any, Dict, final, List import torch from executorch.backends.aoti.aoti_backend import AotiBackend +from executorch.backends.cuda.triton.replacement_pass import ( + ReplaceEdgeOpWithTritonOpPass, +) from executorch.exir._warnings import experimental from executorch.exir.backend.backend_details import BackendDetails +from executorch.exir.backend.compile_spec_schema import CompileSpec from torch._inductor.decomposition import conv1d_to_conv2d @@ -25,25 +30,37 @@ class CudaBackend(AotiBackend, BackendDetails): using the Executorch runtime. """ - @staticmethod - def get_device_name() -> str: + @classmethod + def get_device_name(cls) -> str: return "cuda" - @staticmethod - def get_supported_fallback_kernels() -> Dict[str, Any]: + @classmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "at::_ops::_weight_int4pack_mm::call": None, } - @staticmethod - def get_decomposition_table() -> Dict[Any, Any]: + @classmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: return { torch.ops.aten.conv1d.default: conv1d_to_conv2d, } - @staticmethod - def get_aoti_compile_options() -> Dict[str, typing.Any]: - return { + @classmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass""" + return [ReplaceEdgeOpWithTritonOpPass()] + + @classmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """ + Get AOTI compile options for CUDA backend. + Options may vary based on platform (Linux vs Windows). + """ + # Base options for all platforms + options: Dict[str, typing.Any] = { # Disable this to support sdpa decomposition # TODO(gasoonjia): remove it after pin bump to latest pytorch "loop_ordering_after_fusion": False, @@ -65,3 +82,35 @@ def get_aoti_compile_options() -> Dict[str, typing.Any]: # Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch "max_autotune_conv_backends": "TRITON", } + + # Parse compile_specs to check for platform + platform = "linux" + shim_library_path = None + for spec in compile_specs: + if spec.key == "platform": + platform = spec.value.decode("utf-8") + if spec.key == "shim_library_path": + shim_library_path = spec.value.decode("utf-8") + + # Add platform-specific options + if platform == "windows": + # For Windows, get default shim library path if not provided + if shim_library_path is None: + lib_dir = resources.files("executorch").joinpath("data/lib") + shim_library_path = str(lib_dir) + + options.update( + { + "aot_inductor.cross_target_platform": "windows", + "aot_inductor.aoti_shim_library": "aoti_cuda_shims", + "aot_inductor.aoti_shim_library_path": shim_library_path, + "aot_inductor.precompile_headers": False, + } + ) + else: + # Linux platform + assert ( + shim_library_path is None + ), "shim_library_path should not be set for Linux" + + return options From 48e8c86afbf8a684a4a1789dd39f6de7d036b2a8 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 20 Nov 2025 17:05:45 -0800 Subject: [PATCH 14/15] revert extra changes --- extension/llm/tokenizers | 2 +- setup.py | 18 ------------------ 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/extension/llm/tokenizers b/extension/llm/tokenizers index d710a0cf10c..3aada3fe28c 160000 --- a/extension/llm/tokenizers +++ b/extension/llm/tokenizers @@ -1 +1 @@ -Subproject commit d710a0cf10cfa8cb7ffda33c4e61af63119bc95f +Subproject commit 3aada3fe28c945d14d5ec62254eb56ccdf10eb11 diff --git a/setup.py b/setup.py index 735840a29c3..71aa4c543d4 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,6 @@ # imports. import logging import os -import platform import re import shutil import site @@ -91,23 +90,6 @@ def _is_windows() -> bool: return sys.platform == "win32" -# Duplicate of the one in backends.qualcomm.scripts.download_qnn_sdk to avoid -# import errors. -def is_linux_x86() -> bool: - """ - Check if the current platform is Linux x86_64. - - Returns: - bool: True if the system is Linux x86_64, False otherwise. - """ - return platform.system().lower() == "linux" and platform.machine().lower() in ( - "x86_64", - "amd64", - "i386", - "i686", - ) - - class Version: """Static strings that describe the version of the pip package.""" From 57e8ffefd3dbfebaa545ee65488ffe33776666c5 Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Thu, 20 Nov 2025 17:09:16 -0800 Subject: [PATCH 15/15] Update backends/aoti/aoti_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- backends/aoti/aoti_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index cea10fc5b05..2d396a296bd 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -37,7 +37,7 @@ class AotiBackend(ABC): This class provides common functionality for compiling models using AOTInductor with different device targets (CUDA, Metal, etc.). - This is a mixin class, not an actual backend object, for aoti-driven backens. + This is a mixin class, not an actual backend object, for aoti-driven backends. Concrete backends (e.g., CudaBackend, MetalBackend) should inherit from both BackendDetails and AotiBackend to get the full functionality. """