Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions mamba_ssm/ops/triton/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def _chunk_scan_fwd_kernel(
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
# cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0))
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
Expand Down Expand Up @@ -679,7 +680,8 @@ def _chunk_scan_bwd_dx_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
# cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
Expand Down Expand Up @@ -816,7 +818,8 @@ def _chunk_scan_bwd_dcb_kernel(
dcb *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)
dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
# dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
if HAS_DDA_CS:
tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet")
ddA_cs = dcb * cb
Expand Down Expand Up @@ -1008,7 +1011,8 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old(
acc *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
# acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
mask = offs_m[:, None] >= offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
acc = tl.cumsum(acc, axis=1)
Expand Down Expand Up @@ -1134,7 +1138,8 @@ def _chunk_scan_bwd_ddAcs_stable_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
acc *= cb
dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
# acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
rowsum_new = rowsum + tl.sum(acc, axis=1)
Expand Down
27 changes: 18 additions & 9 deletions mamba_ssm/ops/triton/ssd_chunk_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel(
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt_presoftplus = dt
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
dt = tl.where(dt <= 20.0, softplus(dt), dt)
clamp_mask = (dt < dt_min) | (dt > dt_max)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
Expand Down Expand Up @@ -229,9 +229,11 @@ def _chunk_state_fwd_kernel(
seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
# scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
else:
scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
# scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
Expand Down Expand Up @@ -332,7 +334,8 @@ def _chunk_state_bwd_dx_kernel(
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
# acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None]

x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
Expand Down Expand Up @@ -434,9 +437,11 @@ def _chunk_state_bwd_db_kernel(
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_last - dA_cs_m)
# scale = tl.exp(dA_cs_last - dA_cs_m)
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
else:
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
# scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
scale = tl.where(seq_idx_m == seq_idx_last, tl.minimum((dA_cs_last - dA_cs_m), 0.0), 0.0)
db *= (scale * dt_m)[:, None]
if HAS_DDA_CS:
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
Expand Down Expand Up @@ -549,11 +554,13 @@ def _chunk_state_bwd_ddAcs_stable_kernel(
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_last - dA_cs_m)
# scale = tl.exp(dA_cs_last - dA_cs_m)
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
else:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
# scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
acc *= scale[:, None]

x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
Expand Down Expand Up @@ -634,8 +641,10 @@ def _chunk_state_varlen_kernel(
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
# scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
# tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
Expand Down
30 changes: 22 additions & 8 deletions mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def init_to_zero(names):
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]


def rearrange_and_update_stride(tensor, pattern=None, dim=2):
# ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern,
# if not call contiguous(), rearrange only if pattern is not None
tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor
return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged


@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
Expand Down Expand Up @@ -120,11 +127,13 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(

dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_last - dA_cs_m)
# scale = tl.exp(dA_cs_last - dA_cs_m)
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
else:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
# scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
# Unexpected mma -> mma layout conversion
Expand Down Expand Up @@ -170,7 +179,8 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
# cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
Expand Down Expand Up @@ -776,7 +786,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size,
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
xBC_conv = rearrange(
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"),
conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
"b d s -> b s d"
)
Expand Down Expand Up @@ -850,7 +860,7 @@ def backward(ctx, dout, *args):
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
# Recompute x, B, C
xBC_conv = rearrange(
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"),
conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
"b d s -> b s d"
)
Expand Down Expand Up @@ -900,10 +910,14 @@ def backward(ctx, dout, *args):
else:
doutproj_weight, doutproj_bias = None, None
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
dxBC_given_update, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"]
)
if dxBC_given.stride() != dxBC_given_update.stride():
dxBC_given.copy_(dxBC_given_update)
else:
dxBC_given = dxBC_given_update
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None

Expand Down