Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
strategy:
fail-fast: false
matrix:
model: [linear, add, add_mul, resnet18, conv1d]
model: [linear, add, add_mul, resnet18, conv1d, sdpa]
with:
timeout: 90
runner: linux.g5.4xlarge.nvidia.gpu
Expand Down
31 changes: 31 additions & 0 deletions backends/cuda/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ runtime.python_library(
"//executorch/...",
],
deps = [
":triton_replacement_pass",
"//caffe2:torch",
"//executorch/backends/aoti/passes:passes",
"//executorch/exir/_serialize:lib",
Expand All @@ -32,3 +33,33 @@ runtime.python_library(
"//executorch/backends/aoti:aoti_partitioner",
],
)

runtime.python_library(
name = "triton_kernels",
srcs = [
"triton/kernels/__init__.py",
"triton/kernels/sdpa.py",
],
visibility = [
"//executorch/backends/cuda/...",
],
deps = [
"//caffe2:torch",
],
)

runtime.python_library(
name = "triton_replacement_pass",
srcs = [
"triton/__init__.py",
"triton/replacement_pass.py",
],
visibility = [
"//executorch/...",
],
deps = [
":triton_kernels",
"//caffe2:torch",
"//executorch/exir/dialects:lib",
],
)
15 changes: 9 additions & 6 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
ReplaceViewCopyWithViewPass,
)

from executorch.backends.cuda.triton.replacement_pass import (
ReplaceEdgeOpWithTritonOpPass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_details import (
Expand All @@ -27,7 +31,7 @@
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch._inductor.decomposition import conv1d_to_conv2d
from torch.export.passes import move_to_device_pass
from torch.nn.attention import SDPBackend


cuda_decomposition_table = {
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
Expand Down Expand Up @@ -127,6 +131,9 @@ def preprocess( # noqa: C901
# replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int
ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module)

# Replace aten ops with triton ops
ReplaceEdgeOpWithTritonOpPass()(cuda_edge_program.graph_module)

cuda_edge_program = cuda_edge_program.run_decompositions(
cuda_decomposition_table
)
Expand Down Expand Up @@ -188,11 +195,7 @@ def preprocess( # noqa: C901
}
)

with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
[
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
]
), torch.no_grad():
with collect_unsupported_fallback_kernels(), torch.no_grad():
# torch._logging.set_logs(post_grad_graphs=True)
# Here we should expect 1 so file and 1 weight blob in the same directory.
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python_unittest_remote_gpu(
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/examples/models/toy_model:toy_model",
],
keep_gpu_sections = True,
)
Expand Down
18 changes: 18 additions & 0 deletions backends/cuda/tests/test_cuda_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from executorch.backends.cuda.cuda_backend import CudaBackend
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.examples.models.toy_model import SdpaModule
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from torch.export import export

Expand Down Expand Up @@ -270,3 +271,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed")

def test_sdpa_single_kernel(self):
"""
Test CUDA export for model containing single SDPA kernel.
SDPA: Scaled Dot Product Attention
"""

sdpa = SdpaModule()

# Test export
edge_program_manager = self._export_to_cuda_with_lower(
sdpa.get_eager_model(), sdpa.get_example_inputs()
)
self.assertIsNotNone(
edge_program_manager,
"SDPA single kernel operation export failed",
)
17 changes: 17 additions & 0 deletions backends/cuda/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Import all kernels to ensure @triton_op decorators are executed
# and ops are registered to torch.ops.triton namespace
from executorch.backends.cuda.triton import kernels # noqa: F401

from executorch.backends.cuda.triton.replacement_pass import (
ReplaceEdgeOpWithTritonOpPass,
)

__all__ = [
"ReplaceEdgeOpWithTritonOpPass",
]
11 changes: 11 additions & 0 deletions backends/cuda/triton/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.cuda.triton.kernels.sdpa import sdpa

__all__ = [
"sdpa",
]
Loading
Loading