Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ class Training:
deterministic: bool = False
"""Use deterministic algorithms wherever possible, may be slower"""

debug_moe_force_load_balance: bool = False
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""


@dataclass
class Parallelism:
Expand Down
4 changes: 4 additions & 0 deletions torchtitan/experiments/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
"CP support for FlexAttention is still in progress."
)

self.moe_args._debug_force_load_balance = (
job_config.training.debug_moe_force_load_balance
)

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
) -> tuple[int, float]:
Expand Down
4 changes: 4 additions & 0 deletions torchtitan/experiments/qwen3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

self.moe_args._debug_force_load_balance = (
job_config.training.debug_moe_force_load_balance
)

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
) -> tuple[int, float]:
Expand Down
4 changes: 4 additions & 0 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
"CP support for FlexAttention is still in progress."
)

self.moe_args._debug_force_load_balance = (
job_config.training.debug_moe_force_load_balance
)

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
) -> tuple[int, float]:
Expand Down
30 changes: 30 additions & 0 deletions torchtitan/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class MoEArgs:
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3

_debug_force_load_balance: bool = False
# if True, we force each experts get same amount of token via round-robin


# can be used as dense FFN layer or shared experts in MoE layers
class FeedForward(nn.Module):
Expand Down Expand Up @@ -180,6 +183,7 @@ def __init__(
score_func: Literal["softmax", "sigmoid"],
route_norm: bool,
route_scale: float,
_debug_force_load_balance: bool = False,
):
super().__init__()
self.gate = nn.Linear(dim, num_experts, bias=False)
Expand All @@ -188,6 +192,24 @@ def __init__(
self.score_func = score_func
self.route_norm = route_norm
self.route_scale = route_scale
self._debug_force_load_balance = _debug_force_load_balance

def _debug_force_load_balance_routing(
self, scores: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Balanced round-robin expert assignment.
Returns (selected_experts_indices [N, K] LongTensor, top_scores [N, K] FloatTensor).
"""
n_tokens = scores.size(0)
# Round-robin indices with exact balance
selected_experts_indices = (
torch.arange(
n_tokens * self.top_k, device=scores.device, dtype=torch.int64
).reshape(n_tokens, self.top_k)
% self.num_experts
)
top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K]
return selected_experts_indices, top_scores

def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
Expand Down Expand Up @@ -231,6 +253,13 @@ def forward(
scores, k=self.top_k, dim=1
)

# debug override: balanced round-robin routing
if self._debug_force_load_balance:
(
selected_experts_indices,
top_scores,
) = self._debug_force_load_balance_routing(scores)

if self.route_norm:
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
top_scores = top_scores / denominator
Expand Down Expand Up @@ -329,6 +358,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
score_func=moe_args.score_func,
route_norm=moe_args.route_norm,
route_scale=moe_args.route_scale,
_debug_force_load_balance=moe_args._debug_force_load_balance,
)
self.reorderer = TokenReorderer(num_experts=num_experts, top_k=moe_args.top_k)
self.shared_experts = (
Expand Down