Skip to content

Conversation

rakkit
Copy link
Contributor

@rakkit rakkit commented Sep 1, 2025

we can set DEBUG_FORCE_LOAD_BALANCED=1 to force each experts get same amount of tokens.

reprodue: DEBUG_FORCE_LOAD_BALANCED=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --compile.enable

here is test on 8layers, 8 activate and 64 total experts. Green one is vanilla one and purple one is with force load balance
image

image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 1, 2025
@tianyu-l tianyu-l linked an issue Sep 6, 2025 that may be closed by this pull request
@rakkit rakkit force-pushed the allow_force_router_balance branch from 2bc8578 to f7a6e8a Compare September 23, 2025 04:43
@rakkit rakkit requested a review from tianyu-l September 23, 2025 04:43
@rakkit
Copy link
Contributor Author

rakkit commented Sep 23, 2025

Thanks for the comments, I’ve updated the code as per your suggestions.

It seems that variables starting with _ broke CLI. so left it as training.debug_moe_force_load_balance.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, left some final nit comments.

@garrett361
Copy link
Contributor

Out of curiosity, what's your use case for this @rakkit ? I've used forced MoE balancing before to ensure I can run my model setup without OOMing in (at the very least) the well-balanced routing regime and to find the best-case throughput, but wondered if you had other reasons.

Same question for @tianyu-l ; maybe the above and also for test cases?

PS: a little tangential, but I have (messy) code for aggregating MoE routing stats across PP ranks so that all layer stats get reported, in case you're using pipeline parallel and ever need that.

@tianyu-l
Copy link
Contributor

@garrett361

Same question for @tianyu-l ; maybe the above and also for test cases?

I think also for getting an idea of throughput under some specific setup, without the impact of load imbalance, which is also along the lines of

to ensure I can run my model setup without OOMing in (at the very least) the well-balanced routing regime and to find the best-case throughput

@rakkit rakkit force-pushed the allow_force_router_balance branch 2 times, most recently from 65144ee to 973d6cb Compare September 28, 2025 22:29
@rakkit
Copy link
Contributor Author

rakkit commented Sep 28, 2025

Thanks a lot for the comments @tianyu-l, these are solved now!

@rakkit rakkit requested a review from tianyu-l September 28, 2025 22:31
@rakkit
Copy link
Contributor Author

rakkit commented Sep 28, 2025

Thanks @garrett361, sry for the late response.

what's your use case for this

As tianyu said, the main motivation is testing throughput more conveniently. We also noticed that enforcing load balance helps avoid OOM. We plan to run extensive throughput checks (similar to HF’s Ultra-Scale Playbook) once MoE+Compile becomes available in torchtitan.

PS: a little tangential, but I have (messy) code for aggregating MoE routing stats across PP ranks so that all layer stats get reported, in case you're using pipeline parallel and ever need that.

Thanks a lot! Can you pls share the link to the code? For our use case(here), we decide to log each PP stage separately, since torchtitan supports creating a wandb run per rank via save_for_all_ranks.

The reason are 1) we have lots of log entries on each PP-stage from the Muon/Scion optimizer (we have a preprint on arXiv maybe this week), which provide better insight into training dynamics than simply tracking the global gradient norm. and 2) Aggregating metrics across PP stages requires extra communication. Its ok for most PP strategies, but I feel its awkward for zb-v (zero-bubble PP), where optimizer.step runs sequentially along PP stages. In that case, communication would add sync overhead to async steps, or we’d need buffering for proper logging.

@rakkit rakkit requested a review from garrett361 September 28, 2025 22:53
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Had one last nit comment.

"""Use deterministic algorithms wherever possible, may be slower"""

debug_moe_force_load_balance: bool = False
# if True, we force each experts get same amount of token via round-robin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last nit to let arg parser pick up the helper message.

Suggested change
# if True, we force each experts get same amount of token via round-robin
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""

@rakkit rakkit force-pushed the allow_force_router_balance branch from 973d6cb to f8ab474 Compare September 29, 2025 05:32
@garrett361
Copy link
Contributor

garrett361 commented Sep 29, 2025

we decide to log each PP stage separately, since torchtitan supports creating a wandb run per rank via save_for_all_ranks

Oh, nice, that is probably the better option. Didn't know about save_for_all_ranks.

My implementation certainly has extra exposed comms. I thought they'd usually be pretty negligible for most models which are large enough to require PP, but maybe not the case for zb-v. In any case, the code is around here.

k = torch.arange(self.top_k, device=scores.device)[None, :] # [1,K]
selected_experts_indices = (
(i * self.top_k + k) % self.num_experts
).long() # [N,K]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a nit, but I think the version like

selected_experts_indices = (torch.arange(num_tokens * self.top_k, device=device, dtype=torch.int64).reshape(
    num_tokens, self.top_k
) % self.num_experts)

is a little cleaner, as well as faster.

Also, we could cache this and avoid recomputing the selected indices every time, right?

Copy link
Contributor

@garrett361 garrett361 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had one nit, think this is pretty useful!

@tianyu-l
Copy link
Contributor

tianyu-l commented Oct 2, 2025

@rakkit would you like to address the final comment? Anyway it seems a rebase needs to happen before we merge.

@rakkit rakkit force-pushed the allow_force_router_balance branch from f8ab474 to 35ad499 Compare October 2, 2025 23:35
@rakkit
Copy link
Contributor Author

rakkit commented Oct 2, 2025

@tianyu-l which final comment? #1670 (comment) this one fixed. and i took the version #1670 (comment) garrett361 proposed.

also rebased.

@rakkit rakkit force-pushed the allow_force_router_balance branch from 35ad499 to acf814b Compare October 2, 2025 23:38
@tianyu-l tianyu-l merged commit 99fee81 into pytorch:main Oct 3, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

create fake balanced routing in MoE / EP for infra development
3 participants