From 67e616ebdbeffb6c572cf9c5940dda3b03b3ae07 Mon Sep 17 00:00:00 2001 From: Roger Waleffe Date: Mon, 4 Nov 2024 22:12:10 -0800 Subject: [PATCH 1/7] Numerical stability for large negative values --- mamba_ssm/ops/triton/ssd_chunk_scan.py | 15 ++++++++++----- mamba_ssm/ops/triton/ssd_chunk_state.py | 23 ++++++++++++++++------- mamba_ssm/ops/triton/ssd_combined.py | 9 ++++++--- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index fa5b813a..95907806 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -132,7 +132,8 @@ def _chunk_scan_fwd_kernel( dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) + cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0)) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: @@ -679,7 +680,8 @@ def _chunk_scan_bwd_dx_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. @@ -816,7 +818,8 @@ def _chunk_scan_bwd_dcb_kernel( dcb *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) - dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) if HAS_DDA_CS: tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") ddA_cs = dcb * cb @@ -1008,7 +1011,8 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old( acc *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) mask = offs_m[:, None] >= offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) acc = tl.cumsum(acc, axis=1) @@ -1134,7 +1138,8 @@ def _chunk_scan_bwd_ddAcs_stable_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) acc *= cb dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) rowsum_new = rowsum + tl.sum(acc, axis=1) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index bb49c9a9..0589a45c 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -229,9 +229,11 @@ def _chunk_state_fwd_kernel( seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k else: scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -332,7 +334,8 @@ def _chunk_state_bwd_dx_kernel( dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] + # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] + acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None] x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) @@ -434,9 +437,11 @@ def _chunk_state_bwd_db_kernel( dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) + # scale = tl.exp(dA_cs_last - dA_cs_m) + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.minimum((dA_cs_last - dA_cs_m), 0.0), 0.0) db *= (scale * dt_m)[:, None] if HAS_DDA_CS: # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum @@ -549,11 +554,13 @@ def _chunk_state_bwd_ddAcs_stable_kernel( dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) + # scale = tl.exp(dA_cs_last - dA_cs_m) + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) acc *= scale[:, None] x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) @@ -634,8 +641,10 @@ def _chunk_state_varlen_kernel( b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + # tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 58a6e04a..37691fce 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -120,11 +120,13 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) + # scale = tl.exp(dA_cs_last - dA_cs_m) + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 # However, we're getting error with the Triton compiler 2.1.0 for that code path: # Unexpected mma -> mma layout conversion @@ -170,7 +172,8 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. From d3e3d526ba69b06c9898eebd1b682c151879207a Mon Sep 17 00:00:00 2001 From: Duncan Riach Date: Thu, 6 Feb 2025 17:33:43 -0800 Subject: [PATCH 2/7] Fix causal_conv1d xBC stride not multiple of 8 issue --- mamba_ssm/ops/triton/ssd_combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 37691fce..e6cb922a 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -779,7 +779,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1) seq_idx = seq_idx.contiguous() if seq_idx is not None else None xBC_conv = rearrange( - causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s").contiguous(), conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]), "b d s -> b s d" ) From 9beb2b5bce11acfcd6165b5eeb59510294996995 Mon Sep 17 00:00:00 2001 From: Duncan Riach Date: Fri, 14 Feb 2025 12:17:09 -0800 Subject: [PATCH 3/7] Fix backprop for causal_conv1d xBC stride not multiple of 8 issue --- mamba_ssm/ops/triton/ssd_combined.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index e6cb922a..ebb6c277 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -853,7 +853,7 @@ def backward(ctx, dout, *args): zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) # Recompute x, B, C xBC_conv = rearrange( - causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s").contiguous(), conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), "b d s -> b s d" ) @@ -903,10 +903,11 @@ def backward(ctx, dout, *args): else: doutproj_weight, doutproj_bias = None, None dxBC_given = rearrange(dxBC_given, "b s d -> b d s") - dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( - rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, - rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"] + dxBC_given_contiguous, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + rearrange(xBC, "b s d -> b d s").contiguous(), conv1d_weight, conv1d_bias, + rearrange(dxBC, "b s d -> b d s").contiguous(), seq_idx, None, None, dxBC_given.contiguous(), False, ctx.activation in ["silu", "swish"] ) + dxBC_given.copy_(dxBC_given_contiguous) dxBC_given = rearrange(dxBC_given, "b d s -> b s d") return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None From 41a0994558340a5f275f3044a927f26d973e794a Mon Sep 17 00:00:00 2001 From: Roger Waleffe Date: Mon, 24 Mar 2025 16:26:08 -0700 Subject: [PATCH 4/7] Fix ddt -> dt typo --- mamba_ssm/ops/triton/ssd_chunk_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 0589a45c..eb2846b3 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel( dt += dt_bias[:, None] if DT_SOFTPLUS: dt_presoftplus = dt - dt = tl.where(dt <= 20.0, softplus(dt), ddt) + dt = tl.where(dt <= 20.0, softplus(dt), dt) clamp_mask = (dt < dt_min) | (dt > dt_max) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max) From a8240c18e2d25dea57b3e2293f9acf8f9e920bd4 Mon Sep 17 00:00:00 2001 From: Roger Waleffe Date: Mon, 24 Mar 2025 16:34:22 -0700 Subject: [PATCH 5/7] Add nit comment --- mamba_ssm/ops/triton/ssd_chunk_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index eb2846b3..ab5ca333 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -232,7 +232,7 @@ def _chunk_state_fwd_kernel( # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k else: - scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) From 0284454bd4ec61c87a9fd2e370df822a01557a1f Mon Sep 17 00:00:00 2001 From: Roger Waleffe Date: Mon, 31 Mar 2025 11:58:52 -0700 Subject: [PATCH 6/7] Call ontiguous before causal_conv1d only when stride is not a multiple of 8 --- mamba_ssm/ops/triton/ssd_combined.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index ebb6c277..05eba533 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -47,6 +47,13 @@ def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] +def rearrange_and_update_stride(tensor, pattern=None, dim=2): + # ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern, + # if not call contiguous(), rearrange only if pattern is not None + tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor + return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged + + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), @@ -779,7 +786,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1) seq_idx = seq_idx.contiguous() if seq_idx is not None else None xBC_conv = rearrange( - causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s").contiguous(), + causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]), "b d s -> b s d" ) @@ -853,7 +860,7 @@ def backward(ctx, dout, *args): zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) # Recompute x, B, C xBC_conv = rearrange( - causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s").contiguous(), + causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), "b d s -> b s d" ) @@ -903,11 +910,11 @@ def backward(ctx, dout, *args): else: doutproj_weight, doutproj_bias = None, None dxBC_given = rearrange(dxBC_given, "b s d -> b d s") - dxBC_given_contiguous, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( - rearrange(xBC, "b s d -> b d s").contiguous(), conv1d_weight, conv1d_bias, - rearrange(dxBC, "b s d -> b d s").contiguous(), seq_idx, None, None, dxBC_given.contiguous(), False, ctx.activation in ["silu", "swish"] + dxBC_given_update, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, + rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"] ) - dxBC_given.copy_(dxBC_given_contiguous) + dxBC_given.copy_(dxBC_given_update) dxBC_given = rearrange(dxBC_given, "b d s -> b s d") return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None From 39135c8f53cce4f1843bc04f561f893254c24c6a Mon Sep 17 00:00:00 2001 From: Roger Waleffe Date: Mon, 31 Mar 2025 12:16:05 -0700 Subject: [PATCH 7/7] Copy only if strides differ --- mamba_ssm/ops/triton/ssd_combined.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 05eba533..ab93ec08 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -914,7 +914,10 @@ def backward(ctx, dout, *args): rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"] ) - dxBC_given.copy_(dxBC_given_update) + if dxBC_given.stride() != dxBC_given_update.stride(): + dxBC_given.copy_(dxBC_given_update) + else: + dxBC_given = dxBC_given_update dxBC_given = rearrange(dxBC_given, "b d s -> b s d") return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None