Skip to content

Conversation

@elfiegg
Copy link
Contributor

@elfiegg elfiegg commented Dec 4, 2025

Summary

This initial version integrates DeepEP into TorchTitan, focusing on correctness and compatibility rather than maximal performance tuning.

  • Functional DeepEP-backed MoE + Expert Parallelism
  • User-controlled configuration
  • Compatible with torch.compile and SAC
  • Intended as a first unblocker for benchmarking and iteration

Perf: DeepSeek-V3 671B on 64 nodes × H100 (512 GPUs total)

Training config (click to expand)
config_path="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml",
command_args=[
    "--training.dataset_path=/lustre/fsw/portfolios/sw/users/elfieg/hf_datasets/c4",
    "--training.seq_len=4096",
    "--training.steps=120",
    "--metrics.log_freq=10",
    "--profiling.no-enable-profiling",
    "--comm.init_timeout_seconds=2000",
    "--comm.train_timeout_seconds=300",
    "--metrics.disable_color_printing",

    # Parallelism
    "--parallelism.data_parallel_replicate_degree=1",
    "--parallelism.data_parallel_shard_degree=64",
    "--parallelism.fsdp_reshard_after_forward=default",
    "--parallelism.tensor_parallel_degree=1",
    "--parallelism.expert_parallel_degree=32",
    "--parallelism.expert_tensor_parallel_degree=1",
    "--parallelism.pipeline_parallel_degree=8",
    "--parallelism.pipeline_parallel_schedule=Interleaved1F1B",

    # Training
    "--training.local_batch_size=16",
    "--activation_checkpoint.mode=full",

    # Compilation
    "--compile.enable",
    "--compile.components=model",
    "--compile.components=loss",

    # MoE / DeepEP
    "--debug.moe_force_load_balance",
    "--parallelism.expert_parallel_comm_backend=deepep",
],

After:

memory: 56.75GiB(71.74%)  tps: 579  tflops: 162.82  mfu: 16.46%

Before:

memory: 60.18GiB(76.07%)  tps: 346  tflops: 97.24  mfu: 9.83%

Loss Curve:

Screenshot 2025-12-16 at 11 30 02 PM

Shout out to my colleagues @gekurian @syed-ahmed @aazzolini for internal supports!

@meta-cla
Copy link

meta-cla bot commented Dec 4, 2025

Hi @elfiegg!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@elfiegg elfiegg force-pushed the loss_bug branch 7 times, most recently from a5875e5 to 6999d1e Compare December 4, 2025 05:13
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing!

Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU)

May I ask what is the baseline?

Btw, I think to fully utilize the power of DeepEP, we also need to have node-limited routing, which the current torchtitan DSv3 model doesn't have.

@shuhuayu let's add it? we can refer to HF or deepseek original impl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of making it an experiment (which restricts it to a special version of deepseek_v3), I think we should integrate it directly in core.
We can have a factory method (e.g. build_moe) which takes a string (e.g. "deep_ep") to dispatch to this version of MoE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that's a great idea! - once I confirm this works for larger models and improves perf

Regarding integrating directly to main - do we need to manage DeepEP dependency at all or we leave it to the users to install?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I prefer

we leave it to the users to install

instead of bundling it by default. We can explicitly mention this in try-catch when we do the import.

from torch.distributed.tensor import distribute_module


class DeepEPExpertParallel(ParallelStyle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere? I'm guessing that this is not running e2e with torchtitan train.py which is still WIP.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.

This is our implementation and our deepep interface is similar to Nvidia's version.

https://github.com/AMD-AGI/torchtitan-amd/blob/59fe226a46d76a553d4e591411a464390475be02/torchtitan/distributed/expert_parallel.py#L441

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr! I think we should support node-limited routing to make multi-node setup faster.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the node-limited routing here: #2111. Perhaps it helps make deepep faster in multi-node setups.

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 4, 2025

Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU)

May I ask what is the baseline?

Actually I have rerun this last night and the perf caught up - the lagging perf was gone once I enabled FSDP for MoE layer (which I disabled for debugging purpose). Running below command, I got 13% MFU for both baseline and DeepEP version

torchrun \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$NPROC_PER_NODE \
    --rdzv_id=deepseek_16b_multinode \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
    -m torchtitan.train \
   --parallelism.expert_parallel_degree 16 \
    --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Baseline I referred to the config here: ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml
And DeepEP version is to override --model.name=deepep.deepseek_v3

@yuankaichen-amd
Copy link

Thanks for posting the work!

We had a successful and performant DeepEP integration at: AMD-AGI@59fe226

We borrowed some design from Megatron-LM and we can use it here too.

I don't see big differences between our DeepEP interface and yours. Let's work together on this. Feel free to reach out to me or Tianyu for future discussion and collaboration.

from torch.distributed.tensor import distribute_module


class DeepEPExpertParallel(ParallelStyle):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.

This is our implementation and our deepep interface is similar to Nvidia's version.

https://github.com/AMD-AGI/torchtitan-amd/blob/59fe226a46d76a553d4e591411a464390475be02/torchtitan/distributed/expert_parallel.py#L441


if self.score_before_experts:
recv_x_sorted = (recv_x_sorted.to(torch.float32) * valid_weights_sorted.unsqueeze(-1)).to(recv_x_sorted.dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code before experts.forward should go to DeepEPExpertParallel as input_fn, same for the token unpermute after experts as output_fn

Also consider using something like _indices_to_multihot_kernel (https://github.com/NVIDIA/Megatron-LM/blob/f5344166732f45bb0dd825dc875288ea97b15b47/megatron/core/fusions/fused_indices_converter.py#L32C5-L32C32) to preprocess received DeepEP data.

You are using a lot of index-selecting here which I suspect would incur significant CPU overhead (and lock/wait among CPU threads)

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 5, 2025

Thanks all for the valuable advice! - I'm currently occupied by a deadline but I will take a closer look and join the discussion tomorrow

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 5, 2025

I scanned through the comments, and here is a summary:

  1. We prefer wrapping DeepEP dispatch and combine logic into ExpertParallel module for clear injections
  2. We prefer a a factory method to build MoE module based on a configurable string - depending on user's choice and/or container environment
  3. We prefer lowering the fusible torch naive ops like indexing to written triton kernels for efficiency
  4. We prefer integrating directly to non-experimental codebase

If this looks good to everyone, I'll start revising the PR
cc @tianyu-l @yuankaichen-amd @shuhuayu

@tianyu-l
Copy link
Contributor

tianyu-l commented Dec 5, 2025

@elfiegg sounds good overall

  1. We prefer lowering the fusible torch naive ops like indexing to written triton kernels for efficiency

I think this may be done in a followup PR, assuming that the tradeoff can be justified by benchmarking results. cc @yuankaichen-amd WDYT?

@yuankaichen-amd
Copy link

@elfiegg sounds good overall

  1. We prefer lowering the fusible torch naive ops like indexing to written triton kernels for efficiency

I think this may be done in a followup PR, assuming that the tradeoff can be justified by benchmarking results. cc @yuankaichen-amd WDYT?

Either way works -- it should be a low-hanging fruit. The triton kernel is available in both Megatron and my integration PR.

@yuankaichen-amd
Copy link

I scanned through the comments, and here is a summary:

  1. We prefer wrapping DeepEP dispatch and combine logic into ExpertParallel module for clear injections
  2. We prefer a a factory method to build MoE module based on a configurable string - depending on user's choice and/or container environment
  3. We prefer lowering the fusible torch naive ops like indexing to written triton kernels for efficiency
  4. We prefer integrating directly to non-experimental codebase

If this looks good to everyone, I'll start revising the PR cc @tianyu-l @yuankaichen-amd @shuhuayu

I strongly recommend wrapping DeepEP related ops into a standalone class:
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py#L52

There are six methods:
pre_dispatch / dispatch / post_dispatch
pre_combine / combine / post_combine

It will set a clear boundary between DeepEP and Torchtitan's MoE module or ExpertParallel wrapper. Also you can get a free ride for many things that Nvidia has already implemented in Megatron.

@elfiegg elfiegg force-pushed the loss_bug branch 2 times, most recently from 0a9815b to e0d4fcf Compare December 6, 2025 19:03
Comment on lines 168 to 170
x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess(
x, selected_experts_indices, top_scores
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This version looks much cleaner.

After reading the code, I wonder what's the best way to organize code. Brainstorming with some (immature) ideas:

  1. Use if use_deepep in the code for the region of difference. @shuhuayu IIUC you were having this idea?
  2. Abstract token_dispatching + routed_experts computation into its own classes, so that the MoE class can be shared.
  3. Moving dispatch_preprocess and dispatch_postprocess also inside ExpertParallel hooks. The challenges seems that ExpertParallel classes are not getting all the inputs we need.
  4. Unify dispatch_preprocess with the TokenReorderer concept needed for the non-deepep impl. The challenge is similar to 3 in that the interfaces do not really align.

@yuankaichen-amd would love to hear your thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can revise the existing MoE modules by adding a if use_deepep branch, or finding a way to inherit this MoE module if possible.

Copy link

@yuankaichen-amd yuankaichen-amd Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we compare the existing MoE against the DeepEP MoE:

self.router(x, self.expert_bias)  

self.reorderer(top_scores, selected_experts_indices)

#### many lines omitted

# shape (bs*slen*top_k, dim)

routed_output = self.experts(routed_input, num_tokens_per_expert)

===============================================

self.router(x, self.expert_bias)

x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess(
       x, selected_experts_indices, top_scores
)

routed_output = self.experts(x_prep, num_tokens_per_expert)

For the old MoE class, what happens between router and experts is also a kind of preprocess.

I think we can have an "AlltoallTokenDispatcher" which modularizes these operations. So in the combined MoE implementation, we will have:

self.router(x, self.expert_bias)

some_token_dispatcher.preprocess(x, selected_experts_indices, top_scores)

routed_output = self.experts(routed_input, ...)

With this, we can even combine the ExpertParallel with DeepExerptParallel, where _token_dispatch and _token_combine will be interfaced with a token dispatcher directly.

Copy link
Contributor

@tianyu-l tianyu-l Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuankaichen-amd
Sounds nice in general, I have some concerns about the following:

  • Now token_dispatcher is a submodule of MoE but we only call token_dispatcher.preprocess in model code, and delay the actual dispatch / combine into hooks, which doesn't sound natural.
  • The benefit of using hooks was that single-device code is still correct, and we can apply EP on top of single-device code. If we do the DeepEP path which has a different preprocess method, would the single-device code EP=1 + DeepEP enabled still be "correct"?

Based on the above points, should we move towards dispatch / combine being actually called "in the model code", any be of no-op when EP is not enabled?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your concerns. By "dispatch / combine being actually called "in the model code"", are you suggesting that with this token_dispatcher design, we should also retire the ExpertParallel wrapper?

Copy link
Contributor

@shuhuayu shuhuayu Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the clarification! Actually now I am thinking that we can inherit from current MoE to create a new DeepEPMoE since they share most code, and use the current build_moe function to build it if DeepEP is used. So basically, we use separate MoEs (MoE and DeepEPMoE, and also separate ExpertParallels (ExpertParallel and DeepEPExpertParallel) to integrate DeepEP into TorchTitan so the main interface is preserved. We can create a folder under distributed named expert_parallel and put the main class of DeepEPExpertparall into existing expert_parallel.py and other supporting classes and functions into another deep_ep.py, both .py files are under the new expert_parallel folder. For the preprocessing and postprocessing functions for token_dispatch or token_combine (in DeepEPExpertParallel), we can add private methods into the same classes if necessary. WDYT?

Copy link
Contributor Author

@elfiegg elfiegg Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be similar to the current implementation - either way works for me, abstracting common API or inheriting the current modules. One benefit we get out of box by not separating ExpertParallel is ETP seems to work out of box, also as @yuankaichen-amd pointed out, comm-comp overlap might also work along.
No objection to build_moe and DeeoEPMoE - either way we need to condition the logic at some level. Separating them might be beneficial if more comm libraries come up and complicate MoE class

Copy link
Contributor

@tianyu-l tianyu-l Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comm-comp overlap might also work along.

@yuankaichen-amd could you share more details about the overlapping and how having the same EP class would help? I discussed with @shuhuayu offline and we are not sure what's the benefit of separating into six methods, namely [dispatch, combine] x [preprocess, process, postprocess].

My take is that ExpertParallel class is implicitly a token dispatcher, it's just applied in a wrapper from outside the model.
Explicitly creating another token dispatcher (for expert parallel comms), and let the implicit one (ExpertParallel) access the explicit token dispatcher class indirectly via has_attr call sounds not very straightforward.

At this moment, I'm leaning towards having two MoE classes

  • letting the DeepEPMoE inheriting the base MoE one and only overwrites the forward function
  • Having another DeepEPExpertParallel class which inherits a BaseExpertParallel protocol class.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a discussion with @tianyu-l and @shuhuayu offline. I agree that
(1) having two separate MoE classes may work best for now (@elfiegg's original design);
(2) token_dispatcher now looks unnecessary (sorry for suggesting it in the first place) and we can directly invoke DeepEP related methods in the DeepEPExpertParallel;
(3) DeepEP needs some additional inputs, we can do this by some additional attributes in DeepEPMoE module, or add a manager subclass if needed. Since @elfiegg has already had DeepEp manager implemented, I'd suggest let's leave it as it is for now. I will also need to take a closer look at AMD's DeepEP's different modes. Let's review this design later.

@elfiegg what do you think?

@tianyu-l @shuhuayu Please add if I missed anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM - will send changes by EoD


# Setup dispatcher metadata (routing information) for hooks to use
# The hooks will call token_dispatch/token_combine which need this metadata
x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would preprocess and postprocess do if we don't use EP, e.g. single-device computation -- would it be no-op?

Copy link
Contributor Author

@elfiegg elfiegg Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to fall back to standard impl if (no EP && use DeepEP)

from torchtitan.tools.logging import logger


class MoEWithDeepEP(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we inherit existing MoE to reuse most code except for MoEFlexTokenDispatcher init and the forward call?

dim=model_args.dim,
hidden_dim=model_args.moe_inter_dim,
communication_backend=model_args.moe_comm_backend,
score_before_experts=model_args.moe_args.score_before_experts,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is already available in first arg

Note that this is still an experimental feature.
"""

moe_comm_backend: Literal["standard", "deep_ep"] = "standard"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe follow the convention and call it

Suggested change
moe_comm_backend: Literal["standard", "deep_ep"] = "standard"
expert_parallel_comm_backend: Literal["standard", "deep_ep"] = "standard"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# Select parallelism style based on use_deepep flag
if use_deepep:
from torchtitan.distributed import ExpertParallelDeepEP
from torchtitan.tools.logging import logger as parallelism_logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now it's the same the logger in this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 15 to 16
"MoEWithDeepEP",
"MoEFlexTokenDispatcher",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now maybe let's not expose them here, at the cost of using HAS_DEEPEP everywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, done

from deep_ep.utils import EventOverlap, EventHandle
HAS_DEEPEP = True
except ImportError:
HAS_DEEPEP = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of testing HAS_DEEPEP at multiple locations, I wonder if we can just error out here and let the callsites be careful about importing this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Now only build_moe util in moe.py deal with HAS_DEEPEP - and fall back if HAS_DEEPEP is false

)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does DeepEP work with TP? If not we should error out when are enabled together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

from torchtitan.distributed import ExpertParallelDeepEP
from torchtitan.tools.logging import logger as parallelism_logger
experts_plan = ExpertParallelDeepEP()
parallelism_logger.info(f" Applying DeepEP to MoE layer")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a general logger as "Applied Expert Parallel with xxx comm backend"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)


class ExpertParallelDeepEP(ExpertParallel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
class ExpertParallelDeepEP(ExpertParallel):
class DeepEPExpertParallel(ExpertParallel):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ep_group = device_mesh.get_group()
routed_input, routed_prob = mod.deepep_dispatcher.token_dispatch(routed_input, ep_group)
routed_input, num_tokens_per_expert, routed_prob = mod.deepep_dispatcher.dispatch_postprocess(routed_input, None)
return routed_input, num_tokens_per_expert
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a place where padding is done so that each expert is always getting a multiple of 8 / 16 tokens (required by torch._group_mm), similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L129-L136

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the alignment requirements on M dim isn't about functionality (There should be functionality alignment requirement on contracting dim K though, but DeepSeek moe intermediate size is for sure multiples of 16bytes) - And I found padding doesn't improve performance either, so added a configurable pad_to_alignment to let user choose

logger.info(f"Allocated fallback RDMA buffer: {num_rdma_bytes} bytes")

low_latency_mode = is_multinode or group.size() > 8
buffer = Buffer(group=group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=low_latency_mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this Buffer for? It seems it's not proportional to the num of tokens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides input tokens and output tokens, DeepEP needs to initialize a symmetric buffers (same address on all ranks) for chunked RDMA/NVLink comm; Since the output size is unknown until CPU sync, it takes pre-configured bytes

Comment on lines 142 to 144
num_recv_tokens_per_expert_tensor = torch.tensor(
num_recv_tokens_per_expert_list, dtype=torch.int64, device='cpu'
).to(recv_x.device, non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what this is doing? I know in "standard" impl we are doing D2H sync, but this seems H2D sync?

Copy link
Contributor Author

@elfiegg elfiegg Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great checking! c:
It's not moving any mem between host and device; it attempts to convert a python list to a tensor on CPU side, which to(recv_x.device) isn't necessary at all!

previous_event = _create_event_if_async(async_finish)

recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, after_event = \
buffer.dispatch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how activation checkpointing is done. We need to save the forward comm result so that backward doesn't do the comms again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created torch custom op for DeepEP so that we can work with SAC. Without custom op (manual caching) SAC would track tensors and assert the total number of created tensors ain't aligned to registry

@elfiegg elfiegg changed the title Integrate DeepEP to experimental torchtitan Integrate DeepEP to torchtitan Dec 8, 2025
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this if we implemented it outside of experiments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines 101 to 105
# Allow use_flex_attn to be set from config
if hasattr(job_config.model, 'use_flex_attn') and job_config.model.use_flex_attn is not None:
self.use_flex_attn = job_config.model.use_flex_attn
logger.info(f"Setting use_flex_attn={self.use_flex_attn} from config")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the the current version support flex attention? Is this block still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 9, 2025

Thanks for the feedback! A summary of our discussion:

  1. Design direction (API unification):
    move dispatch pre/post processing into ExpertParallel hooks, and extract the standard MoE token_reorder logic into a preprocess stage
    Let both DeepEP and the standard MoE initialize a common Dispatcher (subclassing an abstract AllToAllDispatcher). -
    This further isolates parallelism logic from expert computation and unifies the APIs across standard and library-specific paths

  2. Build-time selection (DeepEP vs standard):
    Introduce a single top-level HAS_DEEPEP switch to control module selection in build_moe, and enforce consistent failure behavior in the underlying implementations when the configuration is unsupported

  3. TP + EP restriction
    Not aware about any restriction, thought ETP would orthogonally handle EP and TP mesh(so DeepEP all-to-all still applies):

    class ExpertTensorParallel(ExpertParallel):
    def _token_dispatch(self, mod, inputs, device_mesh):
    routed_input, num_tokens_per_expert = inputs
    # NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors.
    # The grad_placements on inputs is set to Partial so that necessary
    # reductions are performed during backward.
    routed_input = DTensor.from_local(
    routed_input, device_mesh["tp"], (Replicate(),)
    ).to_local(grad_placements=(Partial(),))
    inputs = (routed_input, num_tokens_per_expert)
    # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
    return super()._token_dispatch(mod, inputs, device_mesh["ep"])

    After we modify the design to wrap token dispatching logic in ExpertParallel, I assume this should automatically work?

  4. Additional suggestions
    (a) Add a more general logging/logger
    (b) Pad to multiple of 8 tokens for BF16 due to the CUTLASS 16 Bytes alignment requirement
    (c) Add SAC support for saving forward communication results for backward
    (d) Inherit standard MoE and reduce code duplication
    (d) Perform general code cleanup across the module

@yuankaichen-amd
Copy link

Thanks for the feedback! A summary of our discussion:

  1. Design direction (API unification):
    move dispatch pre/post processing into ExpertParallel hooks, and extract the standard MoE token_reorder logic into a preprocess stage
    Let both DeepEP and the standard MoE initialize a common Dispatcher (subclassing an abstract AllToAllDispatcher). -
    This further isolates parallelism logic from expert computation and unifies the APIs across standard and library-specific paths

  2. Build-time selection (DeepEP vs standard):
    Introduce a single top-level HAS_DEEPEP switch to control module selection in build_moe, and enforce consistent failure behavior in the underlying implementations when the configuration is unsupported

  3. TP + EP restriction
    Not aware about any restriction, thought ETP would orthogonally handle EP and TP mesh(so DeepEP all-to-all still applies):

    class ExpertTensorParallel(ExpertParallel):
    def _token_dispatch(self, mod, inputs, device_mesh):
    routed_input, num_tokens_per_expert = inputs
    # NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors.
    # The grad_placements on inputs is set to Partial so that necessary
    # reductions are performed during backward.
    routed_input = DTensor.from_local(
    routed_input, device_mesh["tp"], (Replicate(),)
    ).to_local(grad_placements=(Partial(),))
    inputs = (routed_input, num_tokens_per_expert)
    # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
    return super()._token_dispatch(mod, inputs, device_mesh["ep"])

    After we modify the design to wrap token dispatching logic in ExpertParallel, I assume this should automatically work?

  4. Additional suggestions
    (a) Add a more general logging/logger framework
    (b) Pad to multiple of 8 tokens for BF16 due to the CUTLASS 16 Bytes alignment requirement
    (c) Add SAC support for saving forward communication results for backward
    (d) Inherit standard MoE and reduce code duplication
    (d) Perform general code cleanup across the module

Thanks for the feedback! A summary of our discussion:

  1. Design direction (API unification):
    move dispatch pre/post processing into ExpertParallel hooks, and extract the standard MoE token_reorder logic into a preprocess stage
    Let both DeepEP and the standard MoE initialize a common Dispatcher (subclassing an abstract AllToAllDispatcher). -
    This further isolates parallelism logic from expert computation and unifies the APIs across standard and library-specific paths

  2. Build-time selection (DeepEP vs standard):
    Introduce a single top-level HAS_DEEPEP switch to control module selection in build_moe, and enforce consistent failure behavior in the underlying implementations when the configuration is unsupported

  3. TP + EP restriction
    Not aware about any restriction, thought ETP would orthogonally handle EP and TP mesh(so DeepEP all-to-all still applies):

    class ExpertTensorParallel(ExpertParallel):
    def _token_dispatch(self, mod, inputs, device_mesh):
    routed_input, num_tokens_per_expert = inputs
    # NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors.
    # The grad_placements on inputs is set to Partial so that necessary
    # reductions are performed during backward.
    routed_input = DTensor.from_local(
    routed_input, device_mesh["tp"], (Replicate(),)
    ).to_local(grad_placements=(Partial(),))
    inputs = (routed_input, num_tokens_per_expert)
    # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
    return super()._token_dispatch(mod, inputs, device_mesh["ep"])

    After we modify the design to wrap token dispatching logic in ExpertParallel, I assume this should automatically work?

  4. Additional suggestions
    (a) Add a more general logging/logger framework
    (b) Pad to multiple of 8 tokens for BF16 due to the CUTLASS 16 Bytes alignment requirement
    (c) Add SAC support for saving forward communication results for backward
    (d) Inherit standard MoE and reduce code duplication
    (d) Perform general code cleanup across the module

Thanks for the summary! It looks good to me in general. Just one suggestion -- I think we could easily have a unified MoE module once we have (1). So build_moe may retire if this works.

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 9, 2025

Thanks for the summary! It looks good to me in general. Just one suggestion -- I think we could easily have a unified MoE module once we have (1). So build_moe may retire if this works.

Agree, we can then construct the right Dispatcher based on the env - and be very compatible with the current flow

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general. Please address remaining comments and I believe we can merge.

Could you also show a comparison of loss curve between standard and deepep impls?

You may find https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#reproducibility-between-runs useful, but such a test only requires setting seed I guess.

if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
"Failed to use grouped_mm, which is only supported on SM90 or later",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops not seeing it

@elfiegg elfiegg requested a review from tianyu-l December 17, 2025 01:18
@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 17, 2025

SG - will try getting them done by EOD (also please ignore my accidentally triggered review request lol)

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 17, 2025

Baseline & DeepEP loss curve(overlapped):
Screenshot 2025-12-16 at 11 30 02 PM

MFU on 16B model, 2 nodes:
Screenshot 2025-12-16 at 11 30 18 PM

Run via NODES=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --debug.seed 42 --parallelism.pipeline-parallel-degree 2 --parallelism.expert-parallel-comm-backend deepep

@syed-ahmed syed-ahmed moved this to In Progress in PyTorch + CUDA Dec 17, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please fix linting error so we can merge

@tianyu-l
Copy link
Contributor

There's some typing issue. Please resolve. @elfiegg

@tianyu-l tianyu-l merged commit 36a4b69 into pytorch:main Dec 18, 2025
8 of 9 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in PyTorch + CUDA Dec 18, 2025
@tianyu-l
Copy link
Contributor

@elfiegg Thank you very much for the contribution!

@elfiegg elfiegg deleted the loss_bug branch December 18, 2025 17:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants