Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cuda/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ runtime.python_library(
name = "cuda_backend",
srcs = [
"cuda_backend.py",
"replace_slice_copy_with_slice.py",
],
visibility = [
"//executorch/...",
Expand Down
4 changes: 3 additions & 1 deletion backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def preprocess(
}

with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
[SDPBackend.MATH]
[
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
]
), torch.no_grad():
# torch._logging.set_logs(post_grad_graphs=True)
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
Expand Down
13 changes: 8 additions & 5 deletions backends/cuda/replace_slice_copy_with_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@

# pyre-strict

from typing import Iterable
from typing import Dict, Iterable, Tuple

import torch
from executorch.exir.dialects._ops import ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from torch import fx


_SLICE_COPY_TARGETS = (
_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
torch.ops.aten.slice_copy.Tensor,
ops.edge.aten.slice_copy.Tensor,
)

_SLICE_TARGETS = {
_SLICE_TARGETS: Dict[
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
] = {
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
}
Expand Down Expand Up @@ -99,8 +102,8 @@ def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
return False

def _argument_mutates(
self, schema: torch._C.FunctionSchema, key
) -> bool: # pyre-ignore[11]
self, schema: torch._C.FunctionSchema, key: int | str
) -> bool:
arguments = schema.arguments
if isinstance(key, int):
if key >= len(arguments):
Expand Down
5 changes: 4 additions & 1 deletion backends/cuda/tests/test_cuda_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Tuple

import torch
from executorch.backends.cuda.cuda_backend import CudaBackend
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from torch.export import export
Expand All @@ -30,7 +31,9 @@ def _export_to_cuda_with_lower(
exported_program = export(module, inputs, strict=True)

# Create partitioner and compile specs
partitioner = CudaPartitioner([])
partitioner = CudaPartitioner(
[CudaBackend.generate_method_name_compile_spec("forward")]
)

# Use to_edge_transform_and_lower for complete pipeline
edge_program_manager = to_edge_transform_and_lower(
Expand Down
Loading