Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,8 +1365,10 @@ def fused_gdn_gating_kernel(
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
tl.store(
beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
)


def fused_gdn_gating(
Expand All @@ -1387,7 +1389,7 @@ def fused_gdn_gating(
seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
fused_gdn_gating_kernel[grid](
g,
beta_output,
Expand Down