diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index ab5ca3332..50838d055 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -441,7 +441,7 @@ def _chunk_state_bwd_db_kernel( scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - scale = tl.where(seq_idx_m == seq_idx_last, tl.minimum((dA_cs_last - dA_cs_m), 0.0), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) db *= (scale * dt_m)[:, None] if HAS_DDA_CS: # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum