Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add lora support for Qwen2 #3177

Merged

Conversation

whyiug
Copy link
Contributor

@whyiug whyiug commented Mar 4, 2024

Closes #3054

@simon-mo
Copy link
Collaborator

simon-mo commented Mar 4, 2024

can you post a working model/example?

@Yard1 Yard1 self-requested a review March 4, 2024 19:11
@whyiug whyiug marked this pull request as draft March 5, 2024 02:44
@whyiug
Copy link
Contributor Author

whyiug commented Mar 5, 2024

Qwen is similar to llama's architecture. But its vocab size is 150k and intermediate size is 13696.
I have tested it locally, the result seems normal, but lora is not fully tested. Can you provide some suggestions or conduct more tests? @JustinLin610 @Yard1

server:

CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
		     --trust-remote-code \
		     --max-model-len 4096 \
		     --model ~/qwen/Qwen1.5-14B-Chat  \
                      --enable-lora  \
	              --lora-modules lora1=~/lora/xxx lora2=~/lora/xxx

test:

curl --request POST \
  --url http://localhost:8000/v1/chat/completions \
  --header 'content-type: application/json' \
  --data '{
	"model": "lora2",
	"messages": [ { "role": "system", "content": "You are a helpful assistant." },
		{ "role": "user", "content": "China is a" }
	],
	"stop_token_ids": [151645, 151644, 151643],
	"max_tokens": 5,
	"temperature": 0.7
}'

response:

{"id":"cmpl-7a38720c85814f1faf3d83e1d8718573","object":"chat.completion","created":1122584,"model":"lora2","choices":[{"index":0,"message":{"role":"assistant","content":"country located in East Asia"},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":22,"total_tokens":27,"completion_tokens":5}}

@whyiug whyiug marked this pull request as ready for review March 5, 2024 08:09
@whyiug whyiug changed the title Feature add lora support for qwen Feature add lora support for qwen2 Mar 5, 2024
@whyiug whyiug changed the title Feature add lora support for qwen2 Feature add lora support for qwen/qwen2 Mar 6, 2024
@Yard1
Copy link
Collaborator

Yard1 commented Mar 6, 2024

I think this looks good. Have you verified that the outputs look correct (ie the adapter is being applied)?

@whyiug
Copy link
Contributor Author

whyiug commented Mar 7, 2024 via email

@whyiug
Copy link
Contributor Author

whyiug commented Mar 7, 2024

I think this looks good. Have you verified that the outputs look correct (ie the adapter is being applied)?

Yes, I tested two loras on both qwen and qwen2(14b). The lora returns the same as after merged.
I have added all sizes for qwen and qwen2 to bgmv_config.h.

@WangxuP
Copy link

WangxuP commented Mar 7, 2024

VLLM_INSTALL_PUNICA_KERNELS=1 pip install -e .

________________________________ 发件人: WangxuP @.> 发送时间: 2024年3月6日 23:52 收件人: vllm-project/vllm @.> 抄送: whyiug @.>; Author @.> 主题: Re: [vllm-project/vllm] Feature add lora support for qwen/qwen2 (PR #3177) I downloaded the code for the current branch to be merged myself. Using python setup.py install for compilation and installation, and then loading two Lora models for conversion, the following problem was found. How can I solve it? Traceback (most recent call last): File "/root/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm-0.3.3+cu120-py3.10-linux-x86_64.egg/vllm/lora/punica.py", line 83, in add_lora import vllm._punica_C as punica_kernels ModuleNotFoundError: No module named 'vllm._punica_C' The above exception was the direct cause of the following exception: ImportError: punica LoRA kernels could not be imported. If you built vLLM from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var was set. ― Reply to this email directly, view it on GitHub<#3177 (comment)>, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABS4CUTUMSC4FJW7ASBCS5TYW43SDAVCNFSM6AAAAABEFAJVHCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOBRGE4DOOJVGY. You are receiving this because you authored the thread.Message ID: @.***>

thank you!
After compiling according to the method you provided, did I find that the compiled version does not support V100 GPU?
the error msg is below!

VLLM_INSTALL_PUNICA_KERNELS=1 pip install -e .

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m vllm.entrypoints.openai.api_server --trust-remote-code  --max-model-len 2048  --model /home/models/dem_14b/base/  --enable-lora  --lora-modules lora1=lora1 lora2=lora2

Traceback (most recent call last):
  File "/home/wxp/vllm-feature_add_lora_support_for_qwen/vllm/lora/punica.py", line 83,                                                                                       in add_lora
    import vllm._punica_C as punica_kernels
ModuleNotFoundError: No module named 'vllm._punica_C'
... ...
ImportError: punica LoRA kernels require compute capability >= 8.0

@whyiug
Copy link
Contributor Author

whyiug commented Mar 7, 2024

VLLM_INSTALL_PUNICA_KERNELS=1 pip install -e .

________________________________ 发件人: WangxuP @.> 发送时间: 2024年3月6日 23:52 收件人: vllm-project/vllm _@**._> 抄送: whyiug _@.>; Author @._> 主题: Re: [vllm-project/vllm] Feature add lora support for qwen/qwen2 (PR #3177) I downloaded the code for the current branch to be merged myself. Using python setup.py install for compilation and installation, and then loading two Lora models for conversion, the following problem was found. How can I solve it? Traceback (most recent call last): File "/root/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm-0.3.3+cu120-py3.10-linux-x86_64.egg/vllm/lora/punica.py", line 83, in add_lora import vllm._punica_C as punica_kernels ModuleNotFoundError: No module named 'vllm._punica_C' The above exception was the direct cause of the following exception: ImportError: punica LoRA kernels could not be imported. If you built vLLM from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var was set. ― Reply to this email directly, view it on GitHub<#3177 (comment)>, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABS4CUTUMSC4FJW7ASBCS5TYW43SDAVCNFSM6AAAAABEFAJVHCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOBRGE4DOOJVGY. You are receiving this because you authored the thread.Message ID: _@_.*>

thank you! After compiling according to the method you provided, did I find that the compiled version does not support V100 GPU? the error msg is below!

VLLM_INSTALL_PUNICA_KERNELS=1 pip install -e .

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m vllm.entrypoints.openai.api_server --trust-remote-code  --max-model-len 2048  --model /home/models/dem_14b/base/  --enable-lora  --lora-modules lora1=lora1 lora2=lora2

Traceback (most recent call last):
  File "/home/wxp/vllm-feature_add_lora_support_for_qwen/vllm/lora/punica.py", line 83,                                                                                       in add_lora
    import vllm._punica_C as punica_kernels
ModuleNotFoundError: No module named 'vllm._punica_C'
... ...
ImportError: punica LoRA kernels require compute capability >= 8.0

Yes, V100 is not supported(compute capability = 7.0), same issue: #3197
And you can check out the compute capability here .
Compute Capability
--


@junior-zsy
Copy link

@whyiug The current PR can support Qwen 1.5, but it cannot support Qwen,Because the implementation of the vllm/vllm/lora/layers. py method class QKVParallelLinearWithLora (ColumnParallelLinearWithLoRA) is problematic, if it is qwen, it needs to follow the logic of ColumnParallelLinearWithLoRA instead of QKVParallelLinearWithLora, and the qkv proj weights of qwen are integrated,The qwen model does not include q_proj, k_proj, and v_proj, they are a whole qkv_proj, so your stacked_params mapping=[

#(paramname, shard_name, shard_id)

("qkv proj", "q_proj", "q"),

("qkv_proj", "k_proj", "k"),

("qkv proj", "v proj", "v"),

("gate_up_proj", "w2", 0),

("gate_up_proj", "w1", 1),

]The code is problematic,You can take a Qwen test and you will find the problem,The current code can run Qwen normally, but the effect is not correct. You can fine tune a Lora model and try it out

@whyiug
Copy link
Contributor Author

whyiug commented Mar 7, 2024

@whyiug The current PR can support Qwen 1.5, but it cannot support Qwen,Because the implementation of the vllm/vllm/lora/layers. py method class QKVParallelLinearWithLora (ColumnParallelLinearWithLoRA) is problematic, if it is qwen, it needs to follow the logic of ColumnParallelLinearWithLoRA instead of QKVParallelLinearWithLora, and the qkv proj weights of qwen are integrated,The qwen model does not include q_proj, k_proj, and v_proj, they are a whole qkv_proj, so your stacked_params mapping=[

#(paramname, shard_name, shard_id)

("qkv proj", "q_proj", "q"),

("qkv_proj", "k_proj", "k"),

("qkv proj", "v proj", "v"),

("gate_up_proj", "w2", 0),

("gate_up_proj", "w1", 1),

]The code is problematic,You can take a Qwen test and you will find the problem,The current code can run Qwen normally, but the effect is not correct. You can fine tune a Lora model and try it out

You are right, I found the difference in https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py
I used two LORAs for actual production to verify the effect of qwen1.5. But in qwen, I only saw that it infered normal and did not verify the effect of lora.
Since qwen has been officially no longer maintained, let's just change Qwen 1.5.
I've updated the latest code. @Yard1

@whyiug whyiug changed the title Feature add lora support for qwen/qwen2 Feature add lora support for Qwen2 Mar 7, 2024
@simon-mo simon-mo merged commit c59e120 into vllm-project:main Mar 8, 2024
22 checks passed
@whyiug whyiug deleted the feature_add_lora_support_for_qwen branch March 8, 2024 07:03
@GennVa
Copy link

GennVa commented Mar 11, 2024

@simon-mo If I may ask, will this feature be added in version 0.3.4? Thanks

@simon-mo simon-mo mentioned this pull request Mar 12, 2024
dtransposed pushed a commit to afeldman-nm/vllm that referenced this pull request Mar 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support Lora for qwen2
6 participants