diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 2f87ab064..304f84bda 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -239,6 +239,9 @@ class Training: deterministic: bool = False """Use deterministic algorithms wherever possible, may be slower""" + debug_moe_force_load_balance: bool = False + """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + @dataclass class Parallelism: diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index 9d17de816..e34d4d3cc 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -72,6 +72,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for FlexAttention is still in progress." ) + self.moe_args._debug_force_load_balance = ( + job_config.training.debug_moe_force_load_balance + ) + def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, float]: diff --git a/torchtitan/experiments/qwen3/model/args.py b/torchtitan/experiments/qwen3/model/args.py index 7828291f0..5fd98fdce 100644 --- a/torchtitan/experiments/qwen3/model/args.py +++ b/torchtitan/experiments/qwen3/model/args.py @@ -55,6 +55,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len + self.moe_args._debug_force_load_balance = ( + job_config.training.debug_moe_force_load_balance + ) + def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, float]: diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 238880247..3bac6e82f 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -106,6 +106,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for FlexAttention is still in progress." ) + self.moe_args._debug_force_load_balance = ( + job_config.training.debug_moe_force_load_balance + ) + def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, float]: diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 9f519dc04..ab3f4226e 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -30,6 +30,9 @@ class MoEArgs: use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 + _debug_force_load_balance: bool = False + # if True, we force each experts get same amount of token via round-robin + # can be used as dense FFN layer or shared experts in MoE layers class FeedForward(nn.Module): @@ -180,6 +183,7 @@ def __init__( score_func: Literal["softmax", "sigmoid"], route_norm: bool, route_scale: float, + _debug_force_load_balance: bool = False, ): super().__init__() self.gate = nn.Linear(dim, num_experts, bias=False) @@ -188,6 +192,24 @@ def __init__( self.score_func = score_func self.route_norm = route_norm self.route_scale = route_scale + self._debug_force_load_balance = _debug_force_load_balance + + def _debug_force_load_balance_routing( + self, scores: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Balanced round-robin expert assignment. + Returns (selected_experts_indices [N, K] LongTensor, top_scores [N, K] FloatTensor). + """ + n_tokens = scores.size(0) + # Round-robin indices with exact balance + selected_experts_indices = ( + torch.arange( + n_tokens * self.top_k, device=scores.device, dtype=torch.int64 + ).reshape(n_tokens, self.top_k) + % self.num_experts + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K] + return selected_experts_indices, top_scores def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None @@ -231,6 +253,13 @@ def forward( scores, k=self.top_k, dim=1 ) + # debug override: balanced round-robin routing + if self._debug_force_load_balance: + ( + selected_experts_indices, + top_scores, + ) = self._debug_force_load_balance_routing(scores) + if self.route_norm: denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 top_scores = top_scores / denominator @@ -329,6 +358,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): score_func=moe_args.score_func, route_norm=moe_args.route_norm, route_scale=moe_args.route_scale, + _debug_force_load_balance=moe_args._debug_force_load_balance, ) self.reorderer = TokenReorderer(num_experts=num_experts, top_k=moe_args.top_k) self.shared_experts = (