From e0976131301c6e511371a1549640152ddb4e4786 Mon Sep 17 00:00:00 2001 From: Thomas Frost <100413067+tdgfrost@users.noreply.github.com> Date: Mon, 7 Oct 2024 13:25:32 +0100 Subject: [PATCH] Fixes typo when indexing after self.conv1d Original version indexes sequence to the padding, not the sequence length (and has a typo when refering self.dconv instead of self.d_conv). --- mamba_ssm/modules/mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 1859ab0d..36b16d47 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -230,7 +230,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" xBC = self.act( - self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):] + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, :-(self.d_conv - 1)] ) # (B, L, self.d_ssm + 2 * ngroups * d_state) else: xBC = causal_conv1d_fn(