-
Notifications
You must be signed in to change notification settings - Fork 412
dp2ep Expert Parallel #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
dp2ep Expert Parallel #1324
Conversation
547ecae
to
792f7a8
Compare
0f975fa
to
b517001
Compare
@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None: | |||
super().zero_grad(*args, **kwargs) | |||
|
|||
|
|||
class ExpertParallelOptimizersContainer(OptimizersContainer): | |||
""" | |||
This class is created to support fused optimizer implementation for Expert Parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm do we really need this container? I thought after my PR pytorch/pytorch#147869 earlier this year, we should be able to run fused/foreach optimizer on DTensors that lives on different device mesh. Are you hitting a similar issue in pytorch/pytorch#153268?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After another look, I think indeed I'm hitting the same issue as pytorch/pytorch#153268 -- the error is on aten._fused_adam_.default
(sorry I thought it was more elementary ops like the gradient norm clipping ones).
|
||
|
||
@torch.no_grad() | ||
def _clip_grad_norm_with_ep( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain/document why we need this? Is it a similar issue to the optimizer? If so, IMO we should fix DTensor instead of adding all those wrappers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is similar issue to the optimizer, but not exactly the same.
The cross mesh problem first happens at aten.stack
https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L102
Do you think we should support cross-mesh computation by DTensor for these more "elementary" ops? It might be easier if gradient norm computing / clipping come with fused ops.
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) | ||
|
||
if parallel_dims.ep_enabled and fused: | ||
if ft_manager.enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what's the requirement to be compatible with torchft?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we support fused optimizer using DTensor instead of the wrapper code, it should be compatible with torchft.
@@ -24,40 +26,6 @@ | |||
|
|||
# implementation of Tensor Parallel for the GroupedExperts in MoE | |||
class TensorParallel(ParallelStyle): | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why those methods are deleted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In ExpertParallel
and ExpertTensorParallel
classes:
- I didn't support specifying input (output) placements, as in dp2ep EP they are always expect data sharded on batch dim and perform all-to-all's to dispatch (combine, respectively) tokens to (from, respectively) the corresponding experts.
- I had to convert parameters from DTensor to plain tensor before doing computation, where the inputs always stay as plain tensors.
Since TensorParallel
is created only for the experts rather than with general purpose, I deliberately make the style consistent with ExpertParallel
and ExpertTensorParallel
, meaning
- it doesn't require input_layouts/output_layouts annotation -- it always expect inputs to be
Replicate
and outputs to bePartial
; - it doesn't convert inputs to DTensors in the forward pre hook, o/w it won't be consistent with plain Tensor parameters during computation.
Basically TensorParallel
is a specialized ParallelStyle
(combining Colwise
and Rowwise
into one), just like ExpertParallel
and ExpertTensorParallel
. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @wanchaol for the comments. I think for the fuse optimizer step, we may let DTensor support it in the "foreach" way. For others I'd love to hear more thoughts from you.
@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None: | |||
super().zero_grad(*args, **kwargs) | |||
|
|||
|
|||
class ExpertParallelOptimizersContainer(OptimizersContainer): | |||
""" | |||
This class is created to support fused optimizer implementation for Expert Parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After another look, I think indeed I'm hitting the same issue as pytorch/pytorch#153268 -- the error is on aten._fused_adam_.default
(sorry I thought it was more elementary ops like the gradient norm clipping ones).
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) | ||
|
||
if parallel_dims.ep_enabled and fused: | ||
if ft_manager.enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we support fused optimizer using DTensor instead of the wrapper code, it should be compatible with torchft.
|
||
|
||
@torch.no_grad() | ||
def _clip_grad_norm_with_ep( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is similar issue to the optimizer, but not exactly the same.
The cross mesh problem first happens at aten.stack
https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L102
Do you think we should support cross-mesh computation by DTensor for these more "elementary" ops? It might be easier if gradient norm computing / clipping come with fused ops.
@@ -24,40 +26,6 @@ | |||
|
|||
# implementation of Tensor Parallel for the GroupedExperts in MoE | |||
class TensorParallel(ParallelStyle): | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In ExpertParallel
and ExpertTensorParallel
classes:
- I didn't support specifying input (output) placements, as in dp2ep EP they are always expect data sharded on batch dim and perform all-to-all's to dispatch (combine, respectively) tokens to (from, respectively) the corresponding experts.
- I had to convert parameters from DTensor to plain tensor before doing computation, where the inputs always stay as plain tensors.
Since TensorParallel
is created only for the experts rather than with general purpose, I deliberately make the style consistent with ExpertParallel
and ExpertTensorParallel
, meaning
- it doesn't require input_layouts/output_layouts annotation -- it always expect inputs to be
Replicate
and outputs to bePartial
; - it doesn't convert inputs to DTensors in the forward pre hook, o/w it won't be consistent with plain Tensor parameters during computation.
Basically TensorParallel
is a specialized ParallelStyle
(combining Colwise
and Rowwise
into one), just like ExpertParallel
and ExpertTensorParallel
. Let me know what you think.
Overview
Previously I demonstrated Expert Parallel for expert-choice MoE in a stack of PRs #732.
This PR adds the initial support of dp2ep Expert Parallel for token-choice MoE, being non-intrusive to model code and composable with other parallelisms. See below for details.
This PR also fixes the issue between auxiliary-loss-free load balancing and gradient accumulation, partly inspired by the solution of @hann-wang in #1304 which originally pointed out the issue. This PR does the expert bias update in an optimizer hook, instead of adding another entry in
TrainSpec
.Note: dp2ep EP + TP integration needs to wait for pytorch/pytorch#157216 to land.
While working on this PR, I also identified numerical issues between AdamW and Tensor Parallel, which I will post in a separate issue to track.
What is dp2ep Expert Parallel
Here are two diagrams illustrating the communication / computation pattern happening in dp2ep Expert Parallel. Basically, the Expert Parallel degree needed for MoE routed experts is borrowed from the Data Parallel (including Context Parallel) degree for non-MoE params (e.g. Attention layers, MLP layers) and other params in MoE layers (including the router's gate and shared experts).
without TP

with TP

Note: In the current implementation, the all-to-all communication across all TP ranks are duplicate, causing unnecessary communication overhead. As the next step, I'm going to implement the "Sequence Parallel" for the all-to-all, reducing the communication volume to
1 / tp_degree
.Design
The EP utilizes DTensor's
parallelize_module
API to shard MoE routed experts on thenum_expert
dimension, and inserts a pair of hooks before and after forward to perform all-to-all collectives.In additional, this PR creates an
expert_parallel
wrapper applied to the GroupedExperts computation, servingthe following three purposes:
generate_permute_indices
kernel to permute the inputs to be ordered by local experts (see the_token_dispatch
function inExpertParallel
) and permute the outputs back.torch._grouped_mm
, we need to make sure the number of tokens each expert gets is a multiple ofALIGN_SIZE_M
. Thegenerate_permute_indices
kernel also helps achieve this via padding, without incurring synchronization between device and host. Note that this will create side effects when wrapping the for-loop implementation of GroupedExperts, as it does not need padding.expert_parallel_degree
> 1.ExpertParallel
's_token_dispatch
if not coupled with 3.Due to the inhomogeneity of
DeviceMesh
es from EP parameters and non-EP parameters, this PR adds the following special treatment to enable TPDeviceMesh
creation: when EP is enabled, create a specialDeviceMesh
to share between DP/CP (for non-EP parameters) and EP (for EP parameters).ExpertParallelOptimizersContainer
which does fused optimizer steps on EP parameters and non-EP parameters separately.For
DeviceMesh
, we'll need to improve the way we can express non-homogeneous meshes. For gradient norm clipping and fused optimizer, since there are up two groups of parameters, I expect the approach to be fine, until we find better way of support. Things could change if LLM / MoE architecture evolves to be more dynamic.Communication Trace Verification
One can see that in order to call EP all-to-all
_token_dispatch
and_token_combine
with correctinput_splits
andoutput_splits
, we need to generate the size data via anotherdist.all_to_all_single
(in the default stream) and do a device-to-host sync. This can be avoided by utilizing SymmetricMemory-basedall-to-all-v
, which we will work on soon.DCP Resharding Correctness and Numerical Verification
Note: I used
--optimizer.name="Adam"
instead of"AdamW"
which seems to cause numerical issues when TP is enabled.To verify, I created a seed checkpoint of the debug model, fixed the seed, and ran the same training under different parallelism configs for 100 steps on at most 8 GPUs
Next Steps
torch.compile
support @xmfantorchft
support @feginExpertParallelOptimizersContainer
; we need to figure out the UX to integrate withFTOptimizersContainer
.ColwiseParallel
andRowwiseParallel
(see code). For MoE, I'm creating new ad hocParallelStyle
s, includingTensorParallel
,ExpertParallel
, andExpertTensorParallel
.DeviceMesh
support and general "ETP" support (where experts TP and attention/mlp TP don't have to have the same TP degree) @fduwjj