Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Caching- fix t4 triton error #2517

Merged
merged 5 commits into from
Feb 16, 2024
Merged

Conversation

caoshiyi
Copy link
Contributor

Fix #2513, need a smaller block size for Turing GPUs

@@ -5,6 +5,8 @@
import triton
import triton.language as tl

TESLA = 'Tesla' in torch.cuda.get_device_name(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to check for compute capability instead? also, we should do this inside context_attention_fwd, as calling CUDA APIs before we set CUDA_VISIBLE_DEVICES will lead to errors.

Copy link
Collaborator

@esmeetu esmeetu Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can set prefix_block_size as a parameter in CacheConfig and allow user configure in LLM?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sort of a thing should be ideally derived automatically.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 @caoshiyi Does the block size affect the memory utilization or prefix speed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@esmeetu The block size is mainly dependent on the shared mem size for different GPU architectures. It will affect the prefix-prefill kernel speed a little bit but has nothing to do with the GPU memory utilization.

@esmeetu
Copy link
Collaborator

esmeetu commented Jan 20, 2024

Amazing! @caoshiyi Thanks for your help! This is good for me now and speed is indeed a x2-x3 speedup. But when doing further testing, i encountering engine stuck issue when GPU KV Cache is full (i change prefix 5~6 times). And the request is always at the pending state.

INFO 01-20 19:58:39 llm_engine.py:823] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 1 reqs, GPU KV cache usage: 94.3%, CPU KV cache usage: 0.0%

After one more change for prefix(will take up >10% KV cache), the engine will stuck.
So when will the Prefix will release the KV Cache?

@esmeetu
Copy link
Collaborator

esmeetu commented Jan 20, 2024

#2511 looks solving my second issue.

Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left a small comment.

@@ -5,6 +5,8 @@
import triton
import triton.language as tl

TESLA = 'Tesla' in torch.cuda.get_device_name(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set this variable in a function instead of a global variable? Setting it in global variable may lead to issues in distributed setting.

@WoosukKwon
Copy link
Collaborator

@caoshiyi What is the blocker to this PR? Could you address @Yard1 and @zhuohan123's comments?

@WoosukKwon WoosukKwon mentioned this pull request Feb 15, 2024
5 tasks
@@ -618,7 +618,9 @@ def context_attention_fwd(q,
b_ctx_len,
max_input_len,
alibi_slopes=None):
BLOCK = 128

cap = torch.cuda.get_device_capability()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does prefix caching adapt other hardware? like AMD? This only considers cuda arch. Might it better that we define a global utility to get block size which handles different hardwares.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this kernel only works for NVIDIA right now. Let me merge this fix first and we can systematically test for AMD later.

Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@zhuohan123 zhuohan123 merged commit 64da65b into vllm-project:main Feb 16, 2024
17 checks passed
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 22, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

prefix caching error with baichuan model
5 participants