diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index 3e412b6dc56..fe57f7f1b63 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -6,6 +6,7 @@ runtime.python_library( name = "cuda_backend", srcs = [ "cuda_backend.py", + "replace_slice_copy_with_slice.py", ], visibility = [ "//executorch/...", diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index a39065f6a52..8ed8cdefbb1 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -144,7 +144,9 @@ def preprocess( } with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel( - [SDPBackend.MATH] + [ + SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. + ] ), torch.no_grad(): # torch._logging.set_logs(post_grad_graphs=True) so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] diff --git a/backends/cuda/replace_slice_copy_with_slice.py b/backends/cuda/replace_slice_copy_with_slice.py index 55ddef5de9b..4f16759af35 100644 --- a/backends/cuda/replace_slice_copy_with_slice.py +++ b/backends/cuda/replace_slice_copy_with_slice.py @@ -6,20 +6,23 @@ # pyre-strict -from typing import Iterable +from typing import Dict, Iterable, Tuple import torch from executorch.exir.dialects._ops import ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult from torch import fx -_SLICE_COPY_TARGETS = ( +_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( torch.ops.aten.slice_copy.Tensor, ops.edge.aten.slice_copy.Tensor, ) -_SLICE_TARGETS = { +_SLICE_TARGETS: Dict[ + torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload +] = { torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, } @@ -99,8 +102,8 @@ def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: return False def _argument_mutates( - self, schema: torch._C.FunctionSchema, key - ) -> bool: # pyre-ignore[11] + self, schema: torch._C.FunctionSchema, key: int | str + ) -> bool: arguments = schema.arguments if isinstance(key, int): if key >= len(arguments): diff --git a/backends/cuda/tests/test_cuda_export.py b/backends/cuda/tests/test_cuda_export.py index 99f8d33a766..d794a4f042c 100644 --- a/backends/cuda/tests/test_cuda_export.py +++ b/backends/cuda/tests/test_cuda_export.py @@ -8,6 +8,7 @@ from typing import Tuple import torch +from executorch.backends.cuda.cuda_backend import CudaBackend from executorch.backends.cuda.cuda_partitioner import CudaPartitioner from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from torch.export import export @@ -30,7 +31,9 @@ def _export_to_cuda_with_lower( exported_program = export(module, inputs, strict=True) # Create partitioner and compile specs - partitioner = CudaPartitioner([]) + partitioner = CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("forward")] + ) # Use to_edge_transform_and_lower for complete pipeline edge_program_manager = to_edge_transform_and_lower(