Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Use flashinfer for decoding #4353

Merged
merged 24 commits into from
May 3, 2024

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Apr 25, 2024

This PR is a first attempt to integrate flashinfer for the decoding phase. The PR still uses flash attention for the prefill phase for now.

Updated after discussion with @yzh119
Things need to be fixed:

  • Use flashinfer's rope embedding --> Will turn off Flashinfer's rope support and use vllm's rope.
  • Alibi slope --> Flashinfer will support alibi_slope as a input parameter.
  • Cuda Graph support --> Flashinfer will simplify begin_forward function and add cuda graph support, no need to pad from the vllm side.

Next step:

  • Remove backend interface change
  • Add tests for e2e correctness
  • Add tests to check flashinfer integration works with tensor parallel

@LiuXiaoxuanPKU LiuXiaoxuanPKU marked this pull request as draft April 25, 2024 06:07
@rkooo567 rkooo567 self-assigned this Apr 25, 2024
@LiuXiaoxuanPKU LiuXiaoxuanPKU marked this pull request as ready for review April 28, 2024 06:24
@LiuXiaoxuanPKU
Copy link
Collaborator Author

Pass the correctness test for single GPU and TP setting. Feel free to take a first pass @rkooo567

@LiuXiaoxuanPKU LiuXiaoxuanPKU changed the title [WIP][Kernel] Use flashinfer for decoding [Kernel] Use flashinfer for decoding Apr 28, 2024
@rkooo567
Copy link
Collaborator

rkooo567 commented Apr 28, 2024

Cuda Graph support --> Flashinfer will simplify begin_forward function and add cuda graph support, no need to pad from the vllm side.

But we still need to pad inputs for other parts of kernel except attention? (I think flashinfer-ai/flashinfer#187 could be sth nice to add for prefill cuda graph)

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean!! many comments are nits. So it seems like

  1. not working with prefill
  2. not working with prefix caching
  3. not working with chunked prefill

This is correct right? In this case, we should make sure we raise exception properly.

Also, for prefix attn kernel, does flash infer has equivalent?

vllm/attention/backends/flashinfer.py Show resolved Hide resolved
csrc/cache_kernels.cu Show resolved Hide resolved
tests/basic_correctness/test_flashinfer.py Outdated Show resolved Hide resolved
tests/distributed/test_flashinfer_distributed.py Outdated Show resolved Hide resolved
vllm/utils.py Outdated Show resolved Hide resolved
vllm/attention/backends/flashinfer.py Outdated Show resolved Hide resolved
vllm/attention/backends/flashinfer.py Show resolved Hide resolved
vllm/attention/backends/flashinfer.py Show resolved Hide resolved
vllm/attention/backends/flashinfer.py Show resolved Hide resolved
vllm/attention/backends/flashinfer.py Show resolved Hide resolved
@esmeetu
Copy link
Collaborator

esmeetu commented Apr 28, 2024

@LiuXiaoxuanPKU Thank you for this PR! Can we support other prefill backend instead of just the flash attention? like XFormers.

@esmeetu
Copy link
Collaborator

esmeetu commented Apr 29, 2024

@LiuXiaoxuanPKU Thank you for this PR! Can we support other prefill backend instead of just the flash attention? like XFormers.

I have implemented a xformer version but the performance was not good(10%~20% throughout drop).

@rkooo567
Copy link
Collaborator

I think the goal is to use flash infer prefill eventually, so flash attn is replaced soon!

@LiuXiaoxuanPKU
Copy link
Collaborator Author

@LiuXiaoxuanPKU what's the latest supported version from flash infer? Also, is this something we can just simply build it with torch 2.3?

The latest support version of python package is 2.2. Yes, we can build from source with torch 2.3. I tested it locally and it passed the tests.

@rkooo567
Copy link
Collaborator

rkooo567 commented May 2, 2024

Let's merge it without adding tests to CI? I think we don't need to be blocked by them (to merge). I will also create an issue to their repo for 2.3 support (we can enable tests when it is officially supported)

@rkooo567
Copy link
Collaborator

rkooo567 commented May 2, 2024

Copy link
Collaborator

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments - Thanks for working on this!

vllm/attention/backends/flashinfer.py Outdated Show resolved Hide resolved
vllm/attention/backends/flashinfer.py Outdated Show resolved Hide resolved
data_type: torch.dtype = None

def __post_init__(self):
if not self.is_prompt:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using flashinfer, we are also creating the FlashInferMetadata, which will also call post_init by default, here we want to skip the post_init if it's the prefill phase.

IMO it's worth putting aNOTE in the code regarding this explanation.

Comment on lines +565 to +568
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
self.flashinfer_workspace_buffer = torch.empty(
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious - is there any point of making workspace buffer size configurable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, it's always 16MB

vllm/attention/selector.py Outdated Show resolved Hide resolved
vllm/sequence.py Show resolved Hide resolved
Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock @LiuXiaoxuanPKU given @rkooo567 's approval.

@ywang96 if you have any more comments provide them and @LiuXiaoxuanPKU can follow up in next PR

AttentionMetadataPerStage)


class FlashInferBackend(AttentionBackend):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do in future PRs): good to have a small description in docstring so new people can understand what FlashInfer is / why it's useful over other backends / link to learn more


is_prompt: bool

use_cuda_graph: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do in future PRs): good to add comment for values (e.g. cuda graph not supported yet)

@@ -33,16 +34,19 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
enforce_eager = False
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do in future PR): we should integrate this with #4548

Comment on lines +299 to +300
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do in future PR): need a TODO for vision models?

Copy link
Collaborator

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I don't have any other comment ATM - please resolve the conflict though, thanks!

@cadedaniel cadedaniel merged commit 43c413e into vllm-project:main May 3, 2024
59 checks passed
@rkooo567
Copy link
Collaborator

rkooo567 commented May 4, 2024

awseom! super excited to see performance with cuda graph!

robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
@zhyncs
Copy link

zhyncs commented May 6, 2024

Hi @LiuXiaoxuanPKU Great work! After switching to the new backend, has there been any performance improvement compared to before? Have you conducted any relevant benchmarks? Thanks.

@zhyncs
Copy link

zhyncs commented May 6, 2024

Hi @LiuXiaoxuanPKU Is FlashInfer currently enabled by default? After testing the throughput on the ShareGPT dataset, there was no significant improvement on vLLM, and the gap with LMDeploy is still quite large (8.40 vs 19.62).
I'm not sure if there is a configuration error or something I haven't considered. Here is the reproduction method, please let me know if there are any mistakes. Thanks.

# env
NVIDIA A100-SXM4-80GB
PyTorch: 2.3.0+cu118
vLLM: 0.4.2+cu118
flash-attn: 2.5.8
LMDeploy: 0.4.0+81108ff

# vLLM Server
python3 -m vllm.entrypoints.openai.api_server --model /workdir/Meta-Llama-3-8B-Instruct
# vLLM Client
python3 benchmark_serving.py --backend vllm --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model /workdir/Meta-Llama-3-8B-Instruct

# LMDeploy Server
python3 -m lmdeploy serve api_server /workdir/Meta-Llama-3-8B-Instruct --cache-max-entry-count 0.95
# LMDeploy Client
python3 benchmark_serving.py --backend lmdeploy --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --tokenizer /workdir/Meta-Llama-3-8B-Instruct --model llama3 --port 23333
# vLLM Result

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  119.11
Total input tokens:                      215196
Total generated tokens:                  186473
Request throughput (req/s):              8.40
Input token throughput (tok/s):          1806.76
Output token throughput (tok/s):         1565.61
---------------Time to First Token----------------
Mean TTFT (ms):                          32831.19
Median TTFT (ms):                        24628.54
P99 TTFT (ms):                           88550.36
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          154.28
Median TPOT (ms):                        149.54
P99 TPOT (ms):                           406.19
==================================================

# LMDeploy Result

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  50.98
Total input tokens:                      215196
Total generated tokens:                  187514
Request throughput (req/s):              19.62
Input token throughput (tok/s):          4221.44
Output token throughput (tok/s):         3678.41
---------------Time to First Token----------------
Mean TTFT (ms):                          18022.59
Median TTFT (ms):                        16949.77
P99 TTFT (ms):                           39935.77
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          30.41
Median TPOT (ms):                        28.74
P99 TPOT (ms):                           77.57
==================================================

@LiuXiaoxuanPKU
Copy link
Collaborator Author

Hi @zhyncs, thanks for the interest and benchmarking, several things here:
FlashInfer is not turned on by default, it can only be enabled with environment variable VLLM_ATTENTION_BACKEND=FLASHINFER.

We don't turn it on by default because of performance concerns:

  1. We don't have cuda graph support for flashinfer yet. Without cudagraph, it might be hard to see any performance benefits, though we have not thoroughly benchmarked it yet.
  2. Flashinfer's begin_forward function will cause some extra CPU/GPU communication, which will be optimized from the flashinfer side.

1/2 are both doable and fixable. We will coordinate with Flashinfer team. After fixing the performance issue, we can turn Flashinfer on by default.

@zhyncs
Copy link

zhyncs commented May 6, 2024

Hi @zhyncs, thanks for the interest and benchmarking, several things here:

FlashInfer is not turned on by default, it can only be enabled with environment variable VLLM_ATTENTION_BACKEND=FLASHINFER.

We don't turn it on by default because of performance concerns:

  1. We don't have cuda graph support for flashinfer yet. Without cudagraph, it might be hard to see any performance benefits, though we have not thoroughly benchmarked it yet.

  2. Flashinfer's begin_forward function will cause some extra CPU/GPU communication, which will be optimized from the flashinfer side.

1/2 are both doable and fixable. We will coordinate with Flashinfer team. After fixing the performance issue, we can turn Flashinfer on by default.

Thanks for your reply!

@Qiubo1
Copy link

Qiubo1 commented May 7, 2024

I used to try flashinfer for decoding in vllm, but its perfomance is poorer than FA2.5, i doubt if i make mistake the wrong way, have u compare the perf between fa2.5 and flashinfer in decoding?

z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
@MichoChan
Copy link

Hi @zhyncs, thanks for the interest and benchmarking, several things here: FlashInfer is not turned on by default, it can only be enabled with environment variable VLLM_ATTENTION_BACKEND=FLASHINFER.

We don't turn it on by default because of performance concerns:

  1. We don't have cuda graph support for flashinfer yet. Without cudagraph, it might be hard to see any performance benefits, though we have not thoroughly benchmarked it yet.
  2. Flashinfer's begin_forward function will cause some extra CPU/GPU communication, which will be optimized from the flashinfer side.

1/2 are both doable and fixable. We will coordinate with Flashinfer team. After fixing the performance issue, we can turn Flashinfer on by default.

hi, when turn on by default?

@MichoChan
Copy link

Hi @zhyncs, thanks for the interest and benchmarking, several things here: FlashInfer is not turned on by default, it can only be enabled with environment variable VLLM_ATTENTION_BACKEND=FLASHINFER.
We don't turn it on by default because of performance concerns:

  1. We don't have cuda graph support for flashinfer yet. Without cudagraph, it might be hard to see any performance benefits, though we have not thoroughly benchmarked it yet.
  2. Flashinfer's begin_forward function will cause some extra CPU/GPU communication, which will be optimized from the flashinfer side.

1/2 are both doable and fixable. We will coordinate with Flashinfer team. After fixing the performance issue, we can turn Flashinfer on by default.

hi, when turn on by default?

i found when qps or seq len is not large , the speed is slower than vllm base decode kernel, may 2 is needed to remove cpu/gpu copy

@rkooo567
Copy link
Collaborator

@MichoChan I think it is because it doesn't have cuda graph support yet (at large qps, the cpu overhead cuda graph removes is negligible usually)

@Calculusss
Copy link

Calculusss commented May 24, 2024

Hi @zhyncs, thanks for the interest and benchmarking, several things here: FlashInfer is not turned on by default, it can only be enabled with environment variable VLLM_ATTENTION_BACKEND=FLASHINFER.

We don't turn it on by default because of performance concerns:

  1. We don't have cuda graph support for flashinfer yet. Without cudagraph, it might be hard to see any performance benefits, though we have not thoroughly benchmarked it yet.
  2. Flashinfer's begin_forward function will cause some extra CPU/GPU communication, which will be optimized from the flashinfer side.

1/2 are both doable and fixable. We will coordinate with Flashinfer team. After fixing the performance issue, we can turn Flashinfer on by default.

Hi @LiuXiaoxuanPKU, Do you have a timeline regarding the support for CUDA Graph?

@rkooo567
Copy link
Collaborator

We are blocked by flash infer side to enable cuda graph support. Should be able to support after that (I assume 2+weeks to get it delivered )

mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants