diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 49314bed5e6..a39065f6a52 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -7,10 +7,14 @@ import contextlib import os import typing +from enum import Enum from typing import Any, Dict, final, List, Optional, Set import torch +from executorch.backends.cuda.replace_slice_copy_with_slice import ( + ReplaceSliceCopyWithSlicePass, +) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental from executorch.exir.backend.backend_details import ( @@ -21,7 +25,7 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch.export.passes import move_to_device_pass - +from torch.nn.attention import SDPBackend # exist fallback operators in et namespace; supported_fallback_kernels: Dict[str, Any] = {} @@ -30,6 +34,10 @@ missing_fallback_kernels: Set[str] = set() +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + # context manager for non-fallback guarantee # it will raise exception when generating fallback kernels during aoti compile @contextlib.contextmanager @@ -108,6 +116,9 @@ def preprocess( # Move the edge_program from CPU to CUDA for aoti compile cuda_edge_program = move_to_device_pass(edge_program, "cuda") + # replace slice_copy with slice + ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module) + edge_program_module = cuda_edge_program.module() # Grab all input placeholders from the graph @@ -132,7 +143,10 @@ def preprocess( "max_autotune_conv_backends": "TRITON", } - with collect_unsupported_fallback_kernels(): + with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel( + [SDPBackend.MATH] + ), torch.no_grad(): + # torch._logging.set_logs(post_grad_graphs=True) so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] if len(missing_fallback_kernels) > 0: formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) @@ -146,7 +160,10 @@ def preprocess( so_data = f.read() named_data_store = NamedDataStore() - named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob") + method_name = CudaBackend.method_name_from_compile_specs(compile_specs) + named_data_store.add_named_data( + method_name + "_so_blob", so_data, 1, "aoti_cuda_blob" + ) # Clean up the generated so file; it has been packaged into the NamdeDataStore # pyre-ignorep[6]: Incompatible parameter type @@ -157,3 +174,30 @@ def preprocess( debug_handle_map={}, data_store_output=named_data_store.get_named_data_store_output(), ) + + @staticmethod + def generate_method_name_compile_spec( + method_name: str, + ) -> CompileSpec: + """ + Returns the compile spec representing the model compute precision, for additional details + please refer to the documentation for ``coremltools.precision``. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @staticmethod + def method_name_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> str: + """ + Returns the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/cuda/cuda_partitioner.py b/backends/cuda/cuda_partitioner.py index d52d7d3d087..14c75bdb937 100644 --- a/backends/cuda/cuda_partitioner.py +++ b/backends/cuda/cuda_partitioner.py @@ -44,12 +44,14 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: """ partition_tags: Dict[str, DelegationSpec] = {} + tag = "tag0" + for node in exported_program.graph.nodes: if node.op != "call_function": continue - tag = "tag0" node.meta["delegation_tag"] = tag - partition_tags[tag] = self.delegation_spec + + partition_tags[tag] = self.delegation_spec tag_constant_data(exported_program) diff --git a/backends/cuda/replace_slice_copy_with_slice.py b/backends/cuda/replace_slice_copy_with_slice.py new file mode 100644 index 00000000000..55ddef5de9b --- /dev/null +++ b/backends/cuda/replace_slice_copy_with_slice.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Iterable + +import torch +from executorch.exir.dialects._ops import ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + + +_SLICE_COPY_TARGETS = ( + torch.ops.aten.slice_copy.Tensor, + ops.edge.aten.slice_copy.Tensor, +) + +_SLICE_TARGETS = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, +} + + +class ReplaceSliceCopyWithSlicePass(ExportPass): + """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" + + def call(self, graph_module: fx.GraphModule) -> PassResult: + graph_changed = False + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: + continue + + if self._has_blocking_user(node, node.users.keys()): + continue + + node.target = _SLICE_TARGETS[node.target] + graph_changed = True + + if graph_changed: + graph_module.graph.lint() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) + + def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: + for user in users: + if self._is_mutating_user(node, user) or self._is_view_user(node, user): + return True + return False + + def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat in-place tensor methods conservatively as mutations only when the + # method name ends with ``_`` which is the PyTorch convention for mutation. + return isinstance(user.target, str) and user.target.endswith("_") + + if user.op != "call_function": + return False + + target = user.target + if not hasattr(target, "_schema"): + return False + + schema = target._schema # pyre-ignore[16] + # Positional arguments + for index, arg in enumerate(user.args): + if arg is node and self._argument_mutates(schema, index): + return True + + # Keyword arguments + for name, arg in user.kwargs.items(): + if arg is node and self._argument_mutates(schema, name): + return True + + return False + + def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat tensor methods conservatively and assume they may be view-producing. + return True + + if user.op != "call_function": + return False + + target = user.target + if getattr(target, "is_view", False): + for arg in user.args: + if arg is node: + return True + for arg in user.kwargs.values(): + if arg is node: + return True + + return False + + def _argument_mutates( + self, schema: torch._C.FunctionSchema, key + ) -> bool: # pyre-ignore[11] + arguments = schema.arguments + if isinstance(key, int): + if key >= len(arguments): + return False + argument = arguments[key] + else: + argument = next((arg for arg in arguments if arg.name == key), None) + if argument is None: + return False + + alias_info = argument.alias_info + return bool(alias_info and alias_info.is_write)