Skip to content

Conversation

@cboss6
Copy link
Contributor

@cboss6 cboss6 commented Nov 11, 2025

Purpose

Enable deepep low latency all2all backend with round-robin expert plamement strategy, which produces significant performance improvement.

Performance

Test Platform: CUDA 12.8, drivier 550.144.03
Model: DeepSeek-R1-671B
GPU: H20 * 8 * 2 nodes
Vllm config: dp=16, tp=1, enable_expert_parallel=1, all2all_backend=deepep_low_latency, use_deep_gemm=1

The current functionality has been fully implemented with correct accuracy and significant performance improvements.
In the benchmark with parameters times=10, num_prompts=512, dataset=sharegpt, input_len=1024, output_len=512, max_concurrency=8, and req_rate=8,
the results show that compared to the default linear placement, round-robin's throughput increased by 14.57% and TPOT improved by 13.38%.
Clipboard_Screenshot_1762843721

serving command (head node)

export NCCL_CHECK_DISABLE=1
export NCCL_COLLNET_ENABLE=0
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_IB_DISABLE=0
export NCCL_IB_GID_INDEX=3
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_2,mlx5_bond_3,mlx5_bond_4,mlx5_bond_5,mlx5_bond_6,mlx5_bond_7,mlx5_bond_8
export NCCL_IB_QPS_PER_CONNECTION=4
export NCCL_IB_SL=3
export NCCL_IB_TC=160
export NCCL_LL_THRESHOLD=16384
export NCCL_NET_GDR_LEVEL=2
export NCCL_NVLS_ENABLE=0
export NCCL_P2P_DISABLE=0
export NCCL_PXN_DISABLE=0
export NCCL_SOCKET_IFNAME=bond1
export GLOO_SOCKET_IFNAME=bond1
export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=bond1
export NVSHMEM_HCA_LIST=mlx5_bond_1:1,mlx5_bond_2:1,mlx5_bond_3:1,mlx5_bond_4:1,mlx5_bond_5:1,mlx5_bond_6:1,mlx5_bond_7:1,mlx5_bond_8:1
export NVSHMEM_IB_TRAFFIC_CLASS=160
export NVSHMEM_DIR=/usr/local/nvshmem
export LD_LIBRARY_PATH=${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH
export PATH=${NVSHMEM_DIR}/bin:$PATH
export VLLM_ALL2ALL_BACKEND=deepep_low_latency
export VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION=1
export VLLM_ATTENTION_BACKEND=FLASHMLA
export VLLM_USE_DEEP_GEMM=1
export DG_JIT_CACHE_DIR=/root/.cache/vllm/deep_gemm/

vllm serve /path/to/DeepSeek-R1
--host 0.0.0.0
-tp 1
--max-model-len 16384
--max-num-batched-tokens 16384
--expert-placement-strategy round_robin
--enable-chunked-prefill
--gpu-memory-utilization 0.8
--load-format "auto"
--enable-expert-parallel
--data-parallel-size 16
--data-parallel-size-local 8
--data-parallel-address ${HOST_IP}
--data-parallel-rpc-port 12345
--api-server-count=8

benchmark command

vllm bench serve \ --backend vllm \ --base-url "http://127.0.0.1:8500" \ --port 8500 \ --endpoint '/v1/completions' \ --model ${model} \ --dataset-name sharegpt \ --num-prompts 512 \ --max-concurrency 8
--request-rate 8 \ --random-input-len 1024 \ --random-output-len 512

Accuracy

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.97 ± 0.0171
strict-match 5 exact_match 0.97 ± 0.0171
Clipboard_Screenshot_1762842773

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the deepep-low-latency all-to-all backend with a round-robin expert placement strategy, which is a valuable performance enhancement for Mixture-of-Experts models. The implementation is well-structured, adding the necessary logic for mapping global to physical expert IDs and including checks for unsupported backends. My review includes a couple of suggestions: one to improve code quality and performance in the routing table generation, and another to relax a condition that currently limits the new placement strategy to specific model architectures, potentially broadening its applicability.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@mergify
Copy link

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cboss6.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@cboss6 cboss6 force-pushed the bruce/enable-deepepll-with-roundrobin branch from 3fe12ae to 3a13456 Compare November 11, 2025 10:02
Comment on lines 96 to 98
self.global_to_physical: torch.Tensor | None = None
self.physical_to_global: torch.Tensor | None = None
self.local_expert_global_ids: torch.Tensor | None = None
Copy link
Collaborator

@bnellnm bnellnm Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should just be additional optional parameters to the __init__ method. Then we won't need a setter.

You can pass these through maybe_make_prepare_finalize as optional parameters as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, the setter is removed.

@mergify
Copy link

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cboss6.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2025
Comment on lines 129 to 130
physical = self.global_to_physical[topk_ids.to(torch.long)]
return physical.to(topk_ids.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the casts necessary here? topk_ids should already be an int of some type and if needed, global_to_physical could be cast in set_expert_routing_info (or __init__) using topk_indices_dtype so that there's no issues about type mismatches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 254 to 273
if hasattr(layer, "_ensure_expert_routing_tables"):
layer._ensure_expert_routing_tables()
prepare_finalize = self.maybe_make_prepare_finalize()

if prepare_finalize is not None:
logger.debug(
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
)
if (
getattr(layer, "use_ep", False)
and hasattr(prepare_finalize, "set_expert_routing_info")
and hasattr(layer, "expert_global_to_physical")
and hasattr(layer, "expert_physical_to_global")
and hasattr(layer, "expert_local_to_global")
):
prepare_finalize.set_expert_routing_info(
layer.expert_global_to_physical,
layer.expert_physical_to_global,
layer.expert_local_to_global,
)
Copy link
Collaborator

@bnellnm bnellnm Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The construction logic was just refactored recently and maybe_init_modular_kernel has moved to FusedMoE. So that should make this a little bit simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 1358 to 1360
self.use_deepep_ht_kernels
or self.use_pplx_kernels
or self.use_flashinfer_cutlass_kernels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change the condition to be not self.use_deepep_ll_kernels? That way, if any new all2all mechanisms are added, they won't accidentally skip this code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@bnellnm
Copy link
Collaborator

bnellnm commented Nov 11, 2025

Nice perf numbers! I just had a few minor comments.

@cboss6 cboss6 force-pushed the bruce/enable-deepepll-with-roundrobin branch from 8645f1d to 2a8992d Compare November 14, 2025 11:08
@cboss6
Copy link
Contributor Author

cboss6 commented Nov 14, 2025

Hi @bnellnm, I’ve updated the code based on your comments. Could you take another look when you have a moment? Thanks a lot!

def _maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
layer: torch.nn.Module | None = None,
Copy link
Collaborator

@bnellnm bnellnm Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you pass the maps as individual optional tensors? Or an optional tuple of tensors? In latest main, this function has been moved to a standalone function that can't depend on the layer.

Also, I think adding a layer argument here would break one of the unit tests that depend on this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I’ve updated the code to use a tuple of routing tables instead of passing the layer, and I’ve also rebased onto the latest main.

Could you take a look and see if this approach is acceptable to you?

@cboss6 cboss6 force-pushed the bruce/enable-deepepll-with-roundrobin branch from 2a8992d to 33a5de5 Compare November 16, 2025 16:34
@mergify mergify bot removed the needs-rebase label Nov 16, 2025
Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Co-authored-by: tbzhang <tbzhang@outlook.com>

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Signed-off-by: bruceszchen <bruceszchen@tencent.com>
@cboss6 cboss6 force-pushed the bruce/enable-deepepll-with-roundrobin branch from e2a3038 to 90bca85 Compare November 16, 2025 16:46
Copy link
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing seems obviously incorrect to me, let's see what the tests say

@hmellor hmellor added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 19, 2025
@hmellor hmellor merged commit da2f680 into vllm-project:main Nov 19, 2025
53 checks passed
Victor49152 pushed a commit to Victor49152/vllm that referenced this pull request Nov 20, 2025
…ent. (vllm-project#28449)

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
bhagyashrigai pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Nov 20, 2025
…ent. (vllm-project#28449)

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Bhagyashri <Bhagyashri.Gaikwad2@ibm.com>
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
…ent. (vllm-project#28449)

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: LuminolT <lumischen01@gmail.com>
bigPYJ1151 pushed a commit that referenced this pull request Nov 25, 2025
…ent. (#28449)

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
…ent. (vllm-project#28449)

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants