-
Notifications
You must be signed in to change notification settings - Fork 549
Fake balanced routing in MoE #1670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2bc8578
to
f7a6e8a
Compare
Thanks for the comments, I’ve updated the code as per your suggestions. It seems that variables starting with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, left some final nit comments.
Out of curiosity, what's your use case for this @rakkit ? I've used forced MoE balancing before to ensure I can run my model setup without OOMing in (at the very least) the well-balanced routing regime and to find the best-case throughput, but wondered if you had other reasons. Same question for @tianyu-l ; maybe the above and also for test cases? PS: a little tangential, but I have (messy) code for aggregating MoE routing stats across PP ranks so that all layer stats get reported, in case you're using pipeline parallel and ever need that. |
I think also for getting an idea of throughput under some specific setup, without the impact of load imbalance, which is also along the lines of
|
65144ee
to
973d6cb
Compare
Thanks a lot for the comments @tianyu-l, these are solved now! |
Thanks @garrett361, sry for the late response.
As tianyu said, the main motivation is testing throughput more conveniently. We also noticed that enforcing load balance helps avoid OOM. We plan to run extensive throughput checks (similar to HF’s Ultra-Scale Playbook) once MoE+Compile becomes available in torchtitan.
Thanks a lot! Can you pls share the link to the code? For our use case(here), we decide to log each PP stage separately, since The reason are 1) we have lots of log entries on each PP-stage from the Muon/Scion optimizer (we have a preprint on arXiv maybe this week), which provide better insight into training dynamics than simply tracking the global gradient norm. and 2) Aggregating metrics across PP stages requires extra communication. Its ok for most PP strategies, but I feel its awkward for zb-v (zero-bubble PP), where |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Had one last nit comment.
torchtitan/config/job_config.py
Outdated
"""Use deterministic algorithms wherever possible, may be slower""" | ||
|
||
debug_moe_force_load_balance: bool = False | ||
# if True, we force each experts get same amount of token via round-robin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One last nit to let arg parser pick up the helper message.
# if True, we force each experts get same amount of token via round-robin | |
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" |
973d6cb
to
f8ab474
Compare
Oh, nice, that is probably the better option. Didn't know about My implementation certainly has extra exposed comms. I thought they'd usually be pretty negligible for most models which are large enough to require PP, but maybe not the case for zb-v. In any case, the code is around here. |
torchtitan/models/moe.py
Outdated
k = torch.arange(self.top_k, device=scores.device)[None, :] # [1,K] | ||
selected_experts_indices = ( | ||
(i * self.top_k + k) % self.num_experts | ||
).long() # [N,K] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a nit, but I think the version like
selected_experts_indices = (torch.arange(num_tokens * self.top_k, device=device, dtype=torch.int64).reshape(
num_tokens, self.top_k
) % self.num_experts)
is a little cleaner, as well as faster.
Also, we could cache this and avoid recomputing the selected indices every time, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just had one nit, think this is pretty useful!
@rakkit would you like to address the final comment? Anyway it seems a rebase needs to happen before we merge. |
f8ab474
to
35ad499
Compare
@tianyu-l which final comment? #1670 (comment) this one fixed. and i took the version #1670 (comment) garrett361 proposed. also rebased. |
…ad_balance take version @garrett361 proposed
35ad499
to
acf814b
Compare
we can set
DEBUG_FORCE_LOAD_BALANCED=1
to force each experts get same amount of tokens.reprodue:
DEBUG_FORCE_LOAD_BALANCED=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --compile.enable
here is test on 8layers, 8 activate and 64 total experts. Green one is vanilla one and purple one is with force load balance
