Skip to content

dp2ep Expert Parallel #1324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

dp2ep Expert Parallel #1324

wants to merge 1 commit into from

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Jun 21, 2025

Overview

Previously I demonstrated Expert Parallel for expert-choice MoE in a stack of PRs #732.

This PR adds the initial support of dp2ep Expert Parallel for token-choice MoE, being non-intrusive to model code and composable with other parallelisms. See below for details.

This PR also fixes the issue between auxiliary-loss-free load balancing and gradient accumulation, partly inspired by the solution of @hann-wang in #1304 which originally pointed out the issue. This PR does the expert bias update in an optimizer hook, instead of adding another entry in TrainSpec.

Note: dp2ep EP + TP integration needs to wait for pytorch/pytorch#157216 to land.

While working on this PR, I also identified numerical issues between AdamW and Tensor Parallel, which I will post in a separate issue to track.

What is dp2ep Expert Parallel

Here are two diagrams illustrating the communication / computation pattern happening in dp2ep Expert Parallel. Basically, the Expert Parallel degree needed for MoE routed experts is borrowed from the Data Parallel (including Context Parallel) degree for non-MoE params (e.g. Attention layers, MLP layers) and other params in MoE layers (including the router's gate and shared experts).

without TP
image

with TP
image

Note: In the current implementation, the all-to-all communication across all TP ranks are duplicate, causing unnecessary communication overhead. As the next step, I'm going to implement the "Sequence Parallel" for the all-to-all, reducing the communication volume to 1 / tp_degree.

Design

The EP utilizes DTensor's parallelize_module API to shard MoE routed experts on the num_expert dimension, and inserts a pair of hooks before and after forward to perform all-to-all collectives.

In additional, this PR creates an expert_parallel wrapper applied to the GroupedExperts computation, serving
the following three purposes:

  1. Convert parameters from DTensors to plain Tensors, to work with dynamic-shape inputs which cannot be easily expressed as DTensors.
  2. In Expert Parallel, apply the generate_permute_indices kernel to permute the inputs to be ordered by local experts (see the _token_dispatch function in ExpertParallel) and permute the outputs back.
  3. In order to use torch._grouped_mm, we need to make sure the number of tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices kernel also helps achieve this via padding, without incurring synchronization between device and host. Note that this will create side effects when wrapping the for-loop implementation of GroupedExperts, as it does not need padding.
  4. Among the above:
    • 1 and 2 are needed only when expert_parallel_degree > 1.
    • 3 is needed even for single-device computation.
    • 2 can be moved to ExpertParallel's _token_dispatch if not coupled with 3.

Due to the inhomogeneity of DeviceMeshes from EP parameters and non-EP parameters, this PR adds the following special treatment to enable TP

  • DeviceMesh creation: when EP is enabled, create a special DeviceMesh to share between DP/CP (for non-EP parameters) and EP (for EP parameters).
  • gradient norm clipping: when EP is enabled, separately compute the norm of EP parameters and non-EP parameters -> compute the global norm -> separately perform grad norm clipping with the global norm.
  • fused optimizer step: created a new optimizer container class ExpertParallelOptimizersContainer which does fused optimizer steps on EP parameters and non-EP parameters separately.

For DeviceMesh, we'll need to improve the way we can express non-homogeneous meshes. For gradient norm clipping and fused optimizer, since there are up two groups of parameters, I expect the approach to be fine, until we find better way of support. Things could change if LLM / MoE architecture evolves to be more dynamic.

Communication Trace Verification

image

One can see that in order to call EP all-to-all _token_dispatch and _token_combine with correct input_splits and output_splits, we need to generate the size data via another dist.all_to_all_single (in the default stream) and do a device-to-host sync. This can be avoided by utilizing SymmetricMemory-based all-to-all-v, which we will work on soon.

DCP Resharding Correctness and Numerical Verification

Note: I used --optimizer.name="Adam" instead of "AdamW" which seems to cause numerical issues when TP is enabled.

To verify, I created a seed checkpoint of the debug model, fixed the seed, and ran the same training under different parallelism configs for 100 steps on at most 8 GPUs

  • FSDP 2
  • FSDP 2 (EP 2), TP 2, PP 2
  • HSDP 4 (DP 2, CP 2, EP 4), TP 2
image

Next Steps

  • Sequence Parallel for all-to-all communication collectives, when TP is enabled (at the cost of another pair of TP all-gather and reduce-scatter)
  • adopt SymmetricMemory-based all-to-all and avoid D2H syncs (cc @kwen2501)
  • enable EP in torchtitan's DeepSeekV3 @wwwjn
  • FSDP2 non-dim-0 sharding (cc @weifengpy)
  • torch.compile support @xmfan
    • which blocks torchao quantization enablement
  • computation / communication overlapping
    • either via inductor passes to overlap all-to-all with shared expert computation @xmfan
    • or via fine-grained Pipeline Parallel splitting & scheduling @H-Huang
  • torchft support @fegin
    • This is because this PR creates a new ExpertParallelOptimizersContainer; we need to figure out the UX to integrate with FTOptimizersContainer.
  • float8 + MoE TP integration @danielvegamyhre
    • Previously float8 works with TP by having specialized ColwiseParallel and RowwiseParallel (see code). For MoE, I'm creating new ad hoc ParallelStyles, including TensorParallel, ExpertParallel, and ExpertTensorParallel.
  • better DeviceMesh support and general "ETP" support (where experts TP and attention/mlp TP don't have to have the same TP degree) @fduwjj

@tianyu-l tianyu-l requested review from fegin and wwwjn as code owners June 21, 2025 01:07
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 21, 2025
@tianyu-l tianyu-l marked this pull request as draft June 21, 2025 01:07
@tianyu-l tianyu-l force-pushed the ep branch 2 times, most recently from 547ecae to 792f7a8 Compare June 26, 2025 05:51
@tianyu-l tianyu-l force-pushed the ep branch 3 times, most recently from 0f975fa to b517001 Compare June 29, 2025 07:54
@tianyu-l tianyu-l requested a review from wanchaol June 29, 2025 08:07
@tianyu-l tianyu-l marked this pull request as ready for review June 29, 2025 08:08
@tianyu-l tianyu-l requested a review from wconstab as a code owner June 29, 2025 08:08
@tianyu-l tianyu-l changed the title [WIP] expert parallel dp2ep dp2ep Expert Parallel Jun 29, 2025
@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None:
super().zero_grad(*args, **kwargs)


class ExpertParallelOptimizersContainer(OptimizersContainer):
"""
This class is created to support fused optimizer implementation for Expert Parallel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm do we really need this container? I thought after my PR pytorch/pytorch#147869 earlier this year, we should be able to run fused/foreach optimizer on DTensors that lives on different device mesh. Are you hitting a similar issue in pytorch/pytorch#153268?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After another look, I think indeed I'm hitting the same issue as pytorch/pytorch#153268 -- the error is on aten._fused_adam_.default (sorry I thought it was more elementary ops like the gradient norm clipping ones).



@torch.no_grad()
def _clip_grad_norm_with_ep(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain/document why we need this? Is it a similar issue to the optimizer? If so, IMO we should fix DTensor instead of adding all those wrappers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is similar issue to the optimizer, but not exactly the same.
The cross mesh problem first happens at aten.stack https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L102
Do you think we should support cross-mesh computation by DTensor for these more "elementary" ops? It might be easier if gradient norm computing / clipping come with fused ops.

return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

if parallel_dims.ep_enabled and fused:
if ft_manager.enabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what's the requirement to be compatible with torchft?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we support fused optimizer using DTensor instead of the wrapper code, it should be compatible with torchft.

@@ -24,40 +26,6 @@

# implementation of Tensor Parallel for the GroupedExperts in MoE
class TensorParallel(ParallelStyle):
def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why those methods are deleted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ExpertParallel and ExpertTensorParallel classes:

  1. I didn't support specifying input (output) placements, as in dp2ep EP they are always expect data sharded on batch dim and perform all-to-all's to dispatch (combine, respectively) tokens to (from, respectively) the corresponding experts.
  2. I had to convert parameters from DTensor to plain tensor before doing computation, where the inputs always stay as plain tensors.

Since TensorParallel is created only for the experts rather than with general purpose, I deliberately make the style consistent with ExpertParallel and ExpertTensorParallel, meaning

  1. it doesn't require input_layouts/output_layouts annotation -- it always expect inputs to be Replicate and outputs to be Partial;
  2. it doesn't convert inputs to DTensors in the forward pre hook, o/w it won't be consistent with plain Tensor parameters during computation.

Basically TensorParallel is a specialized ParallelStyle (combining Colwise and Rowwise into one), just like ExpertParallel and ExpertTensorParallel. Let me know what you think.

Copy link
Contributor Author

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wanchaol for the comments. I think for the fuse optimizer step, we may let DTensor support it in the "foreach" way. For others I'd love to hear more thoughts from you.

@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None:
super().zero_grad(*args, **kwargs)


class ExpertParallelOptimizersContainer(OptimizersContainer):
"""
This class is created to support fused optimizer implementation for Expert Parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After another look, I think indeed I'm hitting the same issue as pytorch/pytorch#153268 -- the error is on aten._fused_adam_.default (sorry I thought it was more elementary ops like the gradient norm clipping ones).

return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

if parallel_dims.ep_enabled and fused:
if ft_manager.enabled:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we support fused optimizer using DTensor instead of the wrapper code, it should be compatible with torchft.



@torch.no_grad()
def _clip_grad_norm_with_ep(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is similar issue to the optimizer, but not exactly the same.
The cross mesh problem first happens at aten.stack https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L102
Do you think we should support cross-mesh computation by DTensor for these more "elementary" ops? It might be easier if gradient norm computing / clipping come with fused ops.

@@ -24,40 +26,6 @@

# implementation of Tensor Parallel for the GroupedExperts in MoE
class TensorParallel(ParallelStyle):
def __init__(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ExpertParallel and ExpertTensorParallel classes:

  1. I didn't support specifying input (output) placements, as in dp2ep EP they are always expect data sharded on batch dim and perform all-to-all's to dispatch (combine, respectively) tokens to (from, respectively) the corresponding experts.
  2. I had to convert parameters from DTensor to plain tensor before doing computation, where the inputs always stay as plain tensors.

Since TensorParallel is created only for the experts rather than with general purpose, I deliberately make the style consistent with ExpertParallel and ExpertTensorParallel, meaning

  1. it doesn't require input_layouts/output_layouts annotation -- it always expect inputs to be Replicate and outputs to be Partial;
  2. it doesn't convert inputs to DTensors in the forward pre hook, o/w it won't be consistent with plain Tensor parameters during computation.

Basically TensorParallel is a specialized ParallelStyle (combining Colwise and Rowwise into one), just like ExpertParallel and ExpertTensorParallel. Let me know what you think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants