Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
import contextlib
import os
import typing
from enum import Enum

from typing import Any, Dict, final, List, Optional, Set

import torch
from executorch.backends.cuda.replace_slice_copy_with_slice import (
ReplaceSliceCopyWithSlicePass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_details import (
Expand All @@ -21,7 +25,7 @@
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch.export.passes import move_to_device_pass

from torch.nn.attention import SDPBackend

# exist fallback operators in et namespace;
supported_fallback_kernels: Dict[str, Any] = {}
Expand All @@ -30,6 +34,10 @@
missing_fallback_kernels: Set[str] = set()


class COMPILE_SPEC_KEYS(Enum):
METHOD_NAME = "method_name"


# context manager for non-fallback guarantee
# it will raise exception when generating fallback kernels during aoti compile
@contextlib.contextmanager
Expand Down Expand Up @@ -108,6 +116,9 @@ def preprocess(
# Move the edge_program from CPU to CUDA for aoti compile
cuda_edge_program = move_to_device_pass(edge_program, "cuda")

# replace slice_copy with slice
ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems pretty hacky to run the force functionalization pass and then come through and undo it (but only for slice). Wont you in practice have to do this for all view ops?

Does AOTI lowering typically happen on functionalized IR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inductor's reinplace pass reverts most of the functionalization. I don't think it handles slice_copy though, since it comes from this pass: https://github.com/pytorch/executorch/blob/main/exir/passes/replace_broken_ops_with_function_ops_pass.py#L13

The other option we can do is to optionally run this pass in to_edge().


edge_program_module = cuda_edge_program.module()

# Grab all input placeholders from the graph
Expand All @@ -132,7 +143,10 @@ def preprocess(
"max_autotune_conv_backends": "TRITON",
}

with collect_unsupported_fallback_kernels():
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
[SDPBackend.MATH]
), torch.no_grad():
# torch._logging.set_logs(post_grad_graphs=True)
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
if len(missing_fallback_kernels) > 0:
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
Expand All @@ -146,7 +160,10 @@ def preprocess(
so_data = f.read()

named_data_store = NamedDataStore()
named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob")
method_name = CudaBackend.method_name_from_compile_specs(compile_specs)
named_data_store.add_named_data(
method_name + "_so_blob", so_data, 1, "aoti_cuda_blob"
)

# Clean up the generated so file; it has been packaged into the NamdeDataStore
# pyre-ignorep[6]: Incompatible parameter type
Expand All @@ -157,3 +174,30 @@ def preprocess(
debug_handle_map={},
data_store_output=named_data_store.get_named_data_store_output(),
)

@staticmethod
def generate_method_name_compile_spec(
method_name: str,
) -> CompileSpec:
"""
Returns the compile spec representing the model compute precision, for additional details
please refer to the documentation for ``coremltools.precision``.
"""
return CompileSpec(
COMPILE_SPEC_KEYS.METHOD_NAME.value,
method_name.encode("utf-8"),
)

@staticmethod
def method_name_from_compile_specs(
compile_specs: List[CompileSpec],
) -> str:
"""
Returns the method name from the compile specs.
"""
for spec in compile_specs:
if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value:
return spec.value.decode("utf-8")
raise RuntimeError(
f"Could not find method name in compile specs: {compile_specs}"
)
6 changes: 4 additions & 2 deletions backends/cuda/cuda_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""

partition_tags: Dict[str, DelegationSpec] = {}
tag = "tag0"

for node in exported_program.graph.nodes:
if node.op != "call_function":
continue
tag = "tag0"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)

Expand Down
115 changes: 115 additions & 0 deletions backends/cuda/replace_slice_copy_with_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Iterable

import torch
from executorch.exir.dialects._ops import ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch import fx


_SLICE_COPY_TARGETS = (
torch.ops.aten.slice_copy.Tensor,
ops.edge.aten.slice_copy.Tensor,
)

_SLICE_TARGETS = {
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
}


class ReplaceSliceCopyWithSlicePass(ExportPass):
"""Replace non-mutated ``slice_copy`` results with ``slice`` views."""

def call(self, graph_module: fx.GraphModule) -> PassResult:
graph_changed = False

for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS:
continue

if self._has_blocking_user(node, node.users.keys()):
continue

node.target = _SLICE_TARGETS[node.target]
graph_changed = True

if graph_changed:
graph_module.graph.lint()
graph_module.recompile()

return PassResult(graph_module, graph_changed)

def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool:
for user in users:
if self._is_mutating_user(node, user) or self._is_view_user(node, user):
return True
return False

def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool:
if user.op == "call_method":
# Treat in-place tensor methods conservatively as mutations only when the
# method name ends with ``_`` which is the PyTorch convention for mutation.
return isinstance(user.target, str) and user.target.endswith("_")

if user.op != "call_function":
return False

target = user.target
if not hasattr(target, "_schema"):
return False

schema = target._schema # pyre-ignore[16]
# Positional arguments
for index, arg in enumerate(user.args):
if arg is node and self._argument_mutates(schema, index):
return True

# Keyword arguments
for name, arg in user.kwargs.items():
if arg is node and self._argument_mutates(schema, name):
return True

return False

def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
if user.op == "call_method":
# Treat tensor methods conservatively and assume they may be view-producing.
return True

if user.op != "call_function":
return False

target = user.target
if getattr(target, "is_view", False):
for arg in user.args:
if arg is node:
return True
for arg in user.kwargs.values():
if arg is node:
return True

return False

def _argument_mutates(
self, schema: torch._C.FunctionSchema, key
) -> bool: # pyre-ignore[11]
arguments = schema.arguments
if isinstance(key, int):
if key >= len(arguments):
return False
argument = arguments[key]
else:
argument = next((arg for arg in arguments if arg.name == key), None)
if argument is None:
return False

alias_info = argument.alias_info
return bool(alias_info and alias_info.is_write)
Loading