Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions backends/aoti/aoti_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,15 @@ def get_aoti_compile_options(

@classmethod
@abstractmethod
def get_custom_passes(cls) -> List[typing.Any]:
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
"""Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition."""
pass

@classmethod
def get_extra_aoti_compile_context_manager(cls):
"""Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager."""
return contextlib.nullcontext()

@classmethod
@contextlib.contextmanager
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
Expand Down Expand Up @@ -149,7 +154,7 @@ def preprocess(
ReplaceViewCopyWithViewPass()(device_edge_program.graph_module)

# Apply custom backend-specific passes
custom_passes = cls.get_custom_passes()
custom_passes = cls.get_custom_passes(compile_specs)
for custom_pass in custom_passes:
custom_pass(device_edge_program.graph_module)

Expand All @@ -174,7 +179,7 @@ def preprocess(
# Compile with fallback kernel collection
with cls.collect_unsupported_fallback_kernels(
missing_fallback_kernels
), torch.no_grad():
), torch.no_grad(), cls.get_extra_aoti_compile_context_manager():
paths = torch._inductor.aot_compile(
edge_program_module, tuple(user_input_placeholders), options=options
)
Expand Down
2 changes: 1 addition & 1 deletion backends/apple/metal/metal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
return {}

@classmethod
def get_custom_passes(cls) -> List[typing.Any]:
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
"""Return Metal-specific passes (currently none)"""
return []

Expand Down
43 changes: 40 additions & 3 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from executorch.exir.backend.backend_details import BackendDetails
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch._inductor.decomposition import conv1d_to_conv2d
from torch.nn.attention import SDPBackend


@final
Expand Down Expand Up @@ -47,9 +48,27 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
}

@classmethod
def get_custom_passes(cls) -> List[typing.Any]:
"""Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass"""
return [ReplaceEdgeOpWithTritonOpPass()]
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
"""
Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass.

The Triton kernel replacement behavior can be controlled via compile_specs:
- triton_kernel_mode="ON": Always use Triton kernels
- triton_kernel_mode="OFF": Never use Triton kernels and fallback to other implementations like cuda or decomposed operator.
"""
# Parse compile_specs for triton_kernel_mode
triton_kernel_mode = "ON" # Default mode
for spec in compile_specs:
if spec.key == "triton_kernel_mode":
mode = spec.value.decode("utf-8").upper()
if mode not in ["ON", "OFF"]:
raise ValueError(
f"Invalid triton_kernel_mode: {mode}. "
f"Expected 'ON' or 'OFF'."
)
triton_kernel_mode = mode

return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []

@classmethod
def get_aoti_compile_options(
Expand Down Expand Up @@ -114,3 +133,21 @@ def get_aoti_compile_options(
), "shim_library_path should not be set for Linux"

return options

@classmethod
def get_extra_aoti_compile_context_manager(cls):
"""
Return SDPA MATH backend context manager for CUDA compilation.

This context manager plays as a fallback solution for any remaining PyTorch SDPA
operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.

Note:
- If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
this context manager will have no effect on those ops (they are no longer
PyTorch SDPA ops).
- If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
context manager will force them to use the MATH backend, causing them to
be automatically decomposed during compilation.
"""
return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])
49 changes: 43 additions & 6 deletions backends/cuda/tests/test_cuda_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.examples.models.toy_model import SdpaModule
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export import export


Expand All @@ -25,16 +26,27 @@ def setUp(self):
self.skipTest("CUDA is not available")

def _export_to_cuda_with_lower(
self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...]
self,
module: torch.nn.Module,
inputs: Tuple[torch.Tensor, ...],
compile_specs: list[CompileSpec] | None = None,
) -> None:
"""Helper method to export a module to CUDA backend using to_edge_transform_and_lower."""
"""Helper method to export a module to CUDA backend using to_edge_transform_and_lower.

Args:
module: The torch.nn.Module to export
inputs: The example inputs for the module
compile_specs: Optional list of compile specs. If not provided, defaults to
only the method name compile spec for "forward"
"""
# Export the model
exported_program = export(module, inputs, strict=True)

# Create partitioner and compile specs
partitioner = CudaPartitioner(
[CudaBackend.generate_method_name_compile_spec("forward")]
)
# Create partitioner with compile specs
if compile_specs is None:
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]

partitioner = CudaPartitioner(compile_specs)

# Use to_edge_transform_and_lower for complete pipeline
edge_program_manager = to_edge_transform_and_lower(
Expand Down Expand Up @@ -288,3 +300,28 @@ def test_sdpa_single_kernel(self):
edge_program_manager,
"SDPA single kernel operation export failed",
)

def test_triton_kernel_mode_off(self):
"""
Test CUDA export with triton_kernel_mode set to OFF for SDPA kernel.
This validates that the backend correctly processes the triton_kernel_mode
compile spec and can export SDPA operations without Triton kernel replacements.
When triton_kernel_mode is OFF, SDPA should be decomposed using the MATH backend.
"""

sdpa = SdpaModule()

# Create compile specs with triton_kernel_mode set to OFF
compile_specs = [
CudaBackend.generate_method_name_compile_spec("forward"),
CompileSpec(key="triton_kernel_mode", value=b"OFF"),
]

# Test export with triton_kernel_mode=OFF
edge_program_manager = self._export_to_cuda_with_lower(
sdpa.get_eager_model(), sdpa.get_example_inputs(), compile_specs
)
self.assertIsNotNone(
edge_program_manager,
"SDPA kernel export with triton_kernel_mode=OFF failed",
)
Loading