From 8c3ec9eddb74d0d7e81f55bf253f6a2d0cbf94a5 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 30 Sep 2025 23:36:58 -0700 Subject: [PATCH 1/2] Make it work --- backends/cuda/cuda_backend.py | 42 ++++++- backends/cuda/cuda_partitioner.py | 6 +- .../cuda/replace_slice_copy_with_slice.py | 113 ++++++++++++++++++ 3 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 backends/cuda/replace_slice_copy_with_slice.py diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 49314bed5e6..437b2ed9cc6 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -7,10 +7,12 @@ import contextlib import os import typing +from enum import Enum from typing import Any, Dict, final, List, Optional, Set import torch +from executorch.backends.cuda.replace_slice_copy_with_slice import ReplaceSliceCopyWithSlicePass from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental from executorch.exir.backend.backend_details import ( @@ -21,7 +23,7 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch.export.passes import move_to_device_pass - +from torch.nn.attention import SDPBackend # exist fallback operators in et namespace; supported_fallback_kernels: Dict[str, Any] = {} @@ -29,6 +31,8 @@ # required fallback kernels but not supported missing_fallback_kernels: Set[str] = set() +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" # context manager for non-fallback guarantee # it will raise exception when generating fallback kernels during aoti compile @@ -108,6 +112,9 @@ def preprocess( # Move the edge_program from CPU to CUDA for aoti compile cuda_edge_program = move_to_device_pass(edge_program, "cuda") + # replace slice_copy with slice + ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module) + edge_program_module = cuda_edge_program.module() # Grab all input placeholders from the graph @@ -132,7 +139,8 @@ def preprocess( "max_autotune_conv_backends": "TRITON", } - with collect_unsupported_fallback_kernels(): + with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + torch._logging.set_logs(post_grad_graphs=True) so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] if len(missing_fallback_kernels) > 0: formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) @@ -146,7 +154,8 @@ def preprocess( so_data = f.read() named_data_store = NamedDataStore() - named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob") + method_name = CudaBackend.method_name_from_compile_specs(compile_specs) + named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, "aoti_cuda_blob") # Clean up the generated so file; it has been packaged into the NamdeDataStore # pyre-ignorep[6]: Incompatible parameter type @@ -157,3 +166,30 @@ def preprocess( debug_handle_map={}, data_store_output=named_data_store.get_named_data_store_output(), ) + + @staticmethod + def generate_method_name_compile_spec( + method_name: str, + ) -> CompileSpec: + """ + Returns the compile spec representing the model compute precision, for additional details + please refer to the documentation for ``coremltools.precision``. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @staticmethod + def method_name_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> str: + """ + Returns the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/cuda/cuda_partitioner.py b/backends/cuda/cuda_partitioner.py index d52d7d3d087..14c75bdb937 100644 --- a/backends/cuda/cuda_partitioner.py +++ b/backends/cuda/cuda_partitioner.py @@ -44,12 +44,14 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: """ partition_tags: Dict[str, DelegationSpec] = {} + tag = "tag0" + for node in exported_program.graph.nodes: if node.op != "call_function": continue - tag = "tag0" node.meta["delegation_tag"] = tag - partition_tags[tag] = self.delegation_spec + + partition_tags[tag] = self.delegation_spec tag_constant_data(exported_program) diff --git a/backends/cuda/replace_slice_copy_with_slice.py b/backends/cuda/replace_slice_copy_with_slice.py new file mode 100644 index 00000000000..0eaadf66f5c --- /dev/null +++ b/backends/cuda/replace_slice_copy_with_slice.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Iterable + +import torch +from executorch.exir.dialects._ops import ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + + +_SLICE_COPY_TARGETS = ( + torch.ops.aten.slice_copy.Tensor, + ops.edge.aten.slice_copy.Tensor, +) + +_SLICE_TARGETS = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, +} + + +class ReplaceSliceCopyWithSlicePass(ExportPass): + """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" + + def call(self, graph_module: fx.GraphModule) -> PassResult: + graph_changed = False + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: + continue + + if self._has_blocking_user(node, node.users.keys()): + continue + + node.target = _SLICE_TARGETS[node.target] + graph_changed = True + + if graph_changed: + graph_module.graph.lint() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) + + def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: + for user in users: + if self._is_mutating_user(node, user) or self._is_view_user(node, user): + return True + return False + + def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat in-place tensor methods conservatively as mutations only when the + # method name ends with ``_`` which is the PyTorch convention for mutation. + return isinstance(user.target, str) and user.target.endswith("_") + + if user.op != "call_function": + return False + + target = user.target + if not hasattr(target, "_schema"): + return False + + schema = target._schema # pyre-ignore[16] + # Positional arguments + for index, arg in enumerate(user.args): + if arg is node and self._argument_mutates(schema, index): + return True + + # Keyword arguments + for name, arg in user.kwargs.items(): + if arg is node and self._argument_mutates(schema, name): + return True + + return False + + def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat tensor methods conservatively and assume they may be view-producing. + return True + + if user.op != "call_function": + return False + + target = user.target + if getattr(target, "is_view", False): + for arg in user.args: + if arg is node: + return True + for arg in user.kwargs.values(): + if arg is node: + return True + + return False + + def _argument_mutates(self, schema: torch._C.FunctionSchema, key) -> bool: # pyre-ignore[11] + arguments = schema.arguments + if isinstance(key, int): + if key >= len(arguments): + return False + argument = arguments[key] + else: + argument = next((arg for arg in arguments if arg.name == key), None) + if argument is None: + return False + + alias_info = argument.alias_info + return bool(alias_info and alias_info.is_write) From 5a40be7efc3cd4b3bf20df9884a3d4e171b392a2 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 1 Oct 2025 15:26:44 -0700 Subject: [PATCH 2/2] Address comments --- backends/cuda/cuda_backend.py | 16 ++++++++++++---- backends/cuda/replace_slice_copy_with_slice.py | 4 +++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 437b2ed9cc6..a39065f6a52 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -12,7 +12,9 @@ from typing import Any, Dict, final, List, Optional, Set import torch -from executorch.backends.cuda.replace_slice_copy_with_slice import ReplaceSliceCopyWithSlicePass +from executorch.backends.cuda.replace_slice_copy_with_slice import ( + ReplaceSliceCopyWithSlicePass, +) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental from executorch.exir.backend.backend_details import ( @@ -31,9 +33,11 @@ # required fallback kernels but not supported missing_fallback_kernels: Set[str] = set() + class COMPILE_SPEC_KEYS(Enum): METHOD_NAME = "method_name" + # context manager for non-fallback guarantee # it will raise exception when generating fallback kernels during aoti compile @contextlib.contextmanager @@ -139,8 +143,10 @@ def preprocess( "max_autotune_conv_backends": "TRITON", } - with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - torch._logging.set_logs(post_grad_graphs=True) + with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel( + [SDPBackend.MATH] + ), torch.no_grad(): + # torch._logging.set_logs(post_grad_graphs=True) so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] if len(missing_fallback_kernels) > 0: formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) @@ -155,7 +161,9 @@ def preprocess( named_data_store = NamedDataStore() method_name = CudaBackend.method_name_from_compile_specs(compile_specs) - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, "aoti_cuda_blob") + named_data_store.add_named_data( + method_name + "_so_blob", so_data, 1, "aoti_cuda_blob" + ) # Clean up the generated so file; it has been packaged into the NamdeDataStore # pyre-ignorep[6]: Incompatible parameter type diff --git a/backends/cuda/replace_slice_copy_with_slice.py b/backends/cuda/replace_slice_copy_with_slice.py index 0eaadf66f5c..55ddef5de9b 100644 --- a/backends/cuda/replace_slice_copy_with_slice.py +++ b/backends/cuda/replace_slice_copy_with_slice.py @@ -98,7 +98,9 @@ def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: return False - def _argument_mutates(self, schema: torch._C.FunctionSchema, key) -> bool: # pyre-ignore[11] + def _argument_mutates( + self, schema: torch._C.FunctionSchema, key + ) -> bool: # pyre-ignore[11] arguments = schema.arguments if isinstance(key, int): if key >= len(arguments):