Skip to content

feat: add train_with_decode support for speculative decoding during training#28

Merged
torchspec-bot merged 1 commit intomainfrom
feat/train-with-decode
Mar 4, 2026
Merged

feat: add train_with_decode support for speculative decoding during training#28
torchspec-bot merged 1 commit intomainfrom
feat/train-with-decode

Conversation

@cicirori
Copy link
Collaborator

@cicirori cicirori commented Mar 4, 2026

Summary

  • Add decode-mode training where the inference engine generates new tokens with speculative decoding and captures hidden states for the full prompt+completion sequence
  • Add DecodeConfig dataclass, SglDecodeEngineMixin for decode-mode generation and draft weight sync, sglang decode patch, configs, and example scripts
  • Integrate into training loop with periodic draft model weight sync to inference engines

Changes

  • torchspec/config/train_config.py: Add train_with_decode flag, DecodeConfig dataclass, update dynamic_loss_mask conditional
  • torchspec/inference/engine/sgl_engine_decode.py: New mixin with generate_with_decode() and update_weights_from_disk()
  • torchspec/inference/engine/sgl_engine.py: Inherit SglDecodeEngineMixin, add decode engine kwargs
  • torchspec/controller/inference_manager.py: Conditional dispatch (generate vs generate_with_decode)
  • torchspec/controller/loop.py: _maybe_sync_draft_weights() for periodic draft model updates
  • torchspec/train_entry.py: _maybe_create_scratch_draft(), startup validation for decode mode
  • torchspec/data/dataset.py: Pass train_with_decode as add_generation_prompt through data pipeline
  • patches/sglang/v0.5.8.post1/sglang_decode.patch: Full sglang patch for decode-mode support
  • tools/apply_sglang_patch.sh: Add --decode flag to apply decode patch
  • tools/convert_to_hf.py: Add train_with_decode=False parameter
  • configs/train_with_decode/: Qwen3-8B and Kimi-K2.5-NVFP4 decode-mode configs
  • examples/train-with-decode/: Launch scripts for decode-mode training

Test plan

  • All 232 existing tests pass (0 failures)
  • ruff lint + format checks pass
  • Pre-commit hooks pass
  • Manual integration test with decode-mode config

Acknowledgements

Co-authored-by: BobbyIsHandsome 96061080+BobbyIsHandsome@users.noreply.github.com
Co-authored-by: Junxiong Wang 16102460+jxiw@users.noreply.github.com
Co-authored-by: xwuShirley 37637998+xwuShirley@users.noreply.github.com
Co-authored-by: Yubo Wang 10526540+yubofredwang@users.noreply.github.com

Generated with Claude Code

@cicirori cicirori marked this pull request as draft March 4, 2026 00:44
@cicirori cicirori force-pushed the feat/train-with-decode branch 8 times, most recently from b9f8d3d to 60fa773 Compare March 4, 2026 02:46
@cicirori cicirori changed the title feat: restore train_with_decode functionality feat: add train_with_decode support for speculative decoding during training Mar 4, 2026
@cicirori cicirori force-pushed the feat/train-with-decode branch 2 times, most recently from 46b9971 to b3f976c Compare March 4, 2026 02:50
@cicirori cicirori marked this pull request as ready for review March 4, 2026 03:00
@cicirori cicirori force-pushed the feat/train-with-decode branch from b3f976c to 52cd929 Compare March 4, 2026 03:13
…raining

Co-authored-by: Bobbie Bie <96061080+BobbyIsHandsome@users.noreply.github.com>
Co-authored-by: Junxiong Wang <chuangzhetianxia@gmail.com>
Co-authored-by: Shirley Wu <shirley@research-dev-b200-04.cloud.together.ai>
Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
@torchspec-bot torchspec-bot merged commit 3c59b32 into main Mar 4, 2026
1 check passed
@torchspec-bot torchspec-bot deleted the feat/train-with-decode branch March 4, 2026 05:56
cicirori added a commit that referenced this pull request Mar 4, 2026
…raining (#28)

Co-authored-by: BobbyIsHandsome <96061080+BobbyIsHandsome@users.noreply.github.com>
Co-authored-by: Junxiong Wang <16102460+jxiw@users.noreply.github.com>
Co-authored-by: xwuShirley <37637998+xwuShirley@users.noreply.github.com>
Co-authored-by: Yubo Wang <10526540+yubofredwang@users.noreply.github.com>
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request Mar 22, 2026
- Results: Compute sub-breakdown, 200-step stability, optimization tests
- Issues: torchspec-project#27 torch.compile recompilation, torchspec-project#28 GPU Direct RDMA, torchspec-project#29 Mooncake bypass
- Pending work: Updated completed items, active training tasks
- Best config: no_sync + bf16 reduce → 2.7 step/s (+8%), ~3.9hr training

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request Mar 23, 2026
…raining (torchspec-project#28)

Co-authored-by: BobbyIsHandsome <96061080+BobbyIsHandsome@users.noreply.github.com>
Co-authored-by: Junxiong Wang <16102460+jxiw@users.noreply.github.com>
Co-authored-by: xwuShirley <37637998+xwuShirley@users.noreply.github.com>
Co-authored-by: Yubo Wang <10526540+yubofredwang@users.noreply.github.com>
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request Mar 23, 2026
- Results: Compute sub-breakdown, 200-step stability, optimization tests
- Issues: torchspec-project#27 torch.compile recompilation, torchspec-project#28 GPU Direct RDMA, torchspec-project#29 Mooncake bypass
- Pending work: Updated completed items, active training tasks
- Best config: no_sync + bf16 reduce → 2.7 step/s (+8%), ~3.9hr training
zhubohao911 pushed a commit to zhubohao911/TorchSpec that referenced this pull request Mar 23, 2026
- Results: Compute sub-breakdown, 200-step stability, optimization tests
- Issues: torchspec-project#27 torch.compile recompilation, torchspec-project#28 GPU Direct RDMA, torchspec-project#29 Mooncake bypass
- Pending work: Updated completed items, active training tasks
- Best config: no_sync + bf16 reduce → 2.7 step/s (+8%), ~3.9hr training
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants