Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mamba_ssm/ops/triton/selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from einops import rearrange, repeat

from mamba_ssm.ops.triton.softplus import softplus


@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
Expand Down Expand Up @@ -83,15 +85,15 @@ def _selective_scan_update_kernel(
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
dt = softplus(dt)
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix

Expand Down
17 changes: 17 additions & 0 deletions mamba_ssm/ops/triton/softplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import triton
import triton.language as tl
from packaging import version

TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")


if TRITON3:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
return dt
else:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
return dt
6 changes: 4 additions & 2 deletions mamba_ssm/ops/triton/ssd_chunk_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from einops import rearrange, repeat

from mamba_ssm.ops.triton.softplus import softplus


def init_to_zero(names):
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
Expand Down Expand Up @@ -66,7 +68,7 @@ def _chunk_cumsum_fwd_kernel(
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
dt = softplus(dt)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
Expand Down Expand Up @@ -139,7 +141,7 @@ def _chunk_cumsum_bwd_kernel(
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt_presoftplus = dt
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), ddt)
dt = softplus(dt)
clamp_mask = (dt < dt_min) | (dt > dt_max)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
Expand Down