diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py index d425bc728..be82d1635 100644 --- a/mamba_ssm/ops/triton/selective_state_update.py +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -51,9 +51,17 @@ def _selective_scan_update_kernel( pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + out_ptrs = out_ptr + offs_m * stride_out_dim + if HAS_STATE_BATCH_INDICES: state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr) + # Skip padding tokens + if state_batch_idx < 0: + tl.store(out_ptrs, 0.0, mask=offs_m < dim) + return state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head @@ -67,9 +75,7 @@ def _selective_scan_update_kernel( C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group if HAS_Z: z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) x_ptrs = x_ptr + offs_m * stride_x_dim @@ -85,7 +91,6 @@ def _selective_scan_update_kernel( D_ptrs = D_ptr + offs_m * stride_D_dim if HAS_Z: z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 50838d055..633c66e82 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -233,7 +233,7 @@ def _chunk_state_fwd_kernel( scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k else: # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) - scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) + scale = tl.where((seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) diff --git a/tests/test_generation.py b/tests/test_generation.py index 77e1aedfa..c1b122128 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -111,3 +111,83 @@ def test_generation_varlen(): out_varlen = torch.cat(scores, dim=1) print(f"Max diff: {(out_varlen - out_ref).abs().max()}") assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max() + +def test_generation_varlen_with_padding(): + seqlens = [170, 65, 100] + non_padded_seqlen = sum(seqlens) + padded_seqlen = 512 + seqlens.append(padded_seqlen - non_padded_seqlen) + genlen = 20 + total_seqlen = sum(seqlens) + assert total_seqlen == padded_seqlen + device = "cuda" + dtype = torch.float16 + + config = MambaConfig( + d_model=1024, + n_layer=4, + vocab_size=50277, + ssm_cfg=dict(layer="Mamba2"), + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + pad_vocab_size_multiple=16, + ) + torch.manual_seed(2357) + model = MambaLMHeadModel(config, device=device, dtype=dtype) + xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens] + + # Reference 1: Forward pass with seq_idx + x = torch.cat(xs[:-1], dim=1) + seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device) + for i, ids in enumerate(xs[:-1])], dim=0).unsqueeze(0) + cu_seqlens = F.pad(torch.tensor(seqlens[:-1], device=device, dtype=torch.int32).cumsum(dim=0), (1, 0)) + + out_ref = model(x, seq_idx=seq_idx).logits + # Only take the last @genlen logits of each sequence + out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1] + for i in range(len(seqlens) - 1)], dim=0) + + # Reference 2: Generate the last @genlen tokens of each sequence in a for loop + out_loop = [] + for input_ids in xs[:-1]: + out = model.generate( + input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True, + return_dict_in_generate=True, cg=True, teacher_outputs=input_ids, + ).scores + out_loop.append(torch.stack(out, dim=1)) + out_loop = torch.cat(out_loop, dim=0) + print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}") + + # Varlen generation + input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1) + prompt_seqlens = [seqlen - genlen for seqlen in seqlens] + cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0)) + seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device) + for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0) + inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens)) + + # Account for padding + offset = genlen * len(seqlens) + seq_idx[non_padded_seqlen - offset : padded_seqlen - offset] = -1 + cu_seqlens[-1] = cu_seqlens[-2] + + scores, sequences = [], [] + # Both seq_idx and cu_seqlens must be passed in for varlen generation + logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits + logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d") + scores.append(logits) + # In practice we should sample. In this case we take from the teacher_output for testing + sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1") + sequences.append(sampled_tokens) + for i in range(1, genlen): + inference_params.seqlen_offset += 1 + logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits + scores.append(logits) + # In practice we should sample. In this case we take from the teacher_output for testing + sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1") + sequences.append(sampled_tokens) + out_varlen = torch.cat(scores, dim=1) + + print(f"Max diff: {(out_varlen[:-1] - out_ref).abs().max()}") + assert (out_varlen[:-1] - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()