Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel #420

Conversation

naed90
Copy link
Contributor

@naed90 naed90 commented Jul 10, 2023

Instead of having each thread group fetch the query head (which causes 64x memory to be read), we have all threads in the block share the task of loading the query head. On the benchmark of running 1000 sequences through LLaMA13B on an A100 (80GB), this improves the throughput by 1.10x.

 single_query_cached_kv_attention kernel
@naed90 naed90 mentioned this pull request Jul 10, 2023
@naed90
Copy link
Contributor Author

naed90 commented Jul 10, 2023

See #421 for a detailed description and analysis of this commit.

@zhyncs
Copy link
Contributor

zhyncs commented Jul 11, 2023

Hi @naed90

overall LGTM, just had one small nitpick and looks like some formatting issues to address

@naed90
Copy link
Contributor Author

naed90 commented Jul 11, 2023

Hi @naed90

overall LGTM, just had one small nitpick and looks like some formatting issues to address

ty.
can't seem to find your review, can you send a link to it?

@naed90
Copy link
Contributor Author

naed90 commented Jul 13, 2023

@WoosukKwon @zhuohan123 hey, what do you think?

@WoosukKwon
Copy link
Collaborator

Hey @naed90, thanks for submitting the PR and apologies for the late response. I was busy for the last few days. Will take a look your issue and PR today.

@naed90
Copy link
Contributor Author

naed90 commented Jul 18, 2023

Hey @naed90, thanks for submitting the PR and apologies for the late response. I was busy for the last few days. Will take a look your issue and PR today.

@WoosukKwon bump :)

@zhuohan123
Copy link
Member

Tested a bit on the latency side:

Before optimization

$ python benchmark_latency.py --model huggyllama/llama-13b --input-len 128 --output-len 128 --num-iters 20
Namespace(model='huggyllama/llama-13b', tokenizer=None, tensor_parallel_size=1, input_len=128, output_len=128, batch_size=8, n=1, use_beam_search=False, num_iters=20, trust_remote_code=False)
INFO 07-24 21:53:31 llm_engine.py:67] Initializing an LLM engine with config: model='huggyllama/llama-13b', tokenizer='huggyllama/llama-13b', tokenizer_mode=auto, trust_remote_code=False, dtype=t
orch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)
INFO 07-24 21:53:31 tokenizer.py:29] For some LLaMA-based models, initializing the fast tokenizer may take a long time. To eliminate the initialization time, consider using 'hf-internal-testing/l
lama-tokenizer' instead of the original tokenizer.
INFO 07-24 21:54:01 llm_engine.py:183] # GPU blocks: 899, # CPU blocks: 327
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, temperature=1.0, top_p=1.0, top_k=-1, use_beam_search=False, stop=[], ignore_eos=True, max_tokens=128, logprobs=None)
Warming up...
Profiling iterations: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [01:11<00:00,  3.56s/it]
Avg latency: 3.5580986022949217 seconds

After optimization

$ python benchmark_latency.py --model huggyllama/llama-13b --input-len 128 --output-len 128
--num-iters 20
Namespace(model='huggyllama/llama-13b', tokenizer=None, tensor_parallel_size=1, input_len=128, output_len=128, batch_size=8, n=1, use_beam_search=False, num_iters=20, trust_remote_code=False)    INFO 07-24 21:55:36 llm_engine.py:67] Initializing an LLM engine with config: model='huggyllama/llama-13b', tokenizer='huggyllama/llama-13b', tokenizer_mode=auto, trust_remote_code=False, dtype=t
orch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)                                                                                    INFO 07-24 21:55:36 tokenizer.py:29] For some LLaMA-based models, initializing the fast tokenizer may take a long time. To eliminate the initialization time, consider using 'hf-internal-testing/l
lama-tokenizer' instead of the original tokenizer.
INFO 07-24 21:56:08 llm_engine.py:183] # GPU blocks: 899, # CPU blocks: 327
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, temperature=1.0, top_p=1.0, top_k=-1, use_beam_search=False, stop=[], ignore_eos=True, max_tokens=128, logprobs=None)
Warming up...
Profiling iterations: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [01:09<00:00,  3.49s/it]
Avg latency: 3.4891188383102416 seconds

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution! Left some small comments. We should be able to merge this after the changes.

@@ -116,12 +117,15 @@ __global__ void single_query_cached_kv_attention_kernel(
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
if (thread_group_idx <= NUM_THREAD_GROUPS_LOWER_BOUND) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if seems redundant if we assume NUM_THREADS should is divisible by THREAD_GROUP_SIZE?

Copy link
Contributor Author

@naed90 naed90 Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with an assert.

csrc/attention/attention_kernels.cu Outdated Show resolved Hide resolved
OlivierDehaene added a commit to OlivierDehaene/vllm that referenced this pull request Jul 28, 2023
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you again for your hard work and detailed profiling!

@zhuohan123 zhuohan123 merged commit 79af7e9 into vllm-project:main Aug 4, 2023
2 checks passed
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Oct 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants