Skip to content

Batched main verify + per-token GDN snapshot ring — realize the MTP >=1.3x speedup #30

@pekkah

Description

@pekkah

Background

The CPU MTP foundation (#25, commit a28bac4) ships sequential N=1 verify-and-accept. By construction this emits 2 main forwards per 2 tokens — same throughput as baseline. The acceptance criterion on #25 (>=1.3x decode speedup vs MTP-disabled) is unreachable without batched verify.

Memory note: project_mtp_n1_no_speedup.md documents the per-iter math.

What "batched verify" means here

After MTP drafts t2_draft, perform a single main forward that processes BOTH t1 (at position P) and t2_draft (at position P+1) — advancing the main caches through both positions in one pass — returning logits[P+1] (for verifying t2_draft) and logits[P+2] (for next iter's saved-logits). On accept (argmax(logits[P+1]) == t2_draft): emit t1, t2_draft, set saved-logits = logits[P+2]. On reject: emit t1, argmax(logits[P+1]), then ROLL THE GDN STATE BACK to position P+1 and re-do Forward(t2_target, P+1).

The "roll back" is why this needs the per-token GDN snapshot ring (Phase 11.7 / Risk #6 in docs/qwen35moe-plan.md).

Scope

  1. Batched forward for hybrid GDN. New IForwardPass.BatchVerify(int[] tokens, int startPos) (or similar) on HybridGdnForwardPass + CudaHybridGdnForwardPass. The GDN recurrence stays sequential per token (rank-1 update is inherently sequential), but attention projections and FFN/MoE projections batch via matvec-with-multiple-inputs. The attention scores+softmax+output are batched per-head (k=2 sequential is cheap; full GEMM unnecessary for k=2).

  2. Per-token GDN snapshot ring on GdnStateCache. Add Snapshot(int slot) / Restore(int slot) for N+1 host-pinned buffers (~140 MiB each for 27B). For N=2 a single pre-batch snapshot + restore-on-reject suffices.

  3. Mirror on CUDA hybrid. _gpuGdnScanState snapshots stay on-device (140 MiB each; one VRAM ring slot for N=2 = ~140 MiB extra) to avoid PCIe traffic on every iter.

  4. MtpDecoder switch to batched mode. When IForwardPass.SupportsBatchVerify, use the batched algorithm. Otherwise keep sequential N=1 (correctness preserved).

Acceptance criteria

  • bench-textgen.ps1 row qwen36-27b-mtp-cuda-hybrid shows decode t/s >= 1.3x of SHARPI_DISABLE_MTP=1 baseline on the same model + ctx.
  • N=2 draft length supported (algorithm structure should generalize; v1 can land N=2 specifically).
  • Greedy parity vs MTP-disabled baseline preserved (correction always picks argmax(target_logits) so output is identical).
  • Acceptance rate logging continues to work via SHARPI_TRACE_MTP=1.
  • All existing hybrid GDN tests still pass.

Risk callouts

  • The N=2 verify path advances main GDN state through 2 positions before checking acceptance. The snapshot ring is mandatory; without it, a rejected t2_draft leaves the GDN state un-rewindable.
  • Batched matvec needs to amortize weight reads across the 2 input vectors. CPU mmap FFN already benefits (the dense FFN is DRAM-bound; 2 dots per row at the cost of 1 weight read). GPU GEMM with N=2 is roughly 1.2x single-input cost.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions