From fb469ca2f97b3e83c75146eeb8e2790abf9eb594 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Wed, 15 Oct 2025 14:24:36 -0700 Subject: [PATCH 1/2] Mamba2 Chunk Scan --- benchmarks/run.py | 12 ++ examples/mamba2_chunk_scan.py | 273 ++++++++++++++++++++++++++++++++++ 2 files changed, 285 insertions(+) create mode 100644 examples/mamba2_chunk_scan.py diff --git a/benchmarks/run.py b/benchmarks/run.py index fbbb46fd8..f766fcc37 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -317,6 +317,11 @@ class RunResult: "input_id": 1, }, ), + "mamba2_chunk_scan": ( + "tritonbench.operators.mamba2_chunk_scan.operator", + "examples.mamba2_chunk_scan", + "helion_mamba2_chunk_scan_kernel", + ), } @@ -601,6 +606,13 @@ class RunResult: "helion_attention-speedup": "helion_speedup", "helion_attention-accuracy": "helion_accuracy", }, + "mamba2_chunk_scan": { + "eager": "baseline", + "compile_speedup": "torch_compile_speedup", + "compile_accuracy": "torch_compile_accuracy", + "helion_mamba2_chunk_scan_kernel_speedup": "helion_speedup", + "helion_mamba2_chunk_scan_kernel_accuracy": "helion_accuracy", + }, } diff --git a/examples/mamba2_chunk_scan.py b/examples/mamba2_chunk_scan.py new file mode 100644 index 000000000..a4c014527 --- /dev/null +++ b/examples/mamba2_chunk_scan.py @@ -0,0 +1,273 @@ +""" +Mamba2 Chunk Scan Kernel +======================== + +This code implements a chunked scan kernel as used for Mamba2 +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +import functools + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import run_example +import helion.language as hl + + +# %% +# Helion Kernel Implementation +# ---------------------------- +@helion.kernel() +def helion_mamba2_chunk_scan_kernel( + cb: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + dA_cumsum: torch.Tensor, + C: torch.Tensor, + prev_states: torch.Tensor, + D: torch.Tensor, +) -> torch.Tensor: + """ + Argument: + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + C: (batch, seqlen, ngroups, dstate) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + + batch, nchunks, ngroups, chunk_size, _ = cb.shape + _, seqlen, nheads, headdim = x.shape + _, _, _, dstate = C.shape + assert nchunks == (seqlen + chunk_size - 1) // chunk_size + + block_m = hl.register_block_size(chunk_size) + block_n = hl.register_block_size(headdim) + block_k = hl.register_block_size(64, 64) + dstate = hl.specialize(dstate) + + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert C.shape == (batch, seqlen, ngroups, dstate) + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert D.shape == (nheads,) + + dtype = cb.dtype + accum_dtype = torch.float32 + assert ( + x.dtype + == dt.dtype + == dA_cumsum.dtype + == C.dtype + == prev_states.dtype + == D.dtype + == dtype + ) + + out = torch.empty_like(x) + + p = 1.44269504 + + for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( + [nheads, chunk_size, headdim, batch, nchunks], + block_size=[1, block_m, block_n, 1, 1], + ): + acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) + dA_cumsum_local_m = dA_cumsum[ + tile_b.begin, tile_h.begin, tile_c.begin, tile_m + ].to(torch.float32) + scale_m_local = torch.exp2(dA_cumsum_local_m * p) + + C_local = C[ + tile_b.begin, + tile_m.index + tile_c.begin * chunk_size, + tile_h.begin // (nheads // ngroups), + :, + ] + prev_states_local = prev_states[ + tile_b.begin, tile_c.begin, tile_h.begin, tile_n, : + ] + acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o) + acc_o *= scale_m_local[:, None] + + for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k): + cb_local = cb[ + tile_b.begin, + tile_c.begin, + tile_h.begin // (nheads // ngroups), + tile_m, + tile_k, + ] + dA_cumsum_local_k = dA_cumsum[ + tile_b.begin, tile_h.begin, tile_c.begin, tile_k + ].to(torch.float32) + cb_local *= torch.exp2( + dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p + ) + dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to( + torch.float32 + ) + cb_local = (cb_local * dt_local[None, :]).to(dtype) + pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] + cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local)) + x_local = x[ + tile_b.begin, + tile_c.begin * chunk_size + tile_k.index, + tile_h.begin, + tile_n, + ] + acc_o = hl.dot(cb_local, x_local, acc=acc_o) + + D_local = D[tile_h.begin].to(torch.float32) + x_residual = x[ + tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n + ].to(torch.float32) + acc_o += x_residual * D_local + out[ + tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n + ] = acc_o.to(dtype=dtype) + + return out + + +# %% +# Reference Function +# ------------- +def ref_chunk_scan( + cb: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + dA_cumsum: torch.Tensor, + C: torch.Tensor, + prev_states: torch.Tensor, + D: torch.Tensor, +) -> torch.Tensor: + """ + Argument: + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) + x: (batch, seqlen, nheads, dhead) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + C: (batch, seqlen, ngroups, dstate) + prev_states: (batch, nchunks, nheads, dhead, dstate) + D: (nheads,) + Return: + out: (batch, seqlen, nheads, dhead) + """ + _, _, ngroups, _, _ = cb.shape + batch, seqlen, nheads, dhead = x.shape + # _, _, ngroups, dstate = B.shape + # assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + dstate = C.shape[-1] + assert seqlen == nchunks * chunk_size + # assert C.shape == B.shape + # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = torch.repeat_interleave(C, nheads // ngroups, dim=2) + cb = torch.repeat_interleave(cb, nheads // ngroups, dim=2) + # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + # rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = cb * decay.permute(0, 2, 1, 3, 4) + causal_mask = torch.tril( + torch.ones(chunk_size, chunk_size, device=x.device, dtype=torch.bool), + diagonal=0, + ) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", + scores_decay.to(x.dtype), + dt.to(x.dtype), + x.reshape(batch, nchunks, chunk_size, nheads, dhead), + ) + # state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + state_decay_out = torch.exp(dA_cumsum.permute(0, 2, 3, 1).unsqueeze(-1)) + out_prev = ( + torch.einsum( + "bclhn,bchpn->bclhp", + C.reshape(batch, nchunks, chunk_size, nheads, dstate), + prev_states.to(C.dtype), + ) + * state_decay_out + ) + out = out + out_prev + out = out.reshape(batch, seqlen, nheads, dhead) + if D is not None: + if D.dim() == 1: + D = D.unsqueeze(-1) + out = out + x * D + return out + + +# %% +# Testing Function +# ------------- +def test( + init: str, + batch: int, + nheads: int, + ngroups: int, + seqlen: int, + chunk_size: int, + headdim: int, + dstate: int, + dtype: torch.dtype = torch.float16, +) -> None: + INIT = { + "r": functools.partial(torch.randn, dtype=dtype, device=DEVICE), + "u": functools.partial(torch.rand, dtype=dtype, device=DEVICE), + "z": functools.partial(torch.zeros, dtype=dtype, device=DEVICE), + "o": functools.partial(torch.ones, dtype=dtype, device=DEVICE), + } + nchunks = (seqlen + chunk_size - 1) // chunk_size + idx = 0 + + def fn(*args: int) -> torch.Tensor: + nonlocal idx + ret = INIT[init[idx]](*args) + idx += 1 + return ret + + cb = fn(batch, nchunks, ngroups, chunk_size, chunk_size) + x = fn(batch, seqlen, nheads, headdim) + dt = fn(batch, nheads, nchunks, chunk_size) + dA_cumsum = fn(batch, nheads, nchunks, chunk_size) # init range is too large + C = fn(batch, seqlen, ngroups, dstate) + prev_states = fn(batch, nchunks, nheads, headdim, dstate) + D = fn(nheads) + args = (cb, x, dt, dA_cumsum, C, prev_states, D) + run_example(helion_mamba2_chunk_scan_kernel, ref_chunk_scan, args) + + +# %% +# Main Function +# ----------- +def main() -> None: + """ + Main entry point that runs the attention kernel test with specific parameters. + Tests with batch size 2, 32 heads, 1024 sequence length, and 64-dimensional heads using float16. + """ + test("zzzzzzz", 8, 80, 1, 4096, 256, 64, 128) + test("zrzzzzr", 8, 80, 1, 4096, 256, 64, 128) # D * x + test("zzzzrrz", 8, 80, 1, 4096, 256, 64, 128) # C * prev_state + test("zzzrrrz", 8, 80, 1, 4096, 256, 64, 128) # C * prev_state * dA + test("rrrzzzz", 8, 80, 1, 4096, 256, 64, 128) # cb * x * dt + test("rrruzzz", 8, 80, 1, 4096, 256, 64, 128) # cb * x * dt * dA + + +if __name__ == "__main__": + main() From 9f9df2b5e6b2f54c6d33a7e8ec992a5b72586980 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Wed, 15 Oct 2025 14:47:17 -0700 Subject: [PATCH 2/2] Mamab2 Chunk State --- benchmarks/run.py | 12 +++ examples/mamba2_chunk_state.py | 185 +++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 examples/mamba2_chunk_state.py diff --git a/benchmarks/run.py b/benchmarks/run.py index f766fcc37..3d7c7e6c4 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -322,6 +322,11 @@ class RunResult: "examples.mamba2_chunk_scan", "helion_mamba2_chunk_scan_kernel", ), + "mamba2_chunk_state": ( + "tritonbench.operators.mamba2_chunk_state.operator", + "examples.mamba2_chunk_state", + "helion_mamba2_chunk_state_kernel", + ), } @@ -613,6 +618,13 @@ class RunResult: "helion_mamba2_chunk_scan_kernel_speedup": "helion_speedup", "helion_mamba2_chunk_scan_kernel_accuracy": "helion_accuracy", }, + "mamba2_chunk_state": { + "eager": "baseline", + "compile_speedup": "torch_compile_speedup", + "compile_accuracy": "torch_compile_accuracy", + "helion_mamba2_chunk_state_kernel_speedup": "helion_speedup", + "helion_mamba2_chunk_state_kernel_accuracy": "helion_accuracy", + }, } diff --git a/examples/mamba2_chunk_state.py b/examples/mamba2_chunk_state.py new file mode 100644 index 000000000..aa104adff --- /dev/null +++ b/examples/mamba2_chunk_state.py @@ -0,0 +1,185 @@ +""" +Mamba2 Chunk State Kernel +======================== + +This code implements a chunked state kernel as used for Mamba2 +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +import functools + +import torch +import torch.nn.functional as F + +import helion +from helion._testing import DEVICE +from helion._testing import run_example +import helion.language as hl + + +# %% +# Helion Kernel Implementation +# ---------------------------- +@helion.kernel() +def helion_mamba2_chunk_state_kernel( + B: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, dA_cumsum: torch.Tensor +) -> torch.Tensor: + """ + Argument: + B: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + + batch, seqlen, ngroups, dstate = B.shape + batch, seqlen, nheads, headdim = x.shape + batch, nheads, nchunks, chunk_size = dt.shape + batch, nheads, nchunks, chunk_size = dA_cumsum.shape + + assert nchunks == (seqlen + chunk_size - 1) // chunk_size + + block_m = hl.register_block_size(headdim) + block_n = hl.register_block_size(dstate) + block_k = hl.register_block_size(chunk_size) + + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + + dtype = B.dtype + accum_dtype = torch.float32 + assert x.dtype == dt.dtype == dA_cumsum.dtype == dtype + + out = B.new_empty(batch, nchunks, nheads, headdim, dstate) + + p = 1.44269504 + + for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( + [nheads, headdim, dstate, batch, nchunks], + block_size=[1, block_m, block_n, 1, 1], + ): + dA_cumsum_last = dA_cumsum[ + tile_b.begin, tile_h.begin, tile_c.begin, chunk_size - 1 + ].to(accum_dtype) + acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) + for tile_k in hl.tile(chunk_size, block_size=block_k): + x_local = x[ + tile_b.begin, + tile_k.index + tile_c.begin * chunk_size, + tile_h.begin, + tile_m, + ] + dA_cumsum_local = dA_cumsum[ + tile_b.begin, tile_h.begin, tile_c.begin, tile_k + ].to(accum_dtype) + dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k] + scale = torch.exp2(dA_cumsum_last * p - dA_cumsum_local * p) * dt_local + xt_local = (x_local.T * scale[None, :]).to(dtype) + B_local = B[ + tile_b.begin, + tile_c.begin * chunk_size + tile_k.index, + tile_h.begin // (nheads // ngroups), + tile_n, + ] + acc_o = hl.dot(xt_local, B_local, acc=acc_o) + out[tile_b.begin, tile_c.begin, tile_h.begin, tile_m, tile_n] = acc_o.to(dtype) + + return out + + +# %% +# Reference Function +# ------------- +def ref_chunk_state( + B: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, dA_cumsum: torch.Tensor +) -> torch.Tensor: + """ + Argument: + B: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, dhead) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, dhead, dstate) + """ + # Check constraints. + batch, seqlen, nheads, dhead = x.shape + dstate = B.shape[-1] + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert x.shape == (batch, seqlen, nheads, dhead) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + ngroups = B.shape[2] + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + B = torch.repeat_interleave(B, nheads // ngroups, dim=2) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + x = x.reshape(batch, nchunks, chunk_size, nheads, dhead) + B = B.reshape(batch, nchunks, chunk_size, nheads, dstate) + decay_states = torch.exp(dA_cumsum[:, :, :, -1:] - dA_cumsum) + return torch.einsum( + "bclhn,bhcl,bhcl,bclhp->bchpn", + B.to(x.dtype), + decay_states.to(x.dtype), + dt.to(x.dtype), + x, + ) + + +# %% +# Testing Function +# ------------- +def test( + init: str, + batch: int, + nheads: int, + ngroups: int, + seqlen: int, + chunk_size: int, + headdim: int, + dstate: int, + dtype: torch.dtype = torch.float16, +) -> None: + INIT = { + "r": functools.partial(torch.randn, dtype=dtype, device=DEVICE), + "u": functools.partial(torch.rand, dtype=dtype, device=DEVICE), + "z": functools.partial(torch.zeros, dtype=dtype, device=DEVICE), + "o": functools.partial(torch.ones, dtype=dtype, device=DEVICE), + } + nchunks = (seqlen + chunk_size - 1) // chunk_size + idx = 0 + + def fn(*args: int) -> torch.Tensor: + nonlocal idx + ret = INIT[init[idx]](*args) + idx += 1 + return ret + + B = fn(batch, seqlen, ngroups, dstate) + x = fn(batch, seqlen, nheads, headdim) + dt = fn(batch, nheads, nchunks, chunk_size) + dA_cumsum = fn(batch, nheads, nchunks, chunk_size) + args = (B, x, dt, dA_cumsum) + run_example(helion_mamba2_chunk_state_kernel, ref_chunk_state, args) + + +# %% +# Main Function +# ----------- +def main() -> None: + test("uuuu", 8, 80, 1, 4096, 256, 64, 128) + + +if __name__ == "__main__": + main()