-
-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel #420
[OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel #420
Conversation
single_query_cached_kv_attention kernel
See #421 for a detailed description and analysis of this commit. |
Hi @naed90 overall LGTM, just had one small nitpick and looks like some formatting issues to address |
ty. |
@WoosukKwon @zhuohan123 hey, what do you think? |
Hey @naed90, thanks for submitting the PR and apologies for the late response. I was busy for the last few days. Will take a look your issue and PR today. |
@WoosukKwon bump :) |
Tested a bit on the latency side: Before optimization
After optimization
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution! Left some small comments. We should be able to merge this after the changes.
csrc/attention/attention_kernels.cu
Outdated
@@ -116,12 +117,15 @@ __global__ void single_query_cached_kv_attention_kernel( | |||
// th vectors of the query, and so on. | |||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. | |||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; | |||
Q_vec q_vecs[NUM_VECS_PER_THREAD]; | |||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; | |||
if (thread_group_idx <= NUM_THREAD_GROUPS_LOWER_BOUND) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if
seems redundant if we assume NUM_THREADS
should is divisible by THREAD_GROUP_SIZE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with an assert.
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you again for your hard work and detailed profiling!
Instead of having each thread group fetch the query head (which causes 64x memory to be read), we have all threads in the block share the task of loading the query head. On the benchmark of running 1000 sequences through LLaMA13B on an A100 (80GB), this improves the throughput by 1.10x.