From c12a58acfd93693d6b332d3036c63634a1c53a02 Mon Sep 17 00:00:00 2001 From: Matthieu Le Date: Tue, 17 Jun 2025 18:04:38 -0700 Subject: [PATCH] Fix _chunk_state_bwd_db_kernel when using seq_idx --- mamba_ssm/ops/triton/ssd_chunk_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index ab5ca3332..50838d055 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -441,7 +441,7 @@ def _chunk_state_bwd_db_kernel( scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - scale = tl.where(seq_idx_m == seq_idx_last, tl.minimum((dA_cs_last - dA_cs_m), 0.0), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) db *= (scale * dt_m)[:, None] if HAS_DDA_CS: # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum