From 98c68998a1865bdb73299418d13e8266988581b3 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 18 Nov 2025 15:57:52 -0800 Subject: [PATCH 1/2] introduce triton sdpa kernel to cuda backend Summary: **Introduce Triton SDPA Kernel to CUDA Backend** This diff introduces a kernel-generator (https://github.com/meta-pytorch/KernelAgent) driven, Triton-optimized implementation of scaled dot-product attention (SDPA) kernel to the CUDA backend. The new kernel is designed to replace the default Edge SDPA operator during graph transformation to accelerate the model inference and get rid of sdpa decomposition. **Changes** * Added a new file `sdpa.py` to `fbcode/executorch/backends/cuda/triton/kernels` and `fbcode/executorch/backends/cuda/triton/kernels` directories, which contains the Triton-optimized SDPA kernel implementation. * Added a new `fbcode/executorch/backends/cuda/triton/replacement_pass`, which replaces the given existing edge ops with target triton kernels. * Added tests for sdpa exporting with triton kernel. Without the triton kernel, sdpa model can not be exported. **Purpose** The purpose of this diff is to provide a high-performance SDPA kernel for the CUDA backend, which can be used to accelerate attention-based models on NVIDIA GPUs Reviewed By: larryliu0820 Differential Revision: D87259044 --- .github/workflows/cuda.yml | 2 +- backends/cuda/TARGETS | 31 ++ backends/cuda/cuda_backend.py | 15 +- backends/cuda/tests/TARGETS | 1 + backends/cuda/tests/test_cuda_export.py | 18 ++ backends/cuda/triton/__init__.py | 17 ++ backends/cuda/triton/kernels/__init__.py | 11 + backends/cuda/triton/kernels/sdpa.py | 365 +++++++++++++++++++++++ backends/cuda/triton/replacement_pass.py | 129 ++++++++ examples/models/__init__.py | 2 + examples/models/toy_model/__init__.py | 2 + examples/models/toy_model/model.py | 36 +++ 12 files changed, 622 insertions(+), 7 deletions(-) create mode 100644 backends/cuda/triton/__init__.py create mode 100644 backends/cuda/triton/kernels/__init__.py create mode 100644 backends/cuda/triton/kernels/sdpa.py create mode 100644 backends/cuda/triton/replacement_pass.py diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 7cc937fe6ca..1d237f5d8ef 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -71,7 +71,7 @@ jobs: strategy: fail-fast: false matrix: - model: [linear, add, add_mul, resnet18, conv1d] + model: [linear, add, add_mul, resnet18, conv1d, sdpa] with: timeout: 90 runner: linux.g5.4xlarge.nvidia.gpu diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index 94af87bbaed..d8256f77c41 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -11,6 +11,7 @@ runtime.python_library( "//executorch/...", ], deps = [ + ":triton_replacement_pass", "//caffe2:torch", "//executorch/backends/aoti/passes:passes", "//executorch/exir/_serialize:lib", @@ -32,3 +33,33 @@ runtime.python_library( "//executorch/backends/aoti:aoti_partitioner", ], ) + +runtime.python_library( + name = "triton_kernels", + srcs = [ + "triton/kernels/__init__.py", + "triton/kernels/sdpa.py", + ], + visibility = [ + "//executorch/backends/cuda/...", + ], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "triton_replacement_pass", + srcs = [ + "triton/__init__.py", + "triton/replacement_pass.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + ":triton_kernels", + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 5176ca42710..772e24c75b3 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -16,6 +16,10 @@ from executorch.backends.aoti.passes.replace_view_copy_with_view import ( ReplaceViewCopyWithViewPass, ) + +from executorch.backends.cuda.triton.replacement_pass import ( + ReplaceEdgeOpWithTritonOpPass, +) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental from executorch.exir.backend.backend_details import ( @@ -27,7 +31,7 @@ from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.decomposition import conv1d_to_conv2d from torch.export.passes import move_to_device_pass -from torch.nn.attention import SDPBackend + cuda_decomposition_table = { torch.ops.aten.conv1d.default: conv1d_to_conv2d, @@ -127,6 +131,9 @@ def preprocess( # noqa: C901 # replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) + # Replace aten ops with triton ops + ReplaceEdgeOpWithTritonOpPass()(cuda_edge_program.graph_module) + cuda_edge_program = cuda_edge_program.run_decompositions( cuda_decomposition_table ) @@ -188,11 +195,7 @@ def preprocess( # noqa: C901 } ) - with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel( - [ - SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. - ] - ), torch.no_grad(): + with collect_unsupported_fallback_kernels(), torch.no_grad(): # torch._logging.set_logs(post_grad_graphs=True) # Here we should expect 1 so file and 1 weight blob in the same directory. paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] diff --git a/backends/cuda/tests/TARGETS b/backends/cuda/tests/TARGETS index 12718c04388..974086cd4c5 100644 --- a/backends/cuda/tests/TARGETS +++ b/backends/cuda/tests/TARGETS @@ -19,6 +19,7 @@ python_unittest_remote_gpu( "//executorch/exir:lib", "//executorch/exir/backend:backend_api", "//executorch/exir/backend:compile_spec_schema", + "//executorch/examples/models/toy_model:toy_model", ], keep_gpu_sections = True, ) diff --git a/backends/cuda/tests/test_cuda_export.py b/backends/cuda/tests/test_cuda_export.py index ef43a3ab3cb..03f4e4a9602 100644 --- a/backends/cuda/tests/test_cuda_export.py +++ b/backends/cuda/tests/test_cuda_export.py @@ -10,6 +10,7 @@ import torch from executorch.backends.cuda.cuda_backend import CudaBackend from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.examples.models.toy_model import SdpaModule from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from torch.export import export @@ -270,3 +271,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Test export edge_program_manager = self._export_to_cuda_with_lower(module, inputs) self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed") + + def test_sdpa_single_kernel(self): + """ + Test CUDA export for model containing single SDPA kernel. + SDPA: Scaled Dot Product Attention + """ + + sdpa = SdpaModule() + + # Test export + edge_program_manager = self._export_to_cuda_with_lower( + sdpa.get_eager_model(), sdpa.get_example_inputs() + ) + self.assertIsNotNone( + edge_program_manager, + "SDPA single kernel operation export failed", + ) diff --git a/backends/cuda/triton/__init__.py b/backends/cuda/triton/__init__.py new file mode 100644 index 00000000000..4b9c36249ac --- /dev/null +++ b/backends/cuda/triton/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Import all kernels to ensure @triton_op decorators are executed +# and ops are registered to torch.ops.triton namespace +from executorch.backends.cuda.triton import kernels # noqa: F401 + +from executorch.backends.cuda.triton.replacement_pass import ( + ReplaceEdgeOpWithTritonOpPass, +) + +__all__ = [ + "ReplaceEdgeOpWithTritonOpPass", +] diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py new file mode 100644 index 00000000000..5bd582679c4 --- /dev/null +++ b/backends/cuda/triton/kernels/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.cuda.triton.kernels.sdpa import sdpa + +__all__ = [ + "sdpa", +] diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py new file mode 100644 index 00000000000..432079b15cb --- /dev/null +++ b/backends/cuda/triton/kernels/sdpa.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Triton SDPA Kernel for ExecuTorch CUDA Backend. + +This module provides a Triton-optimized implementation of scaled dot-product attention +that can replace the default ATen/Edge SDPA operator during graph transformation to allow +us export the model without decomposing the SDPA operator under libtorch free environment +and have better performance. +""" + +import math +from typing import Optional + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + + +def _validate_qkv_shapes( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, +) -> tuple[int, int, int, int, int, int]: + """ + Validate dimensions and return shape info. + Args: + query: Query tensor [B, H, L_q, D] + key: Key tensor [B, H, L_kv, D] + value: Value tensor [B, H, L_kv, D] + Returns: + Tuple of (B, H, L_q, L_kv, D_q, D_kv) + Raises: + RuntimeError: If dimensions are incompatible + """ + B_q, H_q, L_q, D_q = query.shape + B_k, H_k, L_kv_k, D_k = key.shape + B_v, H_v, L_kv_v, D_v = value.shape + # Validate batch and head dimensions + if not (B_q == B_k == B_v): + raise RuntimeError( + f"Batch dimension must match; got B_q={B_q}, B_k={B_k}, B_v={B_v}." + ) + + if not (H_q == H_k == H_v): + raise RuntimeError( + f"Head dimension must match; got H_q={H_q}, H_k={H_k}, H_v={H_v}." + ) + # Head dimension must match + if not (D_q == D_k == D_v): + raise RuntimeError( + f"Head dimension must match across Q, K, V; got D_q={D_q}, D_k={D_k}, D_v={D_v}." + ) + # Key and Value sequence lengths must match + if L_kv_k != L_kv_v: + raise RuntimeError( + f"Key and Value must have the same sequence length; got L_k={L_kv_k}, L_v={L_kv_v}." + ) + return B_q, H_q, L_q, L_kv_k, D_q, D_k + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=4, num_warps=8), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_stages=4, num_warps=8), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=1, num_warps=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=1, num_warps=2), + ], + key=["L_Q", "L_KV", "HEAD_DIM"], +) +@triton.jit +def _sdpa_fwd_kernel( + q_ptr, + k_ptr, + v_ptr, + mask_ptr, + o_ptr, + B, + H, + L_Q, # Query sequence length + L_KV, # Key/Value sequence length + HEAD_DIM, + stride_qb, + stride_qh, + stride_ql, + stride_qd, + stride_kb, + stride_kh, + stride_kl, + stride_kd, + stride_vb, + stride_vh, + stride_vl, + stride_vd, + stride_mb, + stride_mh, + stride_ml, + stride_mn, + stride_ob, + stride_oh, + stride_ol, + stride_od, + sm_scale, + IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM_CE: tl.constexpr, +): + """ + Fused SDPA kernel that handles different sequence lengths for Q and K/V. + + Q shape: [B, H, L_Q, D] + K/V shape: [B, H, L_KV, D] + Output shape: [B, H, L_Q, D] + """ + # Program IDs + pid_m = tl.program_id(axis=0) # along query length + pid_hz = tl.program_id(axis=1) # flattened batch*head + off_b = pid_hz // H + off_h = pid_hz % H + # Compute ranges for queries + start_m = pid_m * BLOCK_M + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_DIM_CE) + mask_m = offs_m < L_Q # Mask based on query length + # Base pointers for this (b, h) + q_base = q_ptr + off_b * stride_qb + off_h * stride_qh + k_base = k_ptr + off_b * stride_kb + off_h * stride_kh + v_base = v_ptr + off_b * stride_vb + off_h * stride_vh + o_base = o_ptr + off_b * stride_ob + off_h * stride_oh + # Mask base pointer (if provided) + if HAS_MASK: + mask_base = mask_ptr + off_b * stride_mb + off_h * stride_mh + # Make head-dim addresses compiler-friendly + offs_d_ctg = tl.max_contiguous(tl.multiple_of(offs_d, 16), HEAD_DIM_CE) + # Load Q tile [BLOCK_M, HEAD_DIM] - coalesced along HEAD_DIM + q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d_ctg[None, :] * stride_qd) + q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0) + q = q.to(tl.bfloat16) + # Initialize accumulators and softmax stats + acc = tl.zeros((BLOCK_M, HEAD_DIM_CE), dtype=tl.float32) + m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32) + l_i = tl.zeros((BLOCK_M,), dtype=tl.float32) + # Convert to base-2 scale for exp2 + qk_scale = sm_scale * 1.4426950408889634 + # Loop over keys/values along L_KV dimension (not L_Q!) + for start_n in tl.range(0, L_KV, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < L_KV # Mask based on key/value length + # Load K tile [BLOCK_N, HEAD_DIM] (contiguous along HEAD_DIM) + k_ptrs = k_base + ( + offs_n[:, None] * stride_kl + offs_d_ctg[None, :] * stride_kd + ) + k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) + k = k.to(tl.bfloat16) + # Compute attention logits [BLOCK_M, BLOCK_N] = Q[BM,D] @ K[BN,D]^T + qk = tl.dot(q, tl.trans(k)).to(tl.float32) + qk = qk * qk_scale + # Apply causal mask if needed + # For causal masking with different lengths: position i can attend to position j if i >= j + if IS_CAUSAL: + causal_mask = offs_m[:, None] >= offs_n[None, :] + qk = tl.where(causal_mask, qk, -float("inf")) + # Apply attention mask if provided + if HAS_MASK: + # Load mask tile [BLOCK_M, BLOCK_N] + # Mask shape should be [B, H, L_Q, L_KV] + mask_ptrs = mask_base + ( + offs_m[:, None] * stride_ml + offs_n[None, :] * stride_mn + ) + attn_mask = tl.load( + mask_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0 + ) + # Convert boolean mask to additive mask (-inf for False, 0 for True) + qk = tl.where(attn_mask, qk, -float("inf")) + # Apply OOB masks for both rows and cols + qk = tl.where(mask_n[None, :], qk, -float("inf")) + qk = tl.where(mask_m[:, None], qk, -float("inf")) + # Online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + # Load V tile [BLOCK_N, HEAD_DIM] (contiguous along HEAD_DIM) + v_ptrs = v_base + ( + offs_n[:, None] * stride_vl + offs_d_ctg[None, :] * stride_vd + ) + v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) + v = v.to(tl.bfloat16) + # Update accumulator + acc = acc * alpha[:, None] + p_bf16 = p.to(tl.bfloat16) + acc = tl.dot(p_bf16, v, acc) + # Update softmax stats + l_i = l_i * alpha + l_ij + m_i = m_ij + # Normalize accumulator by softmax denominator + acc = acc / l_i[:, None] + # Store output [BLOCK_M, HEAD_DIM] - shape matches query + o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d_ctg[None, :] * stride_od) + tl.store(o_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None]) + + +@triton_op("triton::sdpa", mutates_args={}) +def sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, +) -> torch.Tensor: + """ + Triton fused Scaled Dot-Product Attention with support for different sequence lengths. + + Args: + query: Query tensor with szie [B, H, L_q, D] and dtype torch.bfloat16 + key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16 + value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16 + attn_mask: Optional attention mask [B, H, L_q, L_kv] or broadcastable shape (2D: [L_q, L_kv] or 3D: [B, L_q, L_kv]) + dropout_p: must be 0.0 (others are not supported) + is_causal: whether to apply causal masking + scale: attention scale (default: 1/sqrt(D)) + enable_gqa: must be False (True is not supported) + Returns: + Output tensor [B, H, L_q, D] with dtype torch.bfloat16 + """ + # Validate inputs + if not (query.is_cuda and key.is_cuda and value.is_cuda): + raise RuntimeError("Q, K, V must be CUDA tensors.") + if ( + query.dtype != torch.bfloat16 + or key.dtype != torch.bfloat16 + or value.dtype != torch.bfloat16 + ): + raise RuntimeError("Expected bfloat16 inputs") + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise RuntimeError( + f"Expected 4D tensors shaped [B, H, L, D]; got query.dim()={query.dim()}, key.dim()={key.dim()}, value.dim()={value.dim()}." + ) + # Enforce unsupported features + if dropout_p != 0.0: + raise RuntimeError( + "dropout_p must be 0.0 (not supported in this implementation)." + ) + if enable_gqa is not False: + raise RuntimeError( + "enable_gqa must be False (not supported in this implementation)." + ) + # Validate and get dimensions + B, H, L_q, L_kv, D_q, D_kv = _validate_qkv_shapes(query, key, value) + D = D_q # Head dimension + # Allocate output with query shape + out = torch.empty_like(query) + # Element-wise strides + sqb, sqh, sql, sqd = query.stride() + skb, skh, skl, skd = key.stride() + svb, svh, svl, svd = value.stride() + sob, soh, sol, sod = out.stride() + + # Grid: tile queries (M) and batch*heads axis + def grid(META): + return ( + triton.cdiv(L_q, META["BLOCK_M"]), # Based on query length + B * H, + ) + + # Scale factor for SDPA + sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale + # Handle attention mask + has_mask = attn_mask is not None + if has_mask: + # Expand mask to [B, H, L_q, L_kv] if needed + if attn_mask.dim() == 2: + # [L_q, L_kv] -> [B, H, L_q, L_kv] + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).expand(B, H, -1, -1) + elif attn_mask.dim() == 3: + # [B, L_q, L_kv] -> [B, H, L_q, L_kv] + attn_mask = attn_mask.unsqueeze(1).expand(-1, H, -1, -1) + + # Validate mask shape + if attn_mask.shape != (B, H, L_q, L_kv): + # Try to expand if broadcastable + attn_mask = attn_mask.expand(B, H, L_q, L_kv) + + smb, smh, sml, smn = attn_mask.stride() + else: + # Dummy strides and mask + smb, smh, sml, smn = 0, 0, 0, 0 + attn_mask = torch.empty(0, dtype=torch.bool, device=query.device) + # Launch kernel + wrap_triton(_sdpa_fwd_kernel)[grid]( + query, + key, + value, + attn_mask, + out, + B, + H, + L_q, # Query sequence length + L_kv, # Key/Value sequence length + D, + sqb, + sqh, + sql, + sqd, + skb, + skh, + skl, + skd, + svb, + svh, + svl, + svd, + smb, + smh, + sml, + smn, + sob, + soh, + sol, + sod, + sm_scale, + IS_CAUSAL=is_causal, + HAS_MASK=has_mask, + HEAD_DIM_CE=D, + ) + return out + + +# Register the abstract/fake implementation for torch.export +# This is critical to avoid accessing real tensor data during export +@sdpa.register_fake +def _sdpa_abstract( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gq: bool = False, +) -> torch.Tensor: + """ + Abstract/fake implementation for torch.export. + This just returns an empty tensor with the correct shape/dtype/device. + """ + # Validate dtypes match + assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" + # Validate kqv's shape and get the output shape + B, H, L_q, _, D_q, _ = _validate_qkv_shapes(query, key, value) + + return torch.empty(B, H, L_q, D_q, dtype=query.dtype, device=query.device) diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py new file mode 100644 index 00000000000..afe0854b2cf --- /dev/null +++ b/backends/cuda/triton/replacement_pass.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph Transformation Pass for Triton Kernel Replacement. + +This pass replaces ATen operators with optimized Triton kernels in the graph. +""" + +import logging + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +logger = logging.getLogger(__name__) +triton = torch.ops.triton + +# Global mapping from edge dialect operators to Triton kernel functions +EDGE_TO_TRITON_KERNELS = { + exir_ops.edge.aten.scaled_dot_product_attention.default: triton.sdpa, +} + + +class ReplaceEdgeOpWithTritonOpPass(PassBase): + """ + Pass to replace ATen operators with Triton kernels. + + This pass scans the graph for Edge operators that have registered Triton + replacements using EDGE_TO_TRITON_KERNELS and replaces them with the + optimized Triton implementations. + """ + + def __init__(self): + """Initialize the pass.""" + super().__init__() + self._replacement_count = 0 + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Execute the pass on the graph module. + + Args: + graph_module: The graph module to transform + + Returns: + PassResult indicating success/failure and the modified graph module + """ + self._replacement_count = 0 + modified = False + + if not EDGE_TO_TRITON_KERNELS: + return PassResult(graph_module, False) + + # Iterate through all nodes in the graph + for node in graph_module.graph.nodes: + if self._should_replace_node(node): + try: + self._replace_node_with_triton(graph_module, node) + modified = True + self._replacement_count += 1 + except Exception as e: + logger.warning(f"Failed to replace node {node.name}: {e}") + # Continue with other replacements even if one fails + + if modified: + # Recompile the graph module after modifications + graph_module.recompile() + + logger.info(f"Replaced {self._replacement_count} nodes with Triton kernels") + + return PassResult(graph_module, modified) + + def _should_replace_node(self, node: Node) -> bool: + """ + Check if a node should be replaced with a Triton kernel. + + Args: + node: The node to check + + Returns: + True if the node should be replaced + """ + # Only consider call_function nodes + if node.op != "call_function": + return False + + return node.target in EDGE_TO_TRITON_KERNELS + + def _replace_node_with_triton(self, graph_module: GraphModule, node: Node) -> None: + """ + Replace an edge dialect node with a Triton kernel call. + + Args: + graph_module: The graph module containing the node + node: The node to replace + """ + # Get the target operator (should be an exir_ops edge dialect op) + target = node.target + + # Get the replacement kernel + if target not in EDGE_TO_TRITON_KERNELS: + raise ValueError(f"No replacement kernel found for {target}") + + triton_kernel_fn = EDGE_TO_TRITON_KERNELS[target] + + # Create a new node with the Triton kernel + with graph_module.graph.inserting_before(node): + # The triton_kernel_fn is already registered as a custom op via @triton_op + # We can call it directly + new_node = graph_module.graph.call_function( + triton_kernel_fn, + args=node.args, + kwargs=node.kwargs, + ) + + # Copy metadata from original node + new_node.meta = node.meta.copy() + + # Replace all uses of the old node with the new node + node.replace_all_uses_with(new_node) + + # Remove the old node + graph_module.graph.erase_node(node) diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 45abfd8f89d..6a6c4ff1875 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -40,6 +40,7 @@ class Model(str, Enum): Phi4Mini = "phi_4_mini" SmolLM2 = "smollm2" DeiTTiny = "deit_tiny" + Sdpa = "sdpa" def __str__(self) -> str: return self.value @@ -89,6 +90,7 @@ def __str__(self) -> str: str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"), str(Model.SmolLM2): ("smollm2", "SmolLM2Model"), str(Model.DeiTTiny): ("deit_tiny", "DeiTTinyModel"), + str(Model.Sdpa): ("toy_model", "SdpaModule"), } __all__ = [ diff --git a/examples/models/toy_model/__init__.py b/examples/models/toy_model/__init__.py index 333a625af1b..87456e3fd4c 100644 --- a/examples/models/toy_model/__init__.py +++ b/examples/models/toy_model/__init__.py @@ -10,6 +10,7 @@ Conv1dModule, LinearModule, MulModule, + SdpaModule, SoftmaxModule, ) @@ -19,5 +20,6 @@ Conv1dModule, LinearModule, MulModule, + SdpaModule, SoftmaxModule, ] diff --git a/examples/models/toy_model/model.py b/examples/models/toy_model/model.py index e1dd290b829..a31149c29af 100644 --- a/examples/models/toy_model/model.py +++ b/examples/models/toy_model/model.py @@ -105,3 +105,39 @@ def get_eager_model(self) -> torch.nn.Module: def get_example_inputs(self): return (torch.randn(1, 3, 10),) + + +class SdpaModule(torch.nn.Module, EagerModelBase): + def __init__(self): + super().__init__() + + def forward(self, query, key, value): + out = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + return out + + def get_eager_model(self) -> torch.nn.Module: + return self + + def get_example_inputs(self): + # Input shape: (batch, num_heads, seq_len, head_dim) + batch_size = 2 + num_heads = 8 + seq_len = 128 + head_dim = 64 + query = torch.randn( + batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16 + ) + key = torch.randn( + batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16 + ) + value = torch.randn( + batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16 + ) + return (query, key, value) From fc87357997398930af88cefb1d918390aef70c78 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 18 Nov 2025 20:11:49 -0800 Subject: [PATCH 2/2] solve gemma3 export issue --- backends/cuda/triton/kernels/sdpa.py | 44 ++++++++++++++++++------ backends/cuda/triton/replacement_pass.py | 3 +- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index 432079b15cb..7e8eb1444df 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -22,6 +22,19 @@ from torch.library import triton_op, wrap_triton +def _next_power_of_2(n: int) -> int: + """Round up to the next power of 2.""" + if n <= 0: + return 1 + if n & (n - 1) == 0: + return n + + power = 1 + while power < n: + power <<= 1 + return power + + def _validate_qkv_shapes( query: torch.Tensor, key: torch.Tensor, @@ -88,7 +101,7 @@ def _sdpa_fwd_kernel( H, L_Q, # Query sequence length L_KV, # Key/Value sequence length - HEAD_DIM, + HEAD_DIM, # Actual head dimension (may not be power of 2) stride_qb, stride_qh, stride_ql, @@ -114,7 +127,7 @@ def _sdpa_fwd_kernel( HAS_MASK: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - HEAD_DIM_CE: tl.constexpr, + HEAD_DIM_CE: tl.constexpr, # Rounded up for tl.arange ): """ Fused SDPA kernel that handles different sequence lengths for Q and K/V. @@ -141,11 +154,13 @@ def _sdpa_fwd_kernel( # Mask base pointer (if provided) if HAS_MASK: mask_base = mask_ptr + off_b * stride_mb + off_h * stride_mh + # Mask for actual head dimension (HEAD_DIM may not be power of 2) + mask_d = offs_d < HEAD_DIM # Make head-dim addresses compiler-friendly offs_d_ctg = tl.max_contiguous(tl.multiple_of(offs_d, 16), HEAD_DIM_CE) # Load Q tile [BLOCK_M, HEAD_DIM] - coalesced along HEAD_DIM q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d_ctg[None, :] * stride_qd) - q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0) + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) q = q.to(tl.bfloat16) # Initialize accumulators and softmax stats acc = tl.zeros((BLOCK_M, HEAD_DIM_CE), dtype=tl.float32) @@ -161,7 +176,7 @@ def _sdpa_fwd_kernel( k_ptrs = k_base + ( offs_n[:, None] * stride_kl + offs_d_ctg[None, :] * stride_kd ) - k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) + k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) k = k.to(tl.bfloat16) # Compute attention logits [BLOCK_M, BLOCK_N] = Q[BM,D] @ K[BN,D]^T qk = tl.dot(q, tl.trans(k)).to(tl.float32) @@ -179,7 +194,9 @@ def _sdpa_fwd_kernel( offs_m[:, None] * stride_ml + offs_n[None, :] * stride_mn ) attn_mask = tl.load( - mask_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0 + mask_ptrs, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, ) # Convert boolean mask to additive mask (-inf for False, 0 for True) qk = tl.where(attn_mask, qk, -float("inf")) @@ -195,7 +212,7 @@ def _sdpa_fwd_kernel( v_ptrs = v_base + ( offs_n[:, None] * stride_vl + offs_d_ctg[None, :] * stride_vd ) - v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) + v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) v = v.to(tl.bfloat16) # Update accumulator acc = acc * alpha[:, None] @@ -208,7 +225,7 @@ def _sdpa_fwd_kernel( acc = acc / l_i[:, None] # Store output [BLOCK_M, HEAD_DIM] - shape matches query o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d_ctg[None, :] * stride_od) - tl.store(o_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None]) + tl.store(o_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_d[None, :]) @triton_op("triton::sdpa", mutates_args={}) @@ -229,7 +246,8 @@ def sdpa( query: Query tensor with szie [B, H, L_q, D] and dtype torch.bfloat16 key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16 value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16 - attn_mask: Optional attention mask [B, H, L_q, L_kv] or broadcastable shape (2D: [L_q, L_kv] or 3D: [B, L_q, L_kv]) + attn_mask: Optional attention mask [B, H, L_q, L_kv] or + broadcastable shape (2D: [L_q, L_kv] or 3D: [B, L_q, L_kv]) dropout_p: must be 0.0 (others are not supported) is_causal: whether to apply causal masking scale: attention scale (default: 1/sqrt(D)) @@ -248,7 +266,9 @@ def sdpa( raise RuntimeError("Expected bfloat16 inputs") if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: raise RuntimeError( - f"Expected 4D tensors shaped [B, H, L, D]; got query.dim()={query.dim()}, key.dim()={key.dim()}, value.dim()={value.dim()}." + f"Expected 4D tensors shaped [B, H, L, D]; got " + f"query.dim()={query.dim()}, key.dim()={key.dim()}, " + f"value.dim()={value.dim()}." ) # Enforce unsupported features if dropout_p != 0.0: @@ -300,6 +320,8 @@ def grid(META): # Dummy strides and mask smb, smh, sml, smn = 0, 0, 0, 0 attn_mask = torch.empty(0, dtype=torch.bool, device=query.device) + # Round up head dimension to next power of 2 for tile.arange in Triton kernel + HEAD_DIM_CE = _next_power_of_2(D) # Launch kernel wrap_triton(_sdpa_fwd_kernel)[grid]( query, @@ -311,7 +333,7 @@ def grid(META): H, L_q, # Query sequence length L_kv, # Key/Value sequence length - D, + D, # Actual head dimension sqb, sqh, sql, @@ -335,7 +357,7 @@ def grid(META): sm_scale, IS_CAUSAL=is_causal, HAS_MASK=has_mask, - HEAD_DIM_CE=D, + HEAD_DIM_CE=HEAD_DIM_CE, # Rounded to power of 2 ) return out diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index afe0854b2cf..bfa3838296b 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -72,7 +72,8 @@ def call(self, graph_module: GraphModule) -> PassResult: # Recompile the graph module after modifications graph_module.recompile() - logger.info(f"Replaced {self._replacement_count} nodes with Triton kernels") + # logger.info(f"Replaced {self._replacement_count} nodes with Triton kernels") + print(f"Replaced {self._replacement_count} nodes with Triton kernels") return PassResult(graph_module, modified)