-
-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize MQA Kernel #452
Optimize MQA Kernel #452
Conversation
How hard is it to port MQA to LLaMA models? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thanks for the PR. Left some comments.
assert self.num_heads % self.num_kv_heads == 0 | ||
self.head_mapping = torch.repeat_interleave( | ||
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), | ||
num_heads // self.num_kv_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nit:
assert self.num_heads % self.num_kv_heads == 0 | |
self.head_mapping = torch.repeat_interleave( | |
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), | |
num_heads // self.num_kv_heads) | |
assert self.num_heads % self.num_kv_heads == 0 | |
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | |
self.head_mapping = torch.repeat_interleave( | |
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), | |
self.num_queries_per_kv) |
self.num_queries_per_kv
can be also used in L97 and L100.
Starcoder latency after this PR on 1 GCP A100:
before this PR:
Note that the throughput can be further boosted because of the extra cache blocks provided. |
Hi @zhuohan123 May I ask, why implement the MQA kernel optimization in the |
Hi! MQA can be combined with some other attention modifications, including RoPE and Alibi embeddings. Having a separate class for MQA can introduce many new classes with these embedding variations. |
This PR implements the MQA paged attention kernel and modifies the GPT Bigcode model to utilize the optimized MQA kernel.
TODO: Check performance gain.