Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def parallelize_deepseekv3(
)

if model_compile_enabled:
apply_compile(model, job_config.compile)
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)

dp_mesh: DeviceMesh | None = None
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
Expand Down
20 changes: 18 additions & 2 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def parallelize_llama(

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
apply_compile(model, job_config.compile)
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)

dp_mesh: DeviceMesh | None = None
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
Expand Down Expand Up @@ -506,7 +506,7 @@ def apply_moe_ep_tp(
)


def apply_compile(model: nn.Module, compile_config: CompileConfig):
def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: bool):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
Expand Down Expand Up @@ -577,6 +577,22 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
fullgraph=True,
)

if ep_enabled:
compiled_fn = moe_module._run_experts_grouped_mm

def _run_experts_grouped_mm_dynamic(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# dynamic number of tokens in expert parallel
torch._dynamo.mark_dynamic(x, 0)
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)

moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic

# NOTE: We don't compile for loop code path due to an issue with unbacked symints:
# https://github.com/pytorch/pytorch/issues/166460

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def parallelize_qwen3(

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
apply_compile(model, job_config.compile)
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)

if parallel_dims.fsdp_enabled:
# apply FSDP or HSDP, potentially with Context Parallel
Expand Down
Loading