From df4edb956c6896b54584d1c4bd7f00d76e49ef59 Mon Sep 17 00:00:00 2001 From: Stephen Youn Date: Fri, 7 Jun 2024 18:42:05 +0000 Subject: [PATCH] modify mamba triton kernels compatible to triton version >= 3.0.0 --- mamba_ssm/ops/triton/selective_state_update.py | 6 ++++-- mamba_ssm/ops/triton/softplus.py | 17 +++++++++++++++++ mamba_ssm/ops/triton/ssd_chunk_state.py | 6 ++++-- 3 files changed, 25 insertions(+), 4 deletions(-) create mode 100755 mamba_ssm/ops/triton/softplus.py diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py index 193552a0..bc78de90 100644 --- a/mamba_ssm/ops/triton/selective_state_update.py +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -12,6 +12,8 @@ from einops import rearrange, repeat +from mamba_ssm.ops.triton.softplus import softplus + @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @@ -83,7 +85,7 @@ def _selective_scan_update_kernel( if HAS_DT_BIAS: dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + dt = softplus(dt) A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) dA = tl.exp(A * dt[:, None]) else: @@ -91,7 +93,7 @@ def _selective_scan_update_kernel( if HAS_DT_BIAS: dt += tl.load(dt_bias_ptr).to(tl.float32) if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + dt = softplus(dt) A = tl.load(A_ptr).to(tl.float32) dA = tl.exp(A * dt) # scalar, not a matrix diff --git a/mamba_ssm/ops/triton/softplus.py b/mamba_ssm/ops/triton/softplus.py new file mode 100755 index 00000000..de68b461 --- /dev/null +++ b/mamba_ssm/ops/triton/softplus.py @@ -0,0 +1,17 @@ +import triton +import triton.language as tl +from packaging import version + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + + +if TRITON3: + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt \ No newline at end of file diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 4333e6a7..0c23f327 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -12,6 +12,8 @@ from einops import rearrange, repeat +from mamba_ssm.ops.triton.softplus import softplus + def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @@ -66,7 +68,7 @@ def _chunk_cumsum_fwd_kernel( dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) dt += dt_bias[:, None] if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + dt = softplus(dt) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) @@ -139,7 +141,7 @@ def _chunk_cumsum_bwd_kernel( dt += dt_bias[:, None] if DT_SOFTPLUS: dt_presoftplus = dt - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), ddt) + dt = softplus(dt) clamp_mask = (dt < dt_min) | (dt > dt_max) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max)