diff --git a/benchmarks/run.py b/benchmarks/run.py index 999824e04..7dd1dfd09 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -304,6 +304,16 @@ class RunResult: "examples.bf16xint16_gemm", "bf16xint16_gemm_tritonbench", ), + "blackwell_attentions": ( + "tritonbench.operators.blackwell_attentions.operator", + "examples.blackwell_attention", + "blackwell_attention", + { + "d_head": 128, # Set default head dimension to 128 for TLX attention compatibility + "num_inputs": 6, # flash_attention takes long time on Benchmark CI, so use fewer inputs instead. + "input_id": 1, + }, + ), } @@ -579,6 +589,15 @@ class RunResult: "helion_bf16xint16_gemm_tritonbench-speedup": "helion_speedup", "helion_bf16xint16_gemm_tritonbench-accuracy": "helion_accuracy", }, + "blackwell_attentions": { + "aten": "baseline", + "triton_tutorial_flash_v2_tma_ws_persistent-speedup": "triton_speedup", + "triton_tutorial_flash_v2_tma_ws_persistent-accuracy": "triton_accuracy", + "flex_attention-speedup": "torch_compile_speedup", + "flex_attention-accuracy": "torch_compile_accuracy", + "helion_attention-speedup": "helion_speedup", + "helion_attention-accuracy": "helion_accuracy", + }, } diff --git a/examples/blackwell_attention.py b/examples/blackwell_attention.py new file mode 100644 index 000000000..7dab8cb4f --- /dev/null +++ b/examples/blackwell_attention.py @@ -0,0 +1,280 @@ +""" +BLackwell Attention Example +================= + +This code implements a custom attention kernel using Helion and PyTorch for efficient computation of scaled dot-product attention, +specifically tuned for Blackwell. +""" +# %% +# Imports +# ------- + +# %% +from __future__ import annotations + +import math + +import torch +from triton.testing import do_bench + +import helion +from helion._testing import run_example +from helion.autotuner.config_fragment import EnumFragment +import helion.language as hl + +# %% +# Utility Functions +# ------------------------------- + + +# %% +def _mul_f32x2(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Vectorized F32 PTX MUL""" + return hl.inline_asm_elementwise( + """ + { + .reg .b64 ra, rb, rc; + mov.b64 ra, { $2, $3 }; + mov.b64 rb, { $4, $5 }; + mul.f32x2 rc, ra, rb; + mov.b64 { $0, $1 }, rc; + } + """, + "=r,=r,r,r,r,r", + [a, b], + dtype=torch.float32, + is_pure=True, + pack=2, + ) + + +# %% +def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + """Vectorized F32 PTX FMA""" + return hl.inline_asm_elementwise( + """ + { + .reg .b64 ra, rb, rc; + mov.b64 ra, { $2, $3 }; + mov.b64 rb, { $4, $5 }; + mul.f32x2 rc, ra, rb; + mov.b64 { $0, $1 }, rc; + } + """, + "=r,=r,r,r,r,r", + [a, b, c], + dtype=torch.float32, + is_pure=True, + pack=2, + ) + + +# %% +# Attention Kernel Implementation +# ------------------------------- + + +# %% +@helion.kernel( + configs=[ + helion.Config( + block_sizes=[256, N], + range_warp_specializes=[OUTER_LOOP or None, None if OUTER_LOOP else True], + range_multi_buffers=[None, False], + pid_type="persistent_interleaved", + indexing="tensor_descriptor", + num_warps=4, + num_stages=3, + _triton_range_id_data_partition_factor=0, + _triton_range_value_data_partition_factor=2, + _triton_config_maxRegAutoWS=maxreg, + ) + for N in [64, 128] + for OUTER_LOOP in [True] + for maxreg in [152, 192] + ], + static_shapes=True, + autotune_accuracy_check=False, +) +def blackwell_attention( + q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Computes scaled dot-product attention. + + Implements the attention mechanism: Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V + + Args: + q_in: Query tensor of shape [..., seq_len_q, head_dim] + k_in: Key tensor of shape [..., seq_len_k, head_dim] + v_in: Value tensor of shape [..., seq_len_k, head_dim] + + Returns: + Output tensor of shape [..., seq_len_q, head_dim] + """ + B, H, M, D = q_in.shape + Bk, Hk, N, Dk = k_in.shape + assert Dk == D + assert Bk == B + assert Hk == H + Bv, Hv, Nv, Dv = v_in.shape + assert Bv == B + assert Hv == Hk + assert Nv == N + D = hl.specialize(D) + Dv = hl.specialize(Dv) + q = q_in.reshape(-1, D) + k = k_in.reshape(-1, D) + v = v_in.reshape(-1, Dv) + MM = q.shape[0] + assert v.shape[0] == k.shape[0] + o = q.new_empty(MM, Dv) + lse = q.new_empty(MM, dtype=torch.float32) + block_m = hl.register_block_size(M) + block_n = hl.register_block_size(N) + assert M % block_m == 0 + assert N % block_n == 0 + hl.register_tunable( + "_triton_range_id_data_partition_factor", EnumFragment(choices=(0,)) + ) + hl.register_tunable( + "_triton_range_value_data_partition_factor", EnumFragment(choices=(2,)) + ) + hl.register_tunable("_triton_config_maxRegAutoWS", EnumFragment(choices=(152, 192))) + SUBTILING = True + VECT_MUL = 1 + sm_scale = 1.0 / math.sqrt(D) + qk_scale = sm_scale * 1.44269504 # 1/log(2) + for tile_m in hl.tile(MM, block_size=block_m): + m_i = hl.zeros([tile_m]) - float("inf") + l_i = hl.zeros([tile_m]) + 1.0 + acc = hl.zeros([tile_m, Dv]) + q_i = q[tile_m, :] + + start_N = tile_m.begin // M * N + for tile_n in hl.tile(N, block_size=block_n): + k_j = k[tile_n + start_N, :] + v_j = v[tile_n + start_N, :] + qk = hl.dot(q_i, k_j.T, out_dtype=torch.float32) + m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale) + if VECT_MUL == 2 or VECT_MUL == 3: + qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) # pyright: ignore[reportArgumentType] + else: + qk = qk * qk_scale - m_ij[:, None] + + p = torch.exp2(qk) + # -- compute correction factor + alpha = torch.exp2(m_i - m_ij) + l_ij = torch.sum(p, -1) + + if SUBTILING: + acc0, acc1 = hl.split( + acc.reshape([tile_m, 2, Dv // 2]).permute(0, 2, 1) + ) + if VECT_MUL == 1 or VECT_MUL == 3: + acc0 = _mul_f32x2(acc0, alpha[:, None]) + acc1 = _mul_f32x2(acc1, alpha[:, None]) + else: + acc0 = acc0 * alpha[:, None] + acc1 = acc1 * alpha[:, None] + acc = ( + hl.join(acc0, acc1) + .permute(0, 2, 1) + .reshape(acc.size(0), acc.size(1)) + ) + else: + acc = acc * alpha[:, None] + + # update m_i and l_i + + # We can potentially move these to be before updating l_ij, so the dot + # is not blocked. + # prepare p and v for the dot + p = p.to(v.dtype) + # note that this non transposed v for FP8 is only supported on Blackwell + acc = hl.dot(p, v_j, acc=acc) + + l_i = l_i * alpha + l_ij + m_i = m_ij + + m_i += torch.log2(l_i) + acc = acc / l_i[:, None] + lse[tile_m] = m_i + o[tile_m, :] = acc + + return o.reshape(B, H, M, Dv), lse.reshape(B, H, M) + + +# %% +# Testing Function +# ---------------- + + +# %% +def test( + z: int, + h: int, + n_ctx: int, + head_dim: int, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cuda", +) -> None: + """ + Test the attention kernel implementation against PyTorch's native attention functions. + + Args: + z: Batch size + h: Number of attention heads + n_ctx: Sequence length (context size) + head_dim: Dimension of each attention head + dtype: Data type for the tensors + device: Device to run the test on + """ + q, k, v = [ + torch.randn((z, h, n_ctx, head_dim), dtype=dtype, device=device) + for _ in range(3) + ] + + def ref_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """Reference manual attention implementation""" + p = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) + p = torch.softmax(p.float(), dim=-1).to(dtype) + return torch.matmul(p, v) + + baselines = { + "torch": torch.nn.functional.scaled_dot_product_attention, + "ref": ref_attention, + } + + run_example( + lambda *args: blackwell_attention(*args)[0], + baselines, + (q, k, v), + atol=0.1, + rtol=0.1, + ) + dur: float = do_bench(lambda: blackwell_attention(q, k, v)) # pyright: ignore[reportArgumentType, reportAssignmentType] + print( + f"{z=} {h=} {n_ctx=} {head_dim=} tflops={z * h * n_ctx * n_ctx * head_dim * 4 / dur * 1e-9:.2f}" + ) + + +# %% +# Main Function +# ------------- + + +# %% +def main() -> None: + """ + Main entry point that runs the attention kernel test with specific parameters. + Tests with batch size 2, 32 heads, 1024 sequence length, and 64-dimensional heads using float16. + """ + test(4, 32, 8192, 64, torch.bfloat16) + test(4, 32, 8192, 128, torch.bfloat16) + + +if __name__ == "__main__": + main() diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 4977f38a9..49abcdded 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -225,6 +225,11 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: "num_warps", "num_stages", ] + + [ + x.removeprefix("_triton_config_") + for x in config + if x.startswith("_triton_config_") + ] ) self._variable_renames: dict[str, list[str]] = {} self.dce_vars: list[str] = [] @@ -614,6 +619,11 @@ def codegen_function_call(self) -> ast.AST: f"num_warps={num_warps}", f"num_stages={self.config.num_stages}", ] + + [ + f"{x.removeprefix('_triton_config_')}={self.config[x]}" + for x in self.config + if x.startswith("_triton_config_") + ] ) pid = self.pid assert pid is not None diff --git a/helion/_compiler/roll_reduction.py b/helion/_compiler/roll_reduction.py index 80abc6ba1..4aed09f93 100644 --- a/helion/_compiler/roll_reduction.py +++ b/helion/_compiler/roll_reduction.py @@ -222,7 +222,7 @@ def start_new_graph(self) -> None: location_meta = { "location": next(iter(inner_nodes)).meta["location"], - "stack_trace": next(iter(inner_nodes)).meta["stack_trace"], + "stack_trace": next(iter(inner_nodes)).meta.get("stack_trace", ""), } output_node = self.outer_graph.call_function( _for_loop, diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index 21313e6b5..73406f4e7 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -184,6 +184,13 @@ def get_tl_range_kwargs(config: Config, block_idx: int) -> list[str]: ) if range_flatten is not None: kwargs.append(f"flatten={range_flatten}") + + dpf_range = config.get("_triton_range_id_data_partition_factor", None) + dpf_value = config.get("_triton_range_value_data_partition_factor", None) + + if dpf_range is not None and dpf_value is not None and dpf_range == block_idx: + kwargs.append(f"data_partition_factor={dpf_value}") + return kwargs @staticmethod diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 1f2c9d718..5def1da5c 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -60,8 +60,14 @@ def default_launcher( *args: object, num_warps: int, num_stages: int, + **kwargs: dict, ) -> object: """Default launcher function that executes the kernel immediately.""" return triton_kernel.run( - *args, grid=grid, warmup=False, num_warps=num_warps, num_stages=num_stages + *args, + grid=grid, + warmup=False, + num_warps=num_warps, + num_stages=num_stages, + **kwargs, )