Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions mamba_ssm/ops/triton/selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,17 @@ def _selective_scan_update_kernel(
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
out_ptrs = out_ptr + offs_m * stride_out_dim

if HAS_STATE_BATCH_INDICES:
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr)
# Skip padding tokens
if state_batch_idx < 0:
tl.store(out_ptrs, 0.0, mask=offs_m < dim)
return
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
Expand All @@ -67,9 +75,7 @@ def _selective_scan_update_kernel(
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
if HAS_Z:
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
x_ptrs = x_ptr + offs_m * stride_x_dim
Expand All @@ -85,7 +91,6 @@ def _selective_scan_update_kernel(
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim

state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/ssd_chunk_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _chunk_state_fwd_kernel(
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
else:
# scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
scale = tl.where((seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
Expand Down
80 changes: 80 additions & 0 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,83 @@ def test_generation_varlen():
out_varlen = torch.cat(scores, dim=1)
print(f"Max diff: {(out_varlen - out_ref).abs().max()}")
assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()

def test_generation_varlen_with_padding():
seqlens = [170, 65, 100]
non_padded_seqlen = sum(seqlens)
padded_seqlen = 512
seqlens.append(padded_seqlen - non_padded_seqlen)
genlen = 20
total_seqlen = sum(seqlens)
assert total_seqlen == padded_seqlen
device = "cuda"
dtype = torch.float16

config = MambaConfig(
d_model=1024,
n_layer=4,
vocab_size=50277,
ssm_cfg=dict(layer="Mamba2"),
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=16,
)
torch.manual_seed(2357)
model = MambaLMHeadModel(config, device=device, dtype=dtype)
xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens]

# Reference 1: Forward pass with seq_idx
x = torch.cat(xs[:-1], dim=1)
seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device)
for i, ids in enumerate(xs[:-1])], dim=0).unsqueeze(0)
cu_seqlens = F.pad(torch.tensor(seqlens[:-1], device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))

out_ref = model(x, seq_idx=seq_idx).logits
# Only take the last @genlen logits of each sequence
out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]
for i in range(len(seqlens) - 1)], dim=0)

# Reference 2: Generate the last @genlen tokens of each sequence in a for loop
out_loop = []
for input_ids in xs[:-1]:
out = model.generate(
input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,
return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,
).scores
out_loop.append(torch.stack(out, dim=1))
out_loop = torch.cat(out_loop, dim=0)
print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}")

# Varlen generation
input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1)
prompt_seqlens = [seqlen - genlen for seqlen in seqlens]
cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device)
for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0)
inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens))

# Account for padding
offset = genlen * len(seqlens)
seq_idx[non_padded_seqlen - offset : padded_seqlen - offset] = -1
cu_seqlens[-1] = cu_seqlens[-2]

scores, sequences = [], []
# Both seq_idx and cu_seqlens must be passed in for varlen generation
logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits
logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d")
scores.append(logits)
# In practice we should sample. In this case we take from the teacher_output for testing
sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1")
sequences.append(sampled_tokens)
for i in range(1, genlen):
inference_params.seqlen_offset += 1
logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits
scores.append(logits)
# In practice we should sample. In this case we take from the teacher_output for testing
sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1")
sequences.append(sampled_tokens)
out_varlen = torch.cat(scores, dim=1)

print(f"Max diff: {(out_varlen[:-1] - out_ref).abs().max()}")
assert (out_varlen[:-1] - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()