Background
The CPU MTP foundation (#25, commit a28bac4) ships sequential N=1 verify-and-accept. By construction this emits 2 main forwards per 2 tokens — same throughput as baseline. The acceptance criterion on #25 (>=1.3x decode speedup vs MTP-disabled) is unreachable without batched verify.
Memory note: project_mtp_n1_no_speedup.md documents the per-iter math.
What "batched verify" means here
After MTP drafts t2_draft, perform a single main forward that processes BOTH t1 (at position P) and t2_draft (at position P+1) — advancing the main caches through both positions in one pass — returning logits[P+1] (for verifying t2_draft) and logits[P+2] (for next iter's saved-logits). On accept (argmax(logits[P+1]) == t2_draft): emit t1, t2_draft, set saved-logits = logits[P+2]. On reject: emit t1, argmax(logits[P+1]), then ROLL THE GDN STATE BACK to position P+1 and re-do Forward(t2_target, P+1).
The "roll back" is why this needs the per-token GDN snapshot ring (Phase 11.7 / Risk #6 in docs/qwen35moe-plan.md).
Scope
-
Batched forward for hybrid GDN. New IForwardPass.BatchVerify(int[] tokens, int startPos) (or similar) on HybridGdnForwardPass + CudaHybridGdnForwardPass. The GDN recurrence stays sequential per token (rank-1 update is inherently sequential), but attention projections and FFN/MoE projections batch via matvec-with-multiple-inputs. The attention scores+softmax+output are batched per-head (k=2 sequential is cheap; full GEMM unnecessary for k=2).
-
Per-token GDN snapshot ring on GdnStateCache. Add Snapshot(int slot) / Restore(int slot) for N+1 host-pinned buffers (~140 MiB each for 27B). For N=2 a single pre-batch snapshot + restore-on-reject suffices.
-
Mirror on CUDA hybrid. _gpuGdnScanState snapshots stay on-device (140 MiB each; one VRAM ring slot for N=2 = ~140 MiB extra) to avoid PCIe traffic on every iter.
-
MtpDecoder switch to batched mode. When IForwardPass.SupportsBatchVerify, use the batched algorithm. Otherwise keep sequential N=1 (correctness preserved).
Acceptance criteria
Risk callouts
- The N=2 verify path advances main GDN state through 2 positions before checking acceptance. The snapshot ring is mandatory; without it, a rejected
t2_draft leaves the GDN state un-rewindable.
- Batched matvec needs to amortize weight reads across the 2 input vectors. CPU mmap FFN already benefits (the dense FFN is DRAM-bound; 2 dots per row at the cost of 1 weight read). GPU GEMM with N=2 is roughly 1.2x single-input cost.
Related
Background
The CPU MTP foundation (#25, commit a28bac4) ships sequential N=1 verify-and-accept. By construction this emits 2 main forwards per 2 tokens — same throughput as baseline. The acceptance criterion on #25 (>=1.3x decode speedup vs MTP-disabled) is unreachable without batched verify.
Memory note:
project_mtp_n1_no_speedup.mddocuments the per-iter math.What "batched verify" means here
After MTP drafts
t2_draft, perform a single main forward that processes BOTHt1(at position P) andt2_draft(at position P+1) — advancing the main caches through both positions in one pass — returninglogits[P+1](for verifyingt2_draft) andlogits[P+2](for next iter's saved-logits). On accept (argmax(logits[P+1]) == t2_draft): emitt1, t2_draft, set saved-logits =logits[P+2]. On reject: emitt1, argmax(logits[P+1]), then ROLL THE GDN STATE BACK to position P+1 and re-doForward(t2_target, P+1).The "roll back" is why this needs the per-token GDN snapshot ring (Phase 11.7 / Risk #6 in
docs/qwen35moe-plan.md).Scope
Batched forward for hybrid GDN. New
IForwardPass.BatchVerify(int[] tokens, int startPos)(or similar) onHybridGdnForwardPass+CudaHybridGdnForwardPass. The GDN recurrence stays sequential per token (rank-1 update is inherently sequential), but attention projections and FFN/MoE projections batch via matvec-with-multiple-inputs. The attention scores+softmax+output are batched per-head (k=2 sequential is cheap; full GEMM unnecessary for k=2).Per-token GDN snapshot ring on
GdnStateCache. AddSnapshot(int slot)/Restore(int slot)for N+1 host-pinned buffers (~140 MiB each for 27B). For N=2 a single pre-batch snapshot + restore-on-reject suffices.Mirror on CUDA hybrid.
_gpuGdnScanStatesnapshots stay on-device (140 MiB each; one VRAM ring slot for N=2 = ~140 MiB extra) to avoid PCIe traffic on every iter.MtpDecoder switch to batched mode. When
IForwardPass.SupportsBatchVerify, use the batched algorithm. Otherwise keep sequential N=1 (correctness preserved).Acceptance criteria
bench-textgen.ps1rowqwen36-27b-mtp-cuda-hybridshows decode t/s >= 1.3x ofSHARPI_DISABLE_MTP=1baseline on the same model + ctx.argmax(target_logits)so output is identical).SHARPI_TRACE_MTP=1.Risk callouts
t2_draftleaves the GDN state un-rewindable.Related
docs/qwen35moe-plan.mdPhase 11.7 + Risk Forward-pass divergence on Qwen3-Coder 30B-A3B for some prompts (top-1 logit goes to <|endoftext|>) #6, and the new "follow-ups" list at the bottom.