Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom all reduce kernels #2192

Merged
merged 61 commits into from
Jan 27, 2024
Merged

Custom all reduce kernels #2192

merged 61 commits into from
Jan 27, 2024

Conversation

hanzhi713
Copy link
Contributor

@hanzhi713 hanzhi713 commented Dec 19, 2023

See this doc for detailed writeup and experiments

Latency-optimal allreduce and cuda graph optimization.pdf

Latency and memory

Tested with

python benchmarks/benchmark_latency.py \
--model NAME -tp TP --input-len 256 --output-len 256 --batch-size BS --num-iters 5

L = Latency, M = Memory

Model GPU TP BS L before (s) L after (s) M before (MB) M after (MB)
Llama 70B A100-80G 4 64 16.37 14.57 76723 75717
Llama 33B A100-80G 2 64 15.88 14.38 75739 74741
Llama 13B A30 (no nvlink) 2 32 10.43 9.79 23519 22911
Llama 7B T4 2 32 17.90 17.51 14969 14475

Hypothesis on why memory usage is lower with fast allreduce:

  • NCCL's internal buffer is captured in the graph
  • NCCL requires inserting more nodes per invocation. For example, NCCL requires a few host nodes to ensure proper operations.

Throughput

Model GPU TP Throughput before Throughput after
Llama 70B A100-80G 4 3.68 requests/s, 1761.33 tokens/s 3.87 requests/s, 1852.65 tokens/s

Performance and memory note

  1. NVswitch based systems should observe higher performance improvement than PCIe systems. Generally, the faster the link, the higher the performance improvement.
  2. Latency improvement is more significant at smaller batch sizes, when allreduce is more latency bound.
  3. The smaller memory overhead of fast allreduce can lead to higher throughput and alleviate some OOM issues when GPU memory budget is tight (e.g. serving 33b with 4xA30).

Implementation note

Since I originally implemented fast allreduce on top of my own fork, I made some changes compared to the original one in the doc. Note that the performance numbers in the writeup doc are not valid because my fork differs significantly from the upstream. Main changes are

  • No fusion with residual connection: this is because it's already fused with layernorm
  • No cuda graph replay optimizations. @WoosukKwon's cuda graph implementation uses a single graph launch per model only (mine needs one per layer), so that's probably not necessary.

There are also extensive effort made to make it work with cuda graph automatically (automatic IPC buffer registration). My previous implementation requires manually allocating a global buffer and changing model code to write matmul's output to it.

The one-hop and two-hop all reduce kernels work very similar to Nvidia TensorRT-LLM's kernels (https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.5.0/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu). However, there were developed independently before TensorRT-LLM's release

Note on some source files

  1. fast_allreduce.cuh is the implementation without pytorch dependencies. You can copy this file to compiler explorer and check the PTX/SASS
  2. fast_allreduce_test.cu is a C++ test for performance and accuracy comparison between NCCL and my implementation. It's fast to compile compared to the torch extension. Code there isn't very neat.

Caveats

Compared to NCCL allreduce, there are some caveats for the fast allreduce path.

  1. Only work for tensor whose byte size is multiple of 16
  2. Can only work out-of-place for now.
  3. Doesn't work with hybrid parallelism for now (e.g. TP + PP). I don't know if there are planned to be supported with vLLM

1 should be automatically handled and checked. 2 should be a non-issue since all usage of tensor_model_parallelism uses its return value.

TODOs

  • add configuration option
  • more end-to-end performance testing on other GPUs, model sizes and TP configs
  • end-to-end correctness test with models
  • format code
  • [ ] (maybe) nit: bind C++ class properly with pybind (not using C style binding) Since we don't want to introduce pytorch dependencies to the header file, we need an additional layer of wrapper anyway.

@hanzhi713
Copy link
Contributor Author

It's not quite ready to merge. I'm requesting for comments.

cc @WoosukKwon @simon-mo

@WoosukKwon
Copy link
Collaborator

@hanzhi713 This is awesome! Many thanks for the PR! A quick question: do you happen to know about the custom all-reduce kernels in TRT-LLM? Is this PR related to the kernel?

@hanzhi713
Copy link
Contributor Author

This is included in the PR description

The one-hop and two-hop all reduce kernels work very similar to Nvidia TensorRT-LLM's kernels (https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.5.0/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu). However, there were developed independently before TensorRT-LLM's release

@hanzhi713
Copy link
Contributor Author

hanzhi713 commented Dec 21, 2023

@WoosukKwon Correctness and functionality wise this PR should be ready. Checked a few models and there are only occasional generation differences (due to numerical differences). See the diff below for reference. Left is without fast allreduce and right is with fast allreduce.

https://www.diffchecker.com/hiJejMpy/

Tested with

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
] * 32
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=32)

# Create an LLM.
llm = LLM(model="TheBloke/Llama-2-70B-fp16", tensor_parallel_size=8, disable_fast_allreduce=True) # or False for fast allreduce
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

@hanzhi713 hanzhi713 changed the title [WIP] Custom all reduce kernels Custom all reduce kernels Dec 21, 2023
@WoosukKwon WoosukKwon self-requested a review December 21, 2023 07:09
@casper-hansen
Copy link
Contributor

A 5% throughput improvement is quite impressive from optimizing all reduce with custom kernels. Well done!

@hanzhi713
Copy link
Contributor Author

A 5% throughput improvement is quite impressive from optimizing all reduce with custom kernels. Well done!

Yes, considering this is mainly an latency optimization

@scv119
Copy link
Contributor

scv119 commented Dec 26, 2023

@hanzhi713 have you compared pytorch/pytorch#114001 with your custom reduce ops?

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @hanzhi713, thanks again for the awesome PR! Did one third of review, mostly on code style. Will look into the actual implementation.

setup.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/model_executor/parallel_utils/fast_allreduce.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/model_executor/parallel_utils/communication_op.py Outdated Show resolved Hide resolved
vllm/model_executor/parallel_utils/communication_op.py Outdated Show resolved Hide resolved
tests/distributed/test_fast_allreduce.py Outdated Show resolved Hide resolved
vllm/model_executor/parallel_utils/fast_allreduce.py Outdated Show resolved Hide resolved
csrc/fast_allreduce_test.cu Outdated Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator

@hanzhi713 BTW I got this error when using 2 L4 GPUs:

(RayWorkerVllm pid=51757) INFO 12-26 04:18:45 fast_allreduce.py:21] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped
(RayWorkerVllm pid=51757) Failed: Cuda error /home/gcpuser/workspace/vllm/csrc/fast_allreduce.cuh:368 'peer access is not supported between these two devices'

@hanzhi713
Copy link
Contributor Author

@hanzhi713 BTW I got this error when using 2 L4 GPUs:

(RayWorkerVllm pid=51757) INFO 12-26 04:18:45 fast_allreduce.py:21] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped
(RayWorkerVllm pid=51757) Failed: Cuda error /home/gcpuser/workspace/vllm/csrc/fast_allreduce.cuh:368 'peer access is not supported between these two devices'

I guess I have to check this. While all topologies that I have access to support P2P, some platforms don't.

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@hanzhi713
Copy link
Contributor Author

You shouldn't remove this. I used cuda driver's API and must link cuda driver for it to work.

@hanzhi713 Yeah I just realized it. Thanks for letting us know. BTW, as you might have noticed, I'm making minor changes (mostly code styles and imports) to accelerate the merge. I will push my review comments soon. Again, apologies for the delay.

Haha it's fine. Writing review is often slower than modifying the code directly

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @hanzhi713, thanks again for the great work!

I left some comments mainly on the C++ part, as I still need a bit more time to complete my review on the CUDA kernel part. Overall, I learned a lot while reading your code and really appreciate it. However, it seems the code can be improved in terms of simplicity. Please take a look at my review comments.

vllm/config.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
tests/distributed/test_comm_ops.py Outdated Show resolved Hide resolved
vllm/model_executor/parallel_utils/custom_all_reduce.py Outdated Show resolved Hide resolved
csrc/custom_all_reduce.cuh Show resolved Hide resolved
csrc/custom_all_reduce.cuh Show resolved Hide resolved
vllm/model_executor/parallel_utils/custom_all_reduce.py Outdated Show resolved Hide resolved
csrc/custom_all_reduce.cuh Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: Can we port this test to Python so that we can use Ray instead of MPI? This change would also make it easier to include this test into our CI.

Copy link
Contributor Author

@hanzhi713 hanzhi713 Jan 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be possible, but it's tricky to get the performance measurement correct in Python, especially for NCCL kernels. Each kernel's runtime is so short (<=10us for the smallest size) that removing any overhead is important.


_CA_HANDLE = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: Why should the number of GPUs be even? Which part of the code should we fix if we want to support odd number of GPUs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I think my kernels do support old #GPUs. I just never test them. Do other parts of vLLM support old number of GPUs (e.g. tensor parallel linear)?

Comment on lines +513 to +521
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't fully understand the underlying principle behind this kernel selection process. How did you set the thresholds (512KB and 256KB)? Why are the thresholds different for ngpus=4 and ngpus=8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
Thresholds are results after tuning. Different thresholds result from the latency-bandwidth trade-off between the two methods.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Many thanks again for submitting the PR and addressing the reviews. I made minor changes again to merge the PR asap. Hope this is ok with you.

Besides, there are a few items we'd hope to work on this thread:

  1. Enabling this optimization for cloud A10G/L4 GPUs. For some reason, CUDA P2P access is not possible for these GPUs in cloud environments. We need to investigate the problem.
  2. Refactoring the code for FastAllreduce initialization and buffer registration. I feel we can simplify this further.

We'd be happy to have your (and anyone's) inputs to the above items!

@WoosukKwon WoosukKwon merged commit 3801700 into vllm-project:main Jan 27, 2024
17 checks passed
NikolaBorisov pushed a commit to deepinfra/vllm that referenced this pull request Jan 31, 2024
@hanzhi713
Copy link
Contributor Author

hanzhi713 commented Feb 1, 2024

@Yard1 I noticed your comment on multiple captures. On my end, multiple captures work and will produce correct results (using examples/offline_inference.py). Also, my unit test (tests/distributed/test_custom_all_reduce.py) uses multiple capture too.

I'm not quite sure how you're using it. I just moved with custom_all_reduce.capture(): to the first line of CUDAGraphRunner.capture.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants