Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize MQA Kernel #452

Merged
merged 9 commits into from
Jul 15, 2023
Merged

Optimize MQA Kernel #452

merged 9 commits into from
Jul 15, 2023

Conversation

zhuohan123
Copy link
Member

This PR implements the MQA paged attention kernel and modifies the GPT Bigcode model to utilize the optimized MQA kernel.

TODO: Check performance gain.

@nivibilla
Copy link

How hard is it to port MQA to LLaMA models?

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks for the PR. Left some comments.

vllm/config.py Outdated Show resolved Hide resolved
Comment on lines 59 to 62
assert self.num_heads % self.num_kv_heads == 0
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
num_heads // self.num_kv_heads)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit:

Suggested change
assert self.num_heads % self.num_kv_heads == 0
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
num_heads // self.num_kv_heads)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)

self.num_queries_per_kv can be also used in L97 and L100.

@zhuohan123
Copy link
Member Author

zhuohan123 commented Jul 17, 2023

Starcoder latency after this PR on 1 GCP A100:

$ python benchmark_latency.py --model bigcode/starcoder --batch-size 1 --input-len 128 --output-len 128 --num-iters 1
Namespace(model='bigcode/starcoder', tokenizer=None, tensor_parallel_size=1, input_len=128, output_len=128, batch_size=1, n=1, use_beam_search=False, num_iters=1, profile=False)
INFO 07-17 07:23:28 llm_engine.py:60] Initializing an LLM engine with config: model='bigcode/starcoder', tokenizer='bigcode/starcoder', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)
INFO 07-17 07:25:40 llm_engine.py:134] # GPU blocks: 20280, # CPU blocks: 13107
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, temperature=1.0, top_p=1.0, top_k=-1, use_beam_search=False, stop=[], ignore_eos=True, max_tokens=128, logprobs=None)
Avg latency: 3.6188764572143555 seconds

before this PR:

(baseline-branch) ubuntu@zhuohan-vllm-4-manual:~/nfs/cacheflow/base-branch/vllm/benchmarks$ python benchmark_latency.py --model bigcode/starcoder --batch-size 1 --input-len 128 --output-len 128
--num-iters 1
Namespace(model='bigcode/starcoder', tokenizer=None, tensor_parallel_size=1, input_len=128, output_len=128, batch_size=1, n=1, use_beam_search=False, num_iters=1)
INFO 07-17 07:28:23 llm_engine.py:60] Initializing an LLM engine with config: model='bigcode/starcoder', tokenizer='bigcode/starcoder', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)
INFO 07-17 07:30:55 llm_engine.py:134] # GPU blocks: 49, # CPU blocks: 273
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, temperature=1.0, top_p=1.0, top_k=-1, use_beam_search=False, stop=[], ignore_eos=True, max_tokens=128, logprobs=None)
Avg latency: 4.571176528930664 seconds

Note that the throughput can be further boosted because of the extra cache blocks provided.

@zhyncs
Copy link
Contributor

zhyncs commented Jul 17, 2023

Hi @zhuohan123

May I ask, why implement the MQA kernel optimization in the single_query_cached_kv_attention_kernel instead of abstracting a kernel separately

@zhuohan123 zhuohan123 deleted the mqa-optimization branch July 18, 2023 22:18
@zhuohan123
Copy link
Member Author

Hi @zhuohan123

May I ask, why implement the MQA kernel optimization in the single_query_cached_kv_attention_kernel instead of abstracting a kernel separately

Hi! MQA can be combined with some other attention modifications, including RoPE and Alibi embeddings. Having a separate class for MQA can introduce many new classes with these embedding variations.

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants