Skip to content

feat: DFlash block-parallel draft model training#71

Closed
cicirori wants to merge 8 commits intomainfrom
feat/dflash-training-oss
Closed

feat: DFlash block-parallel draft model training#71
cicirori wants to merge 8 commits intomainfrom
feat/dflash-training-oss

Conversation

@cicirori
Copy link
Copy Markdown
Collaborator

Summary

  • Implement DFlash (Block Diffusion for Flash Speculative Decoding) training in TorchSpec
  • Based on community PR feat: DFlash (block-parallel) draft model training #64 by @zhubohao911 with additional correctness fixes
  • DFlash predicts 16-token blocks in parallel using dual-source KV attention, achieving ~5x fewer forward passes per training step vs Eagle3

Changes over PR #64

This branch includes all commits from #64 plus correctness fixes:

  • Mooncake storage leak fix: SglEngine omits last_hidden_states from DFlash tensor schema and passes spec_training_store_last_hidden_states=False to prevent orphaned _lhs keys
  • Fail-fast validation: Draft config parsing moved before dataset loading; DFlash validates inference_engine_type=sgl, defer_tokenization=False, and min_loss_tokens >= 2 * block_size before any async work begins
  • DataCollator safety: Hard error on missing target/last_hidden_states replaced with one-time debug log (avoids per-batch flood for DFlash, Eagle3 still fails at trainer level)
  • set_vocab_buffers guard: hasattr check prevents AttributeError on DFlashDraftModel
  • Lint fixes: Shebang permissions, ambiguous variable names, noqa for lazy imports

Remaining work

  • SGLang patch regeneration: add spec_training_store_last_hidden_states flag to patches/sglang/v0.5.8.post1/sglang.patch and sglang_decode.patch via update_sglang_patch.sh workflow
  • E2E validation on B200 with regenerated patches

Test plan

  • Pre-commit (ruff + ruff-format) all passing
  • python -m pytest tests/test_dflash.py -v on GPU
  • Patch smoke test: apply_sglang_patch.sh and apply_sglang_patch.sh --decode with regenerated patches
  • E2E training on B200

🤖 Generated with Claude Code

zhubohao911 and others added 5 commits April 12, 2026 20:06
Add DFlash block-parallel draft model training to TorchSpec, enabling
disaggregated online training with SGLang inference backend. DFlash
predicts 16-token blocks in parallel using dual-source KV attention,
achieving ~5x fewer forward passes per step compared to Eagle3.

Core implementation:
- torchspec/models/draft/dflash.py: DFlashDraftModel architecture
  (5-layer transformer, dual-source KV, GQA, shared embedding/LM head)
- torchspec/models/dflash.py: Training wrapper with anchor sampling,
  block-causal mask (FlexAttention), CE loss with exponential decay
- torchspec/training/dflash_trainer.py: FSDP2 trainer with WSD scheduler
- tests/test_dflash.py: 67 tests covering config, architecture,
  anchor sampling, block-causal mask, forward pass, loss, accuracy

Integration:
- Config-based trainer dispatch (DFlashConfig → DFlashTrainer)
- Generalized N-layer target model support (was hardcoded to 3)
- Async data prefetching, min_lr/weight_decay optimizer params
- Checkpoint rotation, lazy SGLang/vLLM imports

Validation (best model P2-WSD, 800K PerfectBlend, 3 epochs, 8x H100):
- Math avg τ=3.94 (2.7% gap to z-lab's 4.05)
- Decode-only speedup: 3.02x on livecodebench
- Read target_num_hidden_layers from target model config instead of
  hardcoding 36 (Qwen3-8B specific)
- Remove duplicate SglEngine/VllmEngine imports in factory.py
- Explain why accuracy uses binary mask without decay (intentional)
- Note _apply_rotary_pos_emb kept as utility matching SpecForge
- Document RoPE cache +20 buffer purpose
- Reference SpecForge PR #427 for bidirectional intra-block attention
…r, lint

Address runtime correctness issues in DFlash training pipeline:

- SglEngine: omit last_hidden_states from DFlash tensor schema to prevent
  Mooncake storage leak (orphaned _lhs keys never read or cleaned up).
  Add self._is_dflash flag from draft_model_config_obj. Pass
  spec_training_store_last_hidden_states=False in engine kwargs.

- train_entry: move draft config parsing + DFlash validation before
  dataset loading (step [1.5]) so defer_tokenization, backend mismatch,
  and min_loss_tokens misconfiguration fail fast before async work.
  Add min_loss_tokens >= 2*block_size consistency check.

- DataCollator: replace hard ValueError with one-time debug log when
  hidden_states present without target/last_hidden_states. Avoids
  per-batch warning flood for DFlash; Eagle3 still fails at trainer.

- trainer_actor: guard set_vocab_buffers with hasattr (DFlashDraftModel
  does not implement this method).

- Fix pre-existing lint: shebang permissions, ambiguous var names,
  noqa for lazy imports.

- Add TestDFlashHotfixes covering collator and min_loss_tokens validation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…dden_states

Add spec_training_store_last_hidden_states flag to ServerArgs (default True
for Eagle3 backward compat). Guard _lhs storage in
_send_hidden_states_to_mooncake() so DFlash does not write orphaned
last_hidden_states keys to Mooncake.

Regenerated via tools/update_sglang_patch.sh on B200.
Smoke tested: clean git apply on base commit succeeds, flag present in
both server_args.py and scheduler_output_processor_mixin.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
cicirori and others added 3 commits April 13, 2026 19:10
- Remove torch.compile graph break in _sample_anchor_positions() by
  eliminating .item() call (valid_counts check). Normal path handles
  all-zero case via keep_mask=False → zero loss.

- Fix sglang_qwen3_8b_dflash.yaml: correct GPU allocation comment
  (8 GPU not 4), global_batch=8 not 12, switch to REPLICATE strategy.

- Add configs/dflash_qwen3_8b_repro.yaml for 8x B200 reproduction.

- Add scripts/bench_dflash_opts.py: micro-benchmarks for QKV fusion
  (1.39x), gate-up fusion (0.43x, not worth it), batch size scaling
  (1.48x at bs=4), and torch.compile modes (1.81x default best).

- Add scripts/eval_dflash.sh: evaluation script using SpecForge
  benchmarks with DFlash speculative decoding.

Validation on 8x B200 (760K PerfectBlend, 3 epochs, 18.8h):
  gsm8k: τ=5.55, accuracy=94%
  math500: τ=3.25, accuracy=30.5%

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ilures

- Remove dead code: `_apply_rotary_pos_emb` (never called),
  `get_default_dflash_aux_layer_ids` (zero callers)
- Remove unused params `norm_weight`/`norm_eps` from DFlashModel.forward()
- Use Python scalars in torch.where instead of creating scalar CUDA
  tensors on every step (avoids per-step GPU allocation + potential
  torch.compile recompilation)
- Replace class-level mutable `_warned_no_target_lhs` flag with
  standard `warnings.warn(stacklevel=2)` for proper once-only behavior
- Add debug logging to silent ImportError catches in engine/__init__.py
  so internal import bugs are traceable

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Checkpoint cleanup: move _cleanup_old_checkpoints AFTER save at all 3
  call sites to prevent race condition where cleanup deletes existing
  checkpoints before new save succeeds (data loss on save failure).

- Checkpoint cleanup: replace shutil.rmtree(ignore_errors=True) with
  try/except that logs warning on failure, preventing silent disk leak.

- DFlashTrainer: raise TypeError for unsupported draft_model_config
  types instead of silently accepting anything.

- PrefetchedDataFetcher: preserve original traceback when re-raising
  prefetch thread exceptions, so error points to actual failure site.

- DFlashDraftModel.load_embedding: add weights_only=True to torch.load
  calls for security and to suppress deprecation warning.

- load_hf_dataset: narrow bare except Exception to specific schema
  inference errors (ValueError, TypeError, etc.) and log the fallback,
  so auth errors and network failures are not silently swallowed.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@cicirori cicirori closed this Apr 13, 2026
@torchspec-bot torchspec-bot deleted the feat/dflash-training-oss branch April 14, 2026 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants