Skip to content

Conversation

Amir-19
Copy link
Contributor

@Amir-19 Amir-19 commented Sep 9, 2025

Purpose

add nccl symmetric memory for all reduce. currently because of the limitations of torch compile we have to create a custom op and copy the input to a tensor allocated and registered with nccl.

Test Plan

benchmark_throughput.py

Test Result

before:

VLLM_USE_STANDALONE_COMPILE=0 VLLM_WORKER_MULTIPROC_METHOD=spawn python3 benchmarks/benchmark_throughput.py --model=deepseek-ai/DeepSeek-R1  --output-len=1000 --tensor-parallel-size=8 --input-len=1000 --max-model-len=2000 --trust-remote --load-format=dummy --gpu_memory_utilization=0.95 --max-num-seqs=256 --num-prompts=1024

Throughput: 4.50 requests/s, 8989.34 total tokens/s, 4500.46 output tokens/s
Total num prompt tokens:  1021366
Total num output tokens:  1024000

after with addition of VLLM_USE_NCCL_SYMM_MEM=1 NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1:

VLLM_USE_NCCL_SYMM_MEM=1 NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_STANDALONE_COMPILE=0 VLLM_WORKER_MULTIPROC_METHOD=spawn python3 benchmarks/benchmark_throughput.py --model=deepseek-ai/DeepSeek-R1  --output-len=1000 --tensor-parallel-size=8 --input-len=1000 --max-model-len=2000 --trust-remote --load-format=dummy --gpu_memory_utilization=0.95 --max-num-seqs=256 --num-prompts=1024

Throughput: 4.66 requests/s, 9310.51 total tokens/s, 4657.81 output tokens/s
Total num prompt tokens:  1022878
Total num output tokens:  1024000

to ensure the memory allocation and collective uses symmetric memory you can use NCCL_DEBUG=INFO NCCL_DEBUG_SUBSYS=TUNING,REG and should see NCCL INFO AllReduce [Symmetric] in the logs.

Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              pynccl-symm         symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.009               0.013               0.016               0.023               0.017               0.018               ca_1stage (1.87x)
(64, 8192)          1.00 MB        0.013               0.017               0.019               0.020               0.022               0.020               ca_1stage (1.47x)
(96, 8192)          1.50 MB        0.017               0.021               0.020               0.029               0.027               0.021               ca_1stage (1.22x)
(128, 8192)         2.00 MB        0.020               0.026               0.023               0.029               0.031               0.025               ca_1stage (1.17x)
(192, 8192)         3.00 MB        0.026               0.032               0.029               0.035               0.040               0.027               ca_1stage (1.14x)
(256, 8192)         4.00 MB        0.033               0.039               0.041               0.036               0.050               0.033               ca_1stage (1.22x)
(512, 8192)         8.00 MB        0.060               0.066               0.046               0.051               0.086               0.050               pynccl (1.00x)
(1024, 8192)        16.00 MB       0.112               0.121               0.057               0.076               0.159               0.084               pynccl (1.00x)
(2048, 8192)        32.00 MB       0.217               0.232               0.087               0.135               0.310               0.157               pynccl (1.00x)
(3062, 8192)        47.84 MB       0.326               0.358               0.121               0.203               0.463               0.233               pynccl (1.00x)
(4096, 8192)        64.00 MB       0.429               0.495               0.153               0.288               0.616               0.327               pynccl (1.00x)

TP = 4
Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              pynccl-symm         symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.010               0.017               0.021               0.017               0.015               0.016               ca_1stage (2.13x)
(64, 8192)          1.00 MB        0.014               0.018               0.023               0.019               0.018               0.019               ca_1stage (1.60x)
(96, 8192)          1.50 MB        0.019               0.025               0.025               0.020               0.021               0.019               ca_1stage (1.35x)
(128, 8192)         2.00 MB        0.023               0.026               0.028               0.021               0.023               0.023               pynccl-symm (1.34x)
(192, 8192)         3.00 MB        0.030               0.035               0.033               0.023               0.028               0.025               pynccl-symm (1.45x)
(256, 8192)         4.00 MB        0.038               0.042               0.039               0.028               0.033               0.029               pynccl-symm (1.43x)
(512, 8192)         8.00 MB        0.068               0.074               0.049               0.036               0.053               0.041               pynccl-symm (1.36x)
(1024, 8192)        16.00 MB       0.128               0.128               0.079               0.056               0.093               0.064               pynccl-symm (1.39x)
(2048, 8192)        32.00 MB       0.247               0.243               0.123               0.101               0.175               0.117               pynccl-symm (1.22x)
(3062, 8192)        47.84 MB       0.369               0.364               0.153               0.149               0.260               0.170               pynccl-symm (1.03x)
(4096, 8192)        64.00 MB       0.490               0.482               0.187               0.194               0.359               0.227               pynccl (1.00x)

TP=8
Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              pynccl-symm         symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.013               0.025               0.031               0.015               0.016               0.023               ca_1stage (2.38x)
(64, 8192)          1.00 MB        0.020               0.026               0.032               0.017               0.016               0.024               symm_mem_multimem (1.96x)
(96, 8192)          1.50 MB        0.027               0.027               0.034               0.020               0.019               0.024               symm_mem_multimem (1.80x)
(128, 8192)         2.00 MB        0.033               0.028               0.035               0.021               0.020               0.024               symm_mem_multimem (1.78x)
(192, 8192)         3.00 MB        0.045               0.044               0.049               0.023               0.023               0.026               symm_mem_multimem (2.19x)
(256, 8192)         4.00 MB        0.056               0.045               0.057               0.026               0.025               0.034               symm_mem_multimem (2.22x)
(512, 8192)         8.00 MB        0.106               0.078               0.074               0.037               0.037               0.046               pynccl-symm (2.00x)
(1024, 8192)        16.00 MB       0.194               0.145               0.104               0.055               0.060               0.078               pynccl-symm (1.89x)
(2048, 8192)        32.00 MB       0.381               0.263               0.172               0.097               0.109               0.139               pynccl-symm (1.77x)
(3062, 8192)        47.84 MB       0.566               0.382               0.225               0.143               0.162               0.201               pynccl-symm (1.57x)
(4096, 8192)        64.00 MB       0.758               0.499               0.281               0.183               0.222               0.269               pynccl-symm (1.53x)


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Amir-19.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 9, 2025
@Amir-19 Amir-19 force-pushed the nccl_symm_ar_custom_op branch from cb370ec to 7a57a5c Compare September 9, 2025 20:49
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces NCCL symmetric memory for all-reduce operations to improve performance. It achieves this by adding a custom PyTorch operation that uses a pluggable allocator based on ncclMemAlloc. The changes are extensive, touching CUDA graph capturing, device communicators, and adding a new environment variable to control the feature. While the implementation is clever, I have identified a few critical issues. The feature has a hard dependency on a pre-release version of PyTorch, which is a significant risk. There is also a potential resource leak related to NCCL communication windows, and verbose logging is enabled by default during JIT compilation. These issues should be addressed to ensure the stability and usability of this new feature.

@Amir-19 Amir-19 force-pushed the nccl_symm_ar_custom_op branch from 7a57a5c to 9431a53 Compare September 9, 2025 20:54
@mergify mergify bot removed the needs-rebase label Sep 9, 2025
Copy link
Contributor

@ilmarkov ilmarkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making things work! I left some comments.
Some benchmarking comparing allreduce implementations would be helpful to figure out the input size bounds of the nccl symm mem allreduce.

@Amir-19 Amir-19 force-pushed the nccl_symm_ar_custom_op branch from 9431a53 to db043bc Compare September 11, 2025 21:00
@mergify mergify bot added the performance Performance-related issues label Sep 11, 2025
@Amir-19 Amir-19 force-pushed the nccl_symm_ar_custom_op branch from db043bc to 8609bad Compare September 11, 2025 21:02
Copy link

mergify bot commented Sep 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Amir-19.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 12, 2025
@Amir-19 Amir-19 force-pushed the nccl_symm_ar_custom_op branch from 8609bad to 6a1c68c Compare September 15, 2025 17:49
@mergify mergify bot added ci/build and removed needs-rebase labels Sep 15, 2025
@Amir-19 Amir-19 force-pushed the nccl_symm_ar_custom_op branch 4 times, most recently from ad0b592 to 040903f Compare September 15, 2025 22:19
@Amir-19 Amir-19 requested a review from ilmarkov September 15, 2025 23:18
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 17, 2025
Copy link
Contributor

@ilmarkov ilmarkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update!

  1. We need to decide on allreduce algo dispatching. If we leave it for follow up PR, then let's keep constants in one place.
  2. PR needs to have some tests: allreduce primitive, get_nccl_mem_pool, etc.
  3. Compilation of nccl allocator needs to be improved for better UX in cases when there is no system NCCL installed.
  4. PR needs a bit of cleaning (commented code, prints).

"Failed to compile NCCL memory allocator. "
"Symmetric memory will be disabled. "
"This is expected if NCCL headers are not available. "
"optionally set VLLM_NCCL_INCLUDE_PATH to point to NCCL header."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ... point to a directory with NCCL header

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be a list of paths.
[envs.VLLM_NCCL_INCLUDE_PATH], otherwise string will be interpreted as a list

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compilation also fails for me on linking. Usually torch comes with installed nccl. I'd suggest to use that one by default replacing with VLLM_NCCL_PATH (for me it work with some hacking, e.g. adding symlinks).
Please, try to use it on the system without pre-installed NCCL in order to debug and add better UX.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_inline's log also shows:

[1/2] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=nccl_allocator -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /usr/local/lib/python3.12/dist-packages/torch/include -isystem /usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/include -isystem /usr/include/python3.12 -fPIC -std=c++17 -c /tmp/main.cpp -o main.o
[2/2] c++ main.o -shared -lnccl -L/usr/local/lib/python3.12/dist-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o nccl_allocator.so

/usr/local/lib/python3.12/dist-packages/torch/include is the location that the compiler looks for the headers and nccl is header file location is '/usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/cuda/nccl.h'

systems without the header and missing/incorrect location provided by the user will fail to compile the allocator and disables nccl symm memory allocation and registration with setting _nccl_allocator_failed_to_compile=True, which is the expected behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that we don't crash when it fails to compile. NCCL usually comes with nvidia-nccl-cu12 package which is installed along with pytorch, the path to nccl home in my case is ENV_PATH/lib/python3.12/site-packages/nvidia/nccl/. So users are actually able to use the nccl symm mem even if they don't system nccl. Moreover, we already use "non system" nccl, when we bind to it. Check out find_nccl_library func. I would suggest to do something similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added find_nccl_include_paths() which looks for nccl.h under the package nvidia-nccl-cuXX or specified directory with env var VLLM_NCCL_INCLUDE_PATH. even without installing nvidia-nccl-cuXX, proper pytorch setup for multi-gpu includes nccl.h which load_line by default uses torch.utils.cpp_extension.include_paths

@Amir-19 Amir-19 requested a review from ilmarkov September 19, 2025 00:54
@Amir-19 Amir-19 force-pushed the nccl_symm_ar_custom_op branch 2 times, most recently from 34312a3 to cb36b90 Compare September 19, 2025 07:52
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
@Amir-19 Amir-19 force-pushed the nccl_symm_ar_custom_op branch from 41d0e5c to 0bdb252 Compare September 22, 2025 23:58
@mergify mergify bot removed the needs-rebase label Sep 22, 2025
@mgoin mgoin merged commit 8c1c81a into vllm-project:main Sep 23, 2025
78 checks passed
@gshtras
Copy link
Collaborator

gshtras commented Sep 24, 2025

With this PR on ROCm there are now crashes on, among others,
vllm bench latency --model meta-llama/Llama-3.1-70B-Instruct --batch-size 1 --input-len 1024 --output-len 1024 -tp 8 --num-iters-warmup 1 --num-iters 3 --load-format dummy:
HIP error: operation not permitted when stream is capturing during the graph capture
@Amir-19

@Amir-19
Copy link
Contributor Author

Amir-19 commented Sep 24, 2025

With this PR on ROCm there are now crashes on, among others, vllm bench latency --model meta-llama/Llama-3.1-70B-Instruct --batch-size 1 --input-len 1024 --output-len 1024 -tp 8 --num-iters-warmup 1 --num-iters 3 --load-format dummy: HIP error: operation not permitted when stream is capturing during the graph capture @Amir-19

could you share the full log? this feature is behind an env var, so it shouldn't touch the default path.

@gshtras
Copy link
Collaborator

gshtras commented Sep 24, 2025

With this PR on ROCm there are now crashes on, among others, vllm bench latency --model meta-llama/Llama-3.1-70B-Instruct --batch-size 1 --input-len 1024 --output-len 1024 -tp 8 --num-iters-warmup 1 --num-iters 3 --load-format dummy: HIP error: operation not permitted when stream is capturing during the graph capture @Amir-19

could you share the full log? this feature is behind an env var, so it shouldn't touch the default path.

log.txt
I'm trying to find the exact part of the change that triggers it.
Looks like it's one of these lines 8c1c81a#diff-964c170432ace3a43a4e616bc2d5eec9ad45056873411369e24db75bf6c6da39R44-R243
Oddly enough, from a quick test this doesn't happen on an H100 machine, and ROCm used to be more permissive than CUDA in terms of what's allowed during graph capturing, at least until 7.0.

@gshtras
Copy link
Collaborator

gshtras commented Sep 24, 2025

Yeah, it's from exposing ncclCommWindowRegister
My guess is since these functions are new (at least in RCCL), they may not exist in the library version bundled with the ROCm release
Update:
Looks like the case indeed

$ nm -gD /opt/rocm/lib/librccl.so.1.0.70000 | grep "ncclGroupEnd" | wc -l
2
$ nm -gD /opt/rocm/lib/librccl.so.1.0.70000 | grep "ncclCommWindowRegister" | wc -l
0

Can't say, I fully understand how it affects the graph capturing, but removing these functions, assuming nobody tries to use them later on, or wrapping the for loop over NCCLLibrary.exported_functions contents in try-except eliminates the crash

FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants