diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index fa5b813a..95907806 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -132,7 +132,8 @@ def _chunk_scan_fwd_kernel( dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) + cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0)) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: @@ -679,7 +680,8 @@ def _chunk_scan_bwd_dx_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. @@ -816,7 +818,8 @@ def _chunk_scan_bwd_dcb_kernel( dcb *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) - dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) if HAS_DDA_CS: tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") ddA_cs = dcb * cb @@ -1008,7 +1011,8 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old( acc *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) mask = offs_m[:, None] >= offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) acc = tl.cumsum(acc, axis=1) @@ -1134,7 +1138,8 @@ def _chunk_scan_bwd_ddAcs_stable_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) acc *= cb dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) rowsum_new = rowsum + tl.sum(acc, axis=1) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index bb49c9a9..ab5ca333 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel( dt += dt_bias[:, None] if DT_SOFTPLUS: dt_presoftplus = dt - dt = tl.where(dt <= 20.0, softplus(dt), ddt) + dt = tl.where(dt <= 20.0, softplus(dt), dt) clamp_mask = (dt < dt_min) | (dt > dt_max) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max) @@ -229,9 +229,11 @@ def _chunk_state_fwd_kernel( seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k else: - scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -332,7 +334,8 @@ def _chunk_state_bwd_dx_kernel( dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] + # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] + acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None] x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) @@ -434,9 +437,11 @@ def _chunk_state_bwd_db_kernel( dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) + # scale = tl.exp(dA_cs_last - dA_cs_m) + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.minimum((dA_cs_last - dA_cs_m), 0.0), 0.0) db *= (scale * dt_m)[:, None] if HAS_DDA_CS: # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum @@ -549,11 +554,13 @@ def _chunk_state_bwd_ddAcs_stable_kernel( dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) + # scale = tl.exp(dA_cs_last - dA_cs_m) + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) acc *= scale[:, None] x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) @@ -634,8 +641,10 @@ def _chunk_state_varlen_kernel( b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + # tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 58a6e04a..ab93ec08 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -47,6 +47,13 @@ def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] +def rearrange_and_update_stride(tensor, pattern=None, dim=2): + # ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern, + # if not call contiguous(), rearrange only if pattern is not None + tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor + return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged + + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), @@ -120,11 +127,13 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) + # scale = tl.exp(dA_cs_last - dA_cs_m) + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 # However, we're getting error with the Triton compiler 2.1.0 for that code path: # Unexpected mma -> mma layout conversion @@ -170,7 +179,8 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. @@ -776,7 +786,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1) seq_idx = seq_idx.contiguous() if seq_idx is not None else None xBC_conv = rearrange( - causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]), "b d s -> b s d" ) @@ -850,7 +860,7 @@ def backward(ctx, dout, *args): zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) # Recompute x, B, C xBC_conv = rearrange( - causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), "b d s -> b s d" ) @@ -900,10 +910,14 @@ def backward(ctx, dout, *args): else: doutproj_weight, doutproj_bias = None, None dxBC_given = rearrange(dxBC_given, "b s d -> b d s") - dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( - rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, - rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"] + dxBC_given_update, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, + rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"] ) + if dxBC_given.stride() != dxBC_given_update.stride(): + dxBC_given.copy_(dxBC_given_update) + else: + dxBC_given = dxBC_given_update dxBC_given = rearrange(dxBC_given, "b d s -> b s d") return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None