From a84c6ea35987ca7cf75f468f1e1baebeb2d7c5db Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 2 Dec 2025 00:05:09 -0800 Subject: [PATCH] make triton kernel usage user controlable (#16030) Summary: We observed perf regression on some models after using triton sdpa kernel, compare with decomposed operator. This diff makes triton kernel controalable by user to prevent such perf regression. Reviewed By: larryliu0820 Differential Revision: D88096054 --- backends/aoti/aoti_backend.py | 11 ++++-- backends/apple/metal/metal_backend.py | 2 +- backends/cuda/cuda_backend.py | 43 ++++++++++++++++++++-- backends/cuda/tests/test_cuda_export.py | 49 ++++++++++++++++++++++--- 4 files changed, 92 insertions(+), 13 deletions(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 7742cd53c2d..c2c587da9fe 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -70,10 +70,15 @@ def get_aoti_compile_options( @classmethod @abstractmethod - def get_custom_passes(cls) -> List[typing.Any]: + def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: """Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition.""" pass + @classmethod + def get_extra_aoti_compile_context_manager(cls): + """Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager.""" + return contextlib.nullcontext() + @classmethod @contextlib.contextmanager def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]): @@ -149,7 +154,7 @@ def preprocess( ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) # Apply custom backend-specific passes - custom_passes = cls.get_custom_passes() + custom_passes = cls.get_custom_passes(compile_specs) for custom_pass in custom_passes: custom_pass(device_edge_program.graph_module) @@ -174,7 +179,7 @@ def preprocess( # Compile with fallback kernel collection with cls.collect_unsupported_fallback_kernels( missing_fallback_kernels - ), torch.no_grad(): + ), torch.no_grad(), cls.get_extra_aoti_compile_context_manager(): paths = torch._inductor.aot_compile( edge_program_module, tuple(user_input_placeholders), options=options ) diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 1b27b027fc2..1d86cfb8447 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -42,7 +42,7 @@ def get_decomposition_table(cls) -> Dict[Any, Any]: return {} @classmethod - def get_custom_passes(cls) -> List[typing.Any]: + def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: """Return Metal-specific passes (currently none)""" return [] diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index cc2d662b335..f0d3a000ec0 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -17,6 +17,7 @@ from executorch.exir.backend.backend_details import BackendDetails from executorch.exir.backend.compile_spec_schema import CompileSpec from torch._inductor.decomposition import conv1d_to_conv2d +from torch.nn.attention import SDPBackend @final @@ -47,9 +48,27 @@ def get_decomposition_table(cls) -> Dict[Any, Any]: } @classmethod - def get_custom_passes(cls) -> List[typing.Any]: - """Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass""" - return [ReplaceEdgeOpWithTritonOpPass()] + def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: + """ + Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass. + + The Triton kernel replacement behavior can be controlled via compile_specs: + - triton_kernel_mode="ON": Always use Triton kernels + - triton_kernel_mode="OFF": Never use Triton kernels and fallback to other implementations like cuda or decomposed operator. + """ + # Parse compile_specs for triton_kernel_mode + triton_kernel_mode = "ON" # Default mode + for spec in compile_specs: + if spec.key == "triton_kernel_mode": + mode = spec.value.decode("utf-8").upper() + if mode not in ["ON", "OFF"]: + raise ValueError( + f"Invalid triton_kernel_mode: {mode}. " + f"Expected 'ON' or 'OFF'." + ) + triton_kernel_mode = mode + + return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] @classmethod def get_aoti_compile_options( @@ -114,3 +133,21 @@ def get_aoti_compile_options( ), "shim_library_path should not be set for Linux" return options + + @classmethod + def get_extra_aoti_compile_context_manager(cls): + """ + Return SDPA MATH backend context manager for CUDA compilation. + + This context manager plays as a fallback solution for any remaining PyTorch SDPA + operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. + + Note: + - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, + this context manager will have no effect on those ops (they are no longer + PyTorch SDPA ops). + - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this + context manager will force them to use the MATH backend, causing them to + be automatically decomposed during compilation. + """ + return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) diff --git a/backends/cuda/tests/test_cuda_export.py b/backends/cuda/tests/test_cuda_export.py index 03f4e4a9602..ff4a9313545 100644 --- a/backends/cuda/tests/test_cuda_export.py +++ b/backends/cuda/tests/test_cuda_export.py @@ -12,6 +12,7 @@ from executorch.backends.cuda.cuda_partitioner import CudaPartitioner from executorch.examples.models.toy_model import SdpaModule from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export import export @@ -25,16 +26,27 @@ def setUp(self): self.skipTest("CUDA is not available") def _export_to_cuda_with_lower( - self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] + self, + module: torch.nn.Module, + inputs: Tuple[torch.Tensor, ...], + compile_specs: list[CompileSpec] | None = None, ) -> None: - """Helper method to export a module to CUDA backend using to_edge_transform_and_lower.""" + """Helper method to export a module to CUDA backend using to_edge_transform_and_lower. + + Args: + module: The torch.nn.Module to export + inputs: The example inputs for the module + compile_specs: Optional list of compile specs. If not provided, defaults to + only the method name compile spec for "forward" + """ # Export the model exported_program = export(module, inputs, strict=True) - # Create partitioner and compile specs - partitioner = CudaPartitioner( - [CudaBackend.generate_method_name_compile_spec("forward")] - ) + # Create partitioner with compile specs + if compile_specs is None: + compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")] + + partitioner = CudaPartitioner(compile_specs) # Use to_edge_transform_and_lower for complete pipeline edge_program_manager = to_edge_transform_and_lower( @@ -288,3 +300,28 @@ def test_sdpa_single_kernel(self): edge_program_manager, "SDPA single kernel operation export failed", ) + + def test_triton_kernel_mode_off(self): + """ + Test CUDA export with triton_kernel_mode set to OFF for SDPA kernel. + This validates that the backend correctly processes the triton_kernel_mode + compile spec and can export SDPA operations without Triton kernel replacements. + When triton_kernel_mode is OFF, SDPA should be decomposed using the MATH backend. + """ + + sdpa = SdpaModule() + + # Create compile specs with triton_kernel_mode set to OFF + compile_specs = [ + CudaBackend.generate_method_name_compile_spec("forward"), + CompileSpec(key="triton_kernel_mode", value=b"OFF"), + ] + + # Test export with triton_kernel_mode=OFF + edge_program_manager = self._export_to_cuda_with_lower( + sdpa.get_eager_model(), sdpa.get_example_inputs(), compile_specs + ) + self.assertIsNotNone( + edge_program_manager, + "SDPA kernel export with triton_kernel_mode=OFF failed", + )