Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DeciLM-7b and DeciLM-7b-instruct #2062

Merged
merged 6 commits into from
Dec 19, 2023

Conversation

avideci
Copy link
Contributor

@avideci avideci commented Dec 12, 2023

New models: DeciLM-7b and DeciLM-7b-instruct

DeciLM-7b and DeciLM-7b-instruct have been released today.
It has reached first place in the OpenLLM leaderboard (7B catagory).

Screenshot 2023-12-12 at 19 07 07

DeciLM-7B is a 7.04 billion parameter decoder-only text generation model, released under the Apache 2.0 license. At the time of release, DeciLM-7B is the top-performing 7B base language model on the Open LLM Leaderboard. With support for an 8K-token sequence length, this highly efficient model uses variable Grouped-Query Attention (GQA) to achieve a superior balance between accuracy and computational efficiency. The model's architecture was generated using Deci's proprietary Neural Architecture Search technology, AutoNAC.

Screenshot 2023-12-12 at 19 02 55

About the model

DeciLM uses a similar architecture to Llama, thus it was pretty easy to make it work with vLLM.

Deci AI has used AutoNAC (Neural Architecture Search engine) to find this architecture automatically.

The difference is, that DeciLM uses variable grouped query attention ("Variable GQA") instead of uniform grouped query attention (simply "GQA").

The main differene can be spotted in the model configuration:

Useful References:

Future Work

I have tried to modify the code to use variable GQA, but it seems like the Paged Attention machanism relies on static/const number of kv heads. With Variable GQA I tried pass each llama layer a layer_idx, and chose the number of heads for each layer, like in this configuration:

With variable GQA my code worked, but the model outputs gibberish - other than the first word.
As a temporary workaround, I have added the logic of converting variable GQA to uniform GQA, in the load_weights() method. After this fix, the attention kernels worked as expected and the outputs of the model are great.


Is there any reason variable GQA would not work in paged attention?

Can you think of anything related that would prevent the model from returning valid outputs, while the kernel works without errors?
The cuda kernel's python launcher have passed all the assertions of sizes and boundaries, and still returned gibberish after the first token.
The current patch in "load_weights" that degroups each kv head fixes that for now, but the model gets a latency hit since the weights are repeated redundantly. Since you know the PagedAttention kernels better than me, I hope there is a solution that will enable us to use a different number of kv_heads per layer. We would appreciate your help!

Thanks in advance!

config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we are convertin the model to uniform GQA instead of variable GQA.
That's becasue PagedAttention kernel did not work well with variable GQA (not idea why, still) and this is a workaround.

We choose the max. number spotted in this list: https://huggingface.co/Deci/DeciLM-7B/blob/main/config.json#L18

Then we convert weights of "k_proj" and "v_proj" a uniform GQA using "degroup_weight".

default_weight_loader)
weight_loader(param, loaded_weight)

def degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method receives a weight and change the number of kv heads to meet the max. number that is found in the list:
https://huggingface.co/Deci/DeciLM-7B/blob/main/config.json#L18

By doing so, all the attention layers will have the same number of heads, which is required for PagedAttention.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks for this solution. Let's leave it as future work for now.

@WoosukKwon WoosukKwon added the new model Requests to new models label Dec 13, 2023
@avideci
Copy link
Contributor Author

avideci commented Dec 18, 2023

@WoosukKwon WoosukKwon self-requested a review December 19, 2023 09:23
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @avideci, thanks for adding this model to vLLM, and apologies for the late review. We actually found this model interesting, but didn't have bandwidth last week.

To accelerate the integration, I directly made some minor changes on the PR (mostly about the code styles). Thanks again for contributing the model, and looking forward to your next model!

@WoosukKwon WoosukKwon merged commit de60a3f into vllm-project:main Dec 19, 2023
2 checks passed
rkooo567 added a commit to rkooo567/vllm that referenced this pull request Dec 19, 2023
ip

Added DeciLM-7b and DeciLM-7b-instruct (vllm-project#2062)

.
@TobyGE
Copy link

TobyGE commented Jan 3, 2024

I tried DeciLM with vllm, the inference speed is similar to other 7b llms. Does anyone have similar observations?

@geifmany
Copy link

I tried DeciLM with vllm, the inference speed is similar to other 7b llms. Does anyone have similar observations?

It leverages the same kernels so it will have similar performance, the HF implementation of DeciLM is around 2X faster than HF Mistral on large batches.

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Requests to new models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants