-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Use flashinfer for decoding #4353
Conversation
Pass the correctness test for single GPU and TP setting. Feel free to take a first pass @rkooo567 |
But we still need to pad inputs for other parts of kernel except attention? (I think flashinfer-ai/flashinfer#187 could be sth nice to add for prefill cuda graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean!! many comments are nits. So it seems like
- not working with prefill
- not working with prefix caching
- not working with chunked prefill
This is correct right? In this case, we should make sure we raise exception properly.
Also, for prefix attn kernel, does flash infer has equivalent?
@LiuXiaoxuanPKU Thank you for this PR! Can we support other prefill backend instead of just the flash attention? like XFormers. |
I have implemented a xformer version but the performance was not good(10%~20% throughout drop). |
I think the goal is to use flash infer prefill eventually, so flash attn is replaced soon! |
The latest support version of python package is 2.2. Yes, we can build from source with torch 2.3. I tested it locally and it passed the tests. |
Let's merge it without adding tests to CI? I think we don't need to be blocked by them (to merge). I will also create an issue to their repo for 2.3 support (we can enable tests when it is officially supported) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments - Thanks for working on this!
data_type: torch.dtype = None | ||
|
||
def __post_init__(self): | ||
if not self.is_prompt: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using flashinfer, we are also creating the FlashInferMetadata, which will also call post_init by default, here we want to skip the post_init if it's the prefill phase.
IMO it's worth putting aNOTE
in the code regarding this explanation.
# Allocate 16MB workspace buffer | ||
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html | ||
self.flashinfer_workspace_buffer = torch.empty( | ||
16 * 1024 * 1024, dtype=torch.uint8, device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious - is there any point of making workspace buffer size configurable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, it's always 16MB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock @LiuXiaoxuanPKU given @rkooo567 's approval.
@ywang96 if you have any more comments provide them and @LiuXiaoxuanPKU can follow up in next PR
AttentionMetadataPerStage) | ||
|
||
|
||
class FlashInferBackend(AttentionBackend): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(can do in future PRs): good to have a small description in docstring so new people can understand what FlashInfer is / why it's useful over other backends / link to learn more
|
||
is_prompt: bool | ||
|
||
use_cuda_graph: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(can do in future PRs): good to add comment for values (e.g. cuda graph not supported yet)
@@ -33,16 +34,19 @@ def test_models( | |||
dtype: str, | |||
max_tokens: int, | |||
) -> None: | |||
enforce_eager = False | |||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(can do in future PR): we should integrate this with #4548
return self.hf_text_config.num_attention_heads // \ | ||
parallel_config.tensor_parallel_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(can do in future PR): need a TODO for vision models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I don't have any other comment ATM - please resolve the conflict though, thanks!
awseom! super excited to see performance with cuda graph! |
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
Hi @LiuXiaoxuanPKU Great work! After switching to the new backend, has there been any performance improvement compared to before? Have you conducted any relevant benchmarks? Thanks. |
Hi @LiuXiaoxuanPKU Is FlashInfer currently enabled by default? After testing the throughput on the ShareGPT dataset, there was no significant improvement on vLLM, and the gap with LMDeploy is still quite large (8.40 vs 19.62). # env
NVIDIA A100-SXM4-80GB
PyTorch: 2.3.0+cu118
vLLM: 0.4.2+cu118
flash-attn: 2.5.8
LMDeploy: 0.4.0+81108ff
# vLLM Server
python3 -m vllm.entrypoints.openai.api_server --model /workdir/Meta-Llama-3-8B-Instruct
# vLLM Client
python3 benchmark_serving.py --backend vllm --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model /workdir/Meta-Llama-3-8B-Instruct
# LMDeploy Server
python3 -m lmdeploy serve api_server /workdir/Meta-Llama-3-8B-Instruct --cache-max-entry-count 0.95
# LMDeploy Client
python3 benchmark_serving.py --backend lmdeploy --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --tokenizer /workdir/Meta-Llama-3-8B-Instruct --model llama3 --port 23333
|
Hi @zhyncs, thanks for the interest and benchmarking, several things here: We don't turn it on by default because of performance concerns:
1/2 are both doable and fixable. We will coordinate with Flashinfer team. After fixing the performance issue, we can turn Flashinfer on by default. |
Thanks for your reply! |
I used to try flashinfer for decoding in vllm, but its perfomance is poorer than FA2.5, i doubt if i make mistake the wrong way, have u compare the perf between fa2.5 and flashinfer in decoding? |
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
hi, when turn on by default? |
i found when qps or seq len is not large , the speed is slower than vllm base decode kernel, may 2 is needed to remove cpu/gpu copy |
@MichoChan I think it is because it doesn't have cuda graph support yet (at large qps, the cpu overhead cuda graph removes is negligible usually) |
Hi @LiuXiaoxuanPKU, Do you have a timeline regarding the support for CUDA Graph? |
We are blocked by flash infer side to enable cuda graph support. Should be able to support after that (I assume 2+weeks to get it delivered ) |
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
This PR is a first attempt to integrate flashinfer for the decoding phase. The PR still uses flash attention for the prefill phase for now.
Updated after discussion with @yzh119
Things need to be fixed:
Next step: