diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 1859ab0d..36b16d47 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -230,7 +230,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" xBC = self.act( - self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):] + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, :-(self.d_conv - 1)] ) # (B, L, self.d_ssm + 2 * ngroups * d_state) else: xBC = causal_conv1d_fn(