-
Notifications
You must be signed in to change notification settings - Fork 660
Integrate DeepEP to torchtitan #2107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi @elfiegg! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
a5875e5 to
6999d1e
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing!
Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU)
May I ask what is the baseline?
Btw, I think to fully utilize the power of DeepEP, we also need to have node-limited routing, which the current torchtitan DSv3 model doesn't have.
@shuhuayu let's add it? we can refer to HF or deepseek original impl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of making it an experiment (which restricts it to a special version of deepseek_v3), I think we should integrate it directly in core.
We can have a factory method (e.g. build_moe) which takes a string (e.g. "deep_ep") to dispatch to this version of MoE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure that's a great idea! - once I confirm this works for larger models and improves perf
Regarding integrating directly to main - do we need to manage DeepEP dependency at all or we leave it to the users to install?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I prefer
we leave it to the users to install
instead of bundling it by default. We can explicitly mention this in try-catch when we do the import.
| from torch.distributed.tensor import distribute_module | ||
|
|
||
|
|
||
| class DeepEPExpertParallel(ParallelStyle): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used anywhere? I'm guessing that this is not running e2e with torchtitan train.py which is still WIP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.
This is our implementation and our deepep interface is similar to Nvidia's version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr! I think we should support node-limited routing to make multi-node setup faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the node-limited routing here: #2111. Perhaps it helps make deepep faster in multi-node setups.
Actually I have rerun this last night and the perf caught up - the lagging perf was gone once I enabled FSDP for MoE layer (which I disabled for debugging purpose). Running below command, I got 13% MFU for both baseline and DeepEP version Baseline I referred to the config here: |
|
Thanks for posting the work! We had a successful and performant DeepEP integration at: AMD-AGI@59fe226 We borrowed some design from Megatron-LM and we can use it here too. I don't see big differences between our DeepEP interface and yours. Let's work together on this. Feel free to reach out to me or Tianyu for future discussion and collaboration. |
| from torch.distributed.tensor import distribute_module | ||
|
|
||
|
|
||
| class DeepEPExpertParallel(ParallelStyle): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.
This is our implementation and our deepep interface is similar to Nvidia's version.
|
|
||
| if self.score_before_experts: | ||
| recv_x_sorted = (recv_x_sorted.to(torch.float32) * valid_weights_sorted.unsqueeze(-1)).to(recv_x_sorted.dtype) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the code before experts.forward should go to DeepEPExpertParallel as input_fn, same for the token unpermute after experts as output_fn
Also consider using something like _indices_to_multihot_kernel (https://github.com/NVIDIA/Megatron-LM/blob/f5344166732f45bb0dd825dc875288ea97b15b47/megatron/core/fusions/fused_indices_converter.py#L32C5-L32C32) to preprocess received DeepEP data.
You are using a lot of index-selecting here which I suspect would incur significant CPU overhead (and lock/wait among CPU threads)
|
Thanks all for the valuable advice! - I'm currently occupied by a deadline but I will take a closer look and join the discussion tomorrow |
|
I scanned through the comments, and here is a summary:
If this looks good to everyone, I'll start revising the PR |
|
@elfiegg sounds good overall
I think this may be done in a followup PR, assuming that the tradeoff can be justified by benchmarking results. cc @yuankaichen-amd WDYT? |
Either way works -- it should be a low-hanging fruit. The triton kernel is available in both Megatron and my integration PR. |
I strongly recommend wrapping DeepEP related ops into a standalone class: There are six methods: It will set a clear boundary between DeepEP and Torchtitan's MoE module or ExpertParallel wrapper. Also you can get a free ride for many things that Nvidia has already implemented in Megatron. |
0a9815b to
e0d4fcf
Compare
torchtitan/models/moe/moe_deepep.py
Outdated
| x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess( | ||
| x, selected_experts_indices, top_scores | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This version looks much cleaner.
After reading the code, I wonder what's the best way to organize code. Brainstorming with some (immature) ideas:
- Use
if use_deepepin the code for the region of difference. @shuhuayu IIUC you were having this idea? - Abstract token_dispatching + routed_experts computation into its own classes, so that the MoE class can be shared.
- Moving
dispatch_preprocessanddispatch_postprocessalso insideExpertParallelhooks. The challenges seems thatExpertParallelclasses are not getting all the inputs we need. - Unify
dispatch_preprocesswith theTokenReordererconcept needed for the non-deepep impl. The challenge is similar to 3 in that the interfaces do not really align.
@yuankaichen-amd would love to hear your thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can revise the existing MoE modules by adding a if use_deepep branch, or finding a way to inherit this MoE module if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we compare the existing MoE against the DeepEP MoE:
self.router(x, self.expert_bias)
self.reorderer(top_scores, selected_experts_indices)
#### many lines omitted
# shape (bs*slen*top_k, dim)
routed_output = self.experts(routed_input, num_tokens_per_expert)
===============================================
self.router(x, self.expert_bias)
x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess(
x, selected_experts_indices, top_scores
)
routed_output = self.experts(x_prep, num_tokens_per_expert)
For the old MoE class, what happens between router and experts is also a kind of preprocess.
I think we can have an "AlltoallTokenDispatcher" which modularizes these operations. So in the combined MoE implementation, we will have:
self.router(x, self.expert_bias)
some_token_dispatcher.preprocess(x, selected_experts_indices, top_scores)
routed_output = self.experts(routed_input, ...)
With this, we can even combine the ExpertParallel with DeepExerptParallel, where _token_dispatch and _token_combine will be interfaced with a token dispatcher directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yuankaichen-amd
Sounds nice in general, I have some concerns about the following:
- Now token_dispatcher is a submodule of MoE but we only call
token_dispatcher.preprocessin model code, and delay the actual dispatch / combine into hooks, which doesn't sound natural. - The benefit of using hooks was that single-device code is still correct, and we can apply EP on top of single-device code. If we do the DeepEP path which has a different preprocess method, would the single-device code EP=1 + DeepEP enabled still be "correct"?
Based on the above points, should we move towards dispatch / combine being actually called "in the model code", any be of no-op when EP is not enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got your concerns. By "dispatch / combine being actually called "in the model code"", are you suggesting that with this token_dispatcher design, we should also retire the ExpertParallel wrapper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the clarification! Actually now I am thinking that we can inherit from current MoE to create a new DeepEPMoE since they share most code, and use the current build_moe function to build it if DeepEP is used. So basically, we use separate MoEs (MoE and DeepEPMoE, and also separate ExpertParallels (ExpertParallel and DeepEPExpertParallel) to integrate DeepEP into TorchTitan so the main interface is preserved. We can create a folder under distributed named expert_parallel and put the main class of DeepEPExpertparall into existing expert_parallel.py and other supporting classes and functions into another deep_ep.py, both .py files are under the new expert_parallel folder. For the preprocessing and postprocessing functions for token_dispatch or token_combine (in DeepEPExpertParallel), we can add private methods into the same classes if necessary. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be similar to the current implementation - either way works for me, abstracting common API or inheriting the current modules. One benefit we get out of box by not separating ExpertParallel is ETP seems to work out of box, also as @yuankaichen-amd pointed out, comm-comp overlap might also work along.
No objection to build_moe and DeeoEPMoE - either way we need to condition the logic at some level. Separating them might be beneficial if more comm libraries come up and complicate MoE class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comm-comp overlap might also work along.
@yuankaichen-amd could you share more details about the overlapping and how having the same EP class would help? I discussed with @shuhuayu offline and we are not sure what's the benefit of separating into six methods, namely [dispatch, combine] x [preprocess, process, postprocess].
My take is that ExpertParallel class is implicitly a token dispatcher, it's just applied in a wrapper from outside the model.
Explicitly creating another token dispatcher (for expert parallel comms), and let the implicit one (ExpertParallel) access the explicit token dispatcher class indirectly via has_attr call sounds not very straightforward.
At this moment, I'm leaning towards having two MoE classes
- letting the
DeepEPMoEinheriting the baseMoEone and only overwrites theforwardfunction - Having another
DeepEPExpertParallelclass which inherits aBaseExpertParallelprotocol class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a discussion with @tianyu-l and @shuhuayu offline. I agree that
(1) having two separate MoE classes may work best for now (@elfiegg's original design);
(2) token_dispatcher now looks unnecessary (sorry for suggesting it in the first place) and we can directly invoke DeepEP related methods in the DeepEPExpertParallel;
(3) DeepEP needs some additional inputs, we can do this by some additional attributes in DeepEPMoE module, or add a manager subclass if needed. Since @elfiegg has already had DeepEp manager implemented, I'd suggest let's leave it as it is for now. I will also need to take a closer look at AMD's DeepEP's different modes. Let's review this design later.
@elfiegg what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM - will send changes by EoD
torchtitan/models/moe/moe_deepep.py
Outdated
|
|
||
| # Setup dispatcher metadata (routing information) for hooks to use | ||
| # The hooks will call token_dispatch/token_combine which need this metadata | ||
| x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would preprocess and postprocess do if we don't use EP, e.g. single-device computation -- would it be no-op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to fall back to standard impl if (no EP && use DeepEP)
torchtitan/models/moe/moe_deepep.py
Outdated
| from torchtitan.tools.logging import logger | ||
|
|
||
|
|
||
| class MoEWithDeepEP(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we inherit existing MoE to reuse most code except for MoEFlexTokenDispatcher init and the forward call?
| dim=model_args.dim, | ||
| hidden_dim=model_args.moe_inter_dim, | ||
| communication_backend=model_args.moe_comm_backend, | ||
| score_before_experts=model_args.moe_args.score_before_experts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is already available in first arg
torchtitan/config/job_config.py
Outdated
| Note that this is still an experimental feature. | ||
| """ | ||
|
|
||
| moe_comm_backend: Literal["standard", "deep_ep"] = "standard" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe follow the convention and call it
| moe_comm_backend: Literal["standard", "deep_ep"] = "standard" | |
| expert_parallel_comm_backend: Literal["standard", "deep_ep"] = "standard" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| # Select parallelism style based on use_deepep flag | ||
| if use_deepep: | ||
| from torchtitan.distributed import ExpertParallelDeepEP | ||
| from torchtitan.tools.logging import logger as parallelism_logger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right now it's the same the logger in this class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
torchtitan/models/moe/__init__.py
Outdated
| "MoEWithDeepEP", | ||
| "MoEFlexTokenDispatcher", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now maybe let's not expose them here, at the cost of using HAS_DEEPEP everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG, done
| from deep_ep.utils import EventOverlap, EventHandle | ||
| HAS_DEEPEP = True | ||
| except ImportError: | ||
| HAS_DEEPEP = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of testing HAS_DEEPEP at multiple locations, I wonder if we can just error out here and let the callsites be careful about importing this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Now only build_moe util in moe.py deal with HAS_DEEPEP - and fall back if HAS_DEEPEP is false
| ) | ||
| maybe_enable_async_tp(job_config, world_mesh["tp"]) | ||
|
|
||
| if parallel_dims.tp_enabled or parallel_dims.ep_enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does DeepEP work with TP? If not we should error out when are enabled together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| from torchtitan.distributed import ExpertParallelDeepEP | ||
| from torchtitan.tools.logging import logger as parallelism_logger | ||
| experts_plan = ExpertParallelDeepEP() | ||
| parallelism_logger.info(f" Applying DeepEP to MoE layer") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add a general logger as "Applied Expert Parallel with xxx comm backend"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| ) | ||
|
|
||
|
|
||
| class ExpertParallelDeepEP(ExpertParallel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
| class ExpertParallelDeepEP(ExpertParallel): | |
| class DeepEPExpertParallel(ExpertParallel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| ep_group = device_mesh.get_group() | ||
| routed_input, routed_prob = mod.deepep_dispatcher.token_dispatch(routed_input, ep_group) | ||
| routed_input, num_tokens_per_expert, routed_prob = mod.deepep_dispatcher.dispatch_postprocess(routed_input, None) | ||
| return routed_input, num_tokens_per_expert |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a place where padding is done so that each expert is always getting a multiple of 8 / 16 tokens (required by torch._group_mm), similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L129-L136
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the alignment requirements on M dim isn't about functionality (There should be functionality alignment requirement on contracting dim K though, but DeepSeek moe intermediate size is for sure multiples of 16bytes) - And I found padding doesn't improve performance either, so added a configurable pad_to_alignment to let user choose
| logger.info(f"Allocated fallback RDMA buffer: {num_rdma_bytes} bytes") | ||
|
|
||
| low_latency_mode = is_multinode or group.size() > 8 | ||
| buffer = Buffer(group=group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=low_latency_mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this Buffer for? It seems it's not proportional to the num of tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides input tokens and output tokens, DeepEP needs to initialize a symmetric buffers (same address on all ranks) for chunked RDMA/NVLink comm; Since the output size is unknown until CPU sync, it takes pre-configured bytes
| num_recv_tokens_per_expert_tensor = torch.tensor( | ||
| num_recv_tokens_per_expert_list, dtype=torch.int64, device='cpu' | ||
| ).to(recv_x.device, non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious what this is doing? I know in "standard" impl we are doing D2H sync, but this seems H2D sync?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great checking! c:
It's not moving any mem between host and device; it attempts to convert a python list to a tensor on CPU side, which to(recv_x.device) isn't necessary at all!
| previous_event = _create_event_if_async(async_finish) | ||
|
|
||
| recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, after_event = \ | ||
| buffer.dispatch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder how activation checkpointing is done. We need to save the forward comm result so that backward doesn't do the comms again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created torch custom op for DeepEP so that we can work with SAC. Without custom op (manual caching) SAC would track tensors and assert the total number of created tensors ain't aligned to registry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this if we implemented it outside of experiments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
| # Allow use_flex_attn to be set from config | ||
| if hasattr(job_config.model, 'use_flex_attn') and job_config.model.use_flex_attn is not None: | ||
| self.use_flex_attn = job_config.model.use_flex_attn | ||
| logger.info(f"Setting use_flex_attn={self.use_flex_attn} from config") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the the current version support flex attention? Is this block still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
|
Thanks for the feedback! A summary of our discussion:
|
Thanks for the summary! It looks good to me in general. Just one suggestion -- I think we could easily have a unified MoE module once we have (1). So build_moe may retire if this works. |
Agree, we can then construct the right Dispatcher based on the env - and be very compatible with the current flow |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general. Please address remaining comments and I believe we can merge.
Could you also show a comparison of loss curve between standard and deepep impls?
You may find https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#reproducibility-between-runs useful, but such a test only requires setting seed I guess.
| if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): | ||
| logger.warning( | ||
| "Failed to use grouped mm, which is only supported on SM90 or later", | ||
| "Failed to use grouped_mm, which is only supported on SM90 or later", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops not seeing it
|
SG - will try getting them done by EOD (also please ignore my accidentally triggered review request lol) |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please fix linting error so we can merge
|
There's some typing issue. Please resolve. @elfiegg |
|
@elfiegg Thank you very much for the contribution! |


Summary
This initial version integrates DeepEP into TorchTitan, focusing on correctness and compatibility rather than maximal performance tuning.
Perf: DeepSeek-V3 671B on 64 nodes × H100 (512 GPUs total)
Training config (click to expand)
After:
Before:
Loss Curve:
Shout out to my colleagues @gekurian @syed-ahmed @aazzolini for internal supports!