feat: DFlash block-parallel draft model training#71
Closed
Conversation
Add DFlash block-parallel draft model training to TorchSpec, enabling disaggregated online training with SGLang inference backend. DFlash predicts 16-token blocks in parallel using dual-source KV attention, achieving ~5x fewer forward passes per step compared to Eagle3. Core implementation: - torchspec/models/draft/dflash.py: DFlashDraftModel architecture (5-layer transformer, dual-source KV, GQA, shared embedding/LM head) - torchspec/models/dflash.py: Training wrapper with anchor sampling, block-causal mask (FlexAttention), CE loss with exponential decay - torchspec/training/dflash_trainer.py: FSDP2 trainer with WSD scheduler - tests/test_dflash.py: 67 tests covering config, architecture, anchor sampling, block-causal mask, forward pass, loss, accuracy Integration: - Config-based trainer dispatch (DFlashConfig → DFlashTrainer) - Generalized N-layer target model support (was hardcoded to 3) - Async data prefetching, min_lr/weight_decay optimizer params - Checkpoint rotation, lazy SGLang/vLLM imports Validation (best model P2-WSD, 800K PerfectBlend, 3 epochs, 8x H100): - Math avg τ=3.94 (2.7% gap to z-lab's 4.05) - Decode-only speedup: 3.02x on livecodebench
- Read target_num_hidden_layers from target model config instead of hardcoding 36 (Qwen3-8B specific) - Remove duplicate SglEngine/VllmEngine imports in factory.py
- Explain why accuracy uses binary mask without decay (intentional) - Note _apply_rotary_pos_emb kept as utility matching SpecForge - Document RoPE cache +20 buffer purpose - Reference SpecForge PR #427 for bidirectional intra-block attention
…r, lint Address runtime correctness issues in DFlash training pipeline: - SglEngine: omit last_hidden_states from DFlash tensor schema to prevent Mooncake storage leak (orphaned _lhs keys never read or cleaned up). Add self._is_dflash flag from draft_model_config_obj. Pass spec_training_store_last_hidden_states=False in engine kwargs. - train_entry: move draft config parsing + DFlash validation before dataset loading (step [1.5]) so defer_tokenization, backend mismatch, and min_loss_tokens misconfiguration fail fast before async work. Add min_loss_tokens >= 2*block_size consistency check. - DataCollator: replace hard ValueError with one-time debug log when hidden_states present without target/last_hidden_states. Avoids per-batch warning flood for DFlash; Eagle3 still fails at trainer. - trainer_actor: guard set_vocab_buffers with hasattr (DFlashDraftModel does not implement this method). - Fix pre-existing lint: shebang permissions, ambiguous var names, noqa for lazy imports. - Add TestDFlashHotfixes covering collator and min_loss_tokens validation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…dden_states Add spec_training_store_last_hidden_states flag to ServerArgs (default True for Eagle3 backward compat). Guard _lhs storage in _send_hidden_states_to_mooncake() so DFlash does not write orphaned last_hidden_states keys to Mooncake. Regenerated via tools/update_sglang_patch.sh on B200. Smoke tested: clean git apply on base commit succeeds, flag present in both server_args.py and scheduler_output_processor_mixin.py. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
f98baa3 to
59081fe
Compare
- Remove torch.compile graph break in _sample_anchor_positions() by eliminating .item() call (valid_counts check). Normal path handles all-zero case via keep_mask=False → zero loss. - Fix sglang_qwen3_8b_dflash.yaml: correct GPU allocation comment (8 GPU not 4), global_batch=8 not 12, switch to REPLICATE strategy. - Add configs/dflash_qwen3_8b_repro.yaml for 8x B200 reproduction. - Add scripts/bench_dflash_opts.py: micro-benchmarks for QKV fusion (1.39x), gate-up fusion (0.43x, not worth it), batch size scaling (1.48x at bs=4), and torch.compile modes (1.81x default best). - Add scripts/eval_dflash.sh: evaluation script using SpecForge benchmarks with DFlash speculative decoding. Validation on 8x B200 (760K PerfectBlend, 3 epochs, 18.8h): gsm8k: τ=5.55, accuracy=94% math500: τ=3.25, accuracy=30.5% Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ilures - Remove dead code: `_apply_rotary_pos_emb` (never called), `get_default_dflash_aux_layer_ids` (zero callers) - Remove unused params `norm_weight`/`norm_eps` from DFlashModel.forward() - Use Python scalars in torch.where instead of creating scalar CUDA tensors on every step (avoids per-step GPU allocation + potential torch.compile recompilation) - Replace class-level mutable `_warned_no_target_lhs` flag with standard `warnings.warn(stacklevel=2)` for proper once-only behavior - Add debug logging to silent ImportError catches in engine/__init__.py so internal import bugs are traceable Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Checkpoint cleanup: move _cleanup_old_checkpoints AFTER save at all 3 call sites to prevent race condition where cleanup deletes existing checkpoints before new save succeeds (data loss on save failure). - Checkpoint cleanup: replace shutil.rmtree(ignore_errors=True) with try/except that logs warning on failure, preventing silent disk leak. - DFlashTrainer: raise TypeError for unsupported draft_model_config types instead of silently accepting anything. - PrefetchedDataFetcher: preserve original traceback when re-raising prefetch thread exceptions, so error points to actual failure site. - DFlashDraftModel.load_embedding: add weights_only=True to torch.load calls for security and to suppress deprecation warning. - load_hf_dataset: narrow bare except Exception to specific schema inference errors (ValueError, TypeError, etc.) and log the fallback, so auth errors and network failures are not silently swallowed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Changes over PR #64
This branch includes all commits from #64 plus correctness fixes:
last_hidden_statesfrom DFlash tensor schema and passesspec_training_store_last_hidden_states=Falseto prevent orphaned_lhskeysinference_engine_type=sgl,defer_tokenization=False, andmin_loss_tokens >= 2 * block_sizebefore any async work beginstarget/last_hidden_statesreplaced with one-time debug log (avoids per-batch flood for DFlash, Eagle3 still fails at trainer level)set_vocab_buffersguard:hasattrcheck prevents AttributeError on DFlashDraftModelRemaining work
spec_training_store_last_hidden_statesflag topatches/sglang/v0.5.8.post1/sglang.patchandsglang_decode.patchviaupdate_sglang_patch.shworkflowTest plan
python -m pytest tests/test_dflash.py -von GPUapply_sglang_patch.shandapply_sglang_patch.sh --decodewith regenerated patches🤖 Generated with Claude Code