Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPTQ support #916

Merged
merged 27 commits into from
Dec 15, 2023
Merged

Add GPTQ support #916

merged 27 commits into from
Dec 15, 2023

Conversation

chu-tianxiang
Copy link
Contributor

@chu-tianxiang chu-tianxiang commented Aug 31, 2023

Add GPTQ support to vllm.

Update(0925): the previous implementation is hardly compatible with recent introduced AWQ interface. So I moved the previous implementation to gptq_hf_old branch and reimplemented most parts. Kernel files are copied from AutoGPTQ which itself is modified from exllama-v1
Unfortunately new implementation only supports 4-bit quantization. If you wanna use 3-bit model, please refer to the old branch.

Note:

  • Currently when desc_act=True and group_size!=-1, row parallel layers won't be partitioned. Honestly I haven't found an elegant way to deal with act order without giving up exllama's reorder trick for acceleration. Similar with huggingface TGI, now a separate kernel is used when desc_act=True and group_size!=-1 and world_size > 1
  • Weight loading of 'g_idx' is messy for the same reason.

For old branch

Code is based on the recent integration of AutoGPTQ into Transformers, see this blog. This requires that the config.json contains a quantization_config (example: TheBloke/Llama-2-7b-Chat-GPTQ) which has to manually added for old GPTQ models.

Models can be loaded in the same way as normal.

llm = LLM(model="TheBloke/Llama-2-7b-Chat-GPTQ", tensor_parallel_size=2)

Currently tested models:

  • LLaMA: TheBloke/Llama-2-7b-Chat-GPTQ, TheBloke/Llama-2-13b-Chat-GPTQ, TheBloke/Llama-2-70b-Chat-GPTQ
  • QWen: Qwen/Qwen-7B-Chat-Int4
  • Baichuan: TheBloke/baichuan-7B-GPTQ
  • GPT-2: mlabonne/gpt2-GPTQ-4bit
  • GPT-J: PanEa/dolly-v2-gptj-enhanced-auto-gptq
  • GPT-Neox: TheBloke/stablecode-instruct-alpha-3b-GPTQ
  • Bloom: TheBloke/BLOOMChat-176B-v1-GPTQ
  • Falcon: TheBloke/falcon-7b-instruct-GPTQ, TheBloke/Falcon-180B-Chat-GPTQ
  • InternLM: cczhong/internlm-chat-7b-4bit-gptq
  • GPT-Bigcode: TheBloke/starcoderplus-GPTQ
  • MPT: casperhansen/mpt-7b-8k-chat-gptq
  • Aquila: local quantized model
  • OPT: local quantized model

There're various different configurations for GPTQ and the code has not been thoroughly tested yet.

@chu-tianxiang chu-tianxiang marked this pull request as draft August 31, 2023 04:09
@chu-tianxiang chu-tianxiang changed the title Gptq hf Add GPTQ support Aug 31, 2023
@TheBloke
Copy link

TheBloke commented Sep 2, 2023

Great to see!

Someone ping me when it's merged and I'll mention support in my GPTQ readmes.

@chu-tianxiang chu-tianxiang marked this pull request as ready for review September 5, 2023 09:56
@esmeetu
Copy link
Collaborator

esmeetu commented Sep 5, 2023

@chu-tianxiang Thanks for your efforts! 👍

It should better to add use_safetensors args in api endpoints files. 😄
Besides, reminds maintainer to add auto-gptq and optimum to deps file.

I try this PR, all Fine for me.
Give me a 25% speed boost.

@chu-tianxiang
Copy link
Contributor Author

Thanks for trying this out. use-safetensors is already included in EngineArgs

python -m vllm.entrypoints.openai.api_server --model TheBloke/Llama-2-7b-Chat-GPTQ --use-safetensors

In my local machine (A100-80G), the quantized model is only faster than fp16 when batch_size = 1 and the performance quickly degenerate as batch size increases. The throughput of TheBloke/Llama-2-7b-Chat-GPTQ is 4.68 requests/s which is 15% less than the fp16 counterpart (meta-llama/Llama-2-7b-chat-hf at 5.45 requests/s).
output

@eugenepentland
Copy link

From my testing with offline inference of Open-Orca/OpenOrca-Platypus2-13B vs quantized, quantized is about 3% faster. The test was a batch of 500 prompts, with max_tokens=512.

The biggest difference is I can now run inference on both of my 3090's individually, so it's effectively twice as fast.

@esmeetu
Copy link
Collaborator

esmeetu commented Sep 7, 2023

@chu-tianxiang Does this support the latest Falcon-180B-GPTQ?🙃

@eugenepentland
Copy link

Can you merge this with the latest main branch commit? There is some performance testing I want to do using code-llama which support was just added in main.

@chu-tianxiang
Copy link
Contributor Author

Can you merge this with the latest main branch commit? There is some performance testing I want to do using code-llama which support was just added in main.

I just merged the main branch. Safetensors is enabled by default in main branch now, so use-safetensors argument is no longer needed and removed.

to @esmeetu : I haven't tested Falcon-180B-GPTQ yet. I'll give it a try.

@teraktor2006
Copy link

teraktor2006 commented Sep 9, 2023

Thanks for this. A noob question, I tried to install from pip but seems broken and not installed. Can you tell me how I can install from yours?

Failed to build vllm ERROR: Could not build wheels for vllm, which is required to install pyproject.toml-based projects

@chu-tianxiang
Copy link
Contributor Author

chu-tianxiang commented Sep 11, 2023

Thanks for this. A noob question, I tried to install from pip but seems broken and not installed. Can you tell me how I can install from yours?

Failed to build vllm ERROR: Could not build wheels for vllm, which is required to install pyproject.toml-based projects

I'm not sure what caused the error, this branch can be installed from source with pip

pip install git+https://github.com/chu-tianxiang/vllm-gptq@gptq_hf

Besides you need to install huggingface/optimum and AutoGPTQ to use GPTQ models. There was a bug in AutoGPTQ which was not fixed until recently, so it probably has to be installed from source too.

@WoosukKwon
Copy link
Collaborator

Hi @chu-tianxiang, thanks a million for the PR. We are in the process of merging the AWQ PR (#1032), which includes common interface for different quantization methods. We will get back to this PR after that.

@chu-tianxiang
Copy link
Contributor Author

An initial implementation of GPTQ compatible with the interface of AWQ is added to the gptq_compat branch. Only LLaMA 4-bit quantization is supported.

Following the practice of AWQ, kernels are copied into vLLM so there's no dependency on AutoGPTQ. However AutoGPTQ implements many different kernels to be imported dynamically. For simplicity only the default 4-bit kernel codes are copied.

@gururise
Copy link

Now that AWQ is in, will this PR get merged? Unfortunately, I cannot use AWQ due to its Ampere requirement. Would love to use GPTQ though!

@Trapper4888
Copy link

Trapper4888 commented Sep 24, 2023 via email

@curname
Copy link

curname commented Sep 25, 2023

Thank for your great work! I did some experiments on bigcode model(starcoder), the generation speed is slower than gptq_hf_old, almost twice. Have you done similar experiments?

@chu-tianxiang
Copy link
Contributor Author

Thank for your great work! I did some experiments on bigcode model(starcoder), the generation speed is slower than gptq_hf_old, almost twice. Have you done similar experiments?

Thanks for your experiment. It turns out to be related to the fact that vllm by default pads the sequence length to multiple of 8 here which is exactly the kernel switch threshold for GPTQ. I disabled the padding in old branch but forgot to do the same in the new. It seems the padding has slight negative effect on AWQ as well. I'll do more tests and fix it later.

@pweglik
Copy link

pweglik commented Sep 27, 2023

Hi @chu-tianxiang, please, consider adding 'gtpq' as possible choice in arg_utils.py, and benchmarks. Running code that utilizes AsyncEngineArgs results in error error: argument --quantization/-q: invalid choice: 'gptq' (choose from 'awq', None). Thanks a lot. I can also create a PR to your branch if you'd prefer that

@esmeetu
Copy link
Collaborator

esmeetu commented Sep 27, 2023

@chu-tianxiang When i using gptq model. There is no throughput difference with batch size 1 and 2.
Do you have any ideas?

@pweglik
Copy link

pweglik commented Sep 27, 2023

Another error I found is: File "/root/przemek/miniconda3/envs/documents-demo/lib/python3.10/site-packages/vllm/model_executor/model_loader.py", line 99, in get_model if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION[ KeyError: None when running official example: https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py . I think you need to check whether model_config.quantization is set at all before using it as key in dict or use deafultdict/ get() with default value instead: https://docs.python.org/3/library/collections.html#collections.defaultdict

@@ -266,10 +286,11 @@ def forward(

class OPTForCausalLM(nn.Module):

def __init__(self, config):
def __init__(self, config, quant_config):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quant_config should be optional, or have default value = None

@chu-tianxiang
Copy link
Contributor Author

@pweglik Thank you very much! The issue with loading the model that you mentioned should have been fixed earlier. Regarding the other two problems, I have added gptq to the quantization argument and fixed the bug with OPT model. You're more than welcome to create a PR for any additional problems.

@esmeetu I'm not sure what you mean by throughput of batch size 2. The performance of GPTQ kernels is more likely to be affected by the batch size compared to regular fp16 models.

@chu-tianxiang
Copy link
Contributor Author

While refactoring the GPTQ kernel code, I benchmarked most currently available kernels for 4-bit GEMM. The result is as below.
下载 (4)

Some kernels don't natively support act-order models, so I insert an extra reorder operation (e.g. x = x[:, q_perm]) for those kernels, this is marked as w/ reorder in the graph. size_k and size_n are kept as 4096, while size_m (batch_size) varies.

Benchmarked method include exllamav2(exllamav2 will use dequant & matmul when batch size is above threshold, which behavior is disabled here), gpt-fast, GPTQ act-order(with slight modification of block number), GPTQ sequential, GPTQ triton, AWQ GEMM, AWQ GEMV, Dequant & matmul.

When the batch size is small, exllamav2 is the fastest while gpt-fast is competitive enough (especially considering an extra reorder operation is inserted). When the batch is large enough, simple dequant & matmul is so strong a baseline that it beats all other custom kernels.

While the kernels differ a lot in implementation details, say, some using tensor cores while some don't, AWQ GEMM and GPTQ triton launches way fewer blocks and do more calculation per thread. I think that may be the part of the reason their performance are better at big batch size.

@a1164714
Copy link

a1164714 commented Dec 14, 2023

this mr don't support 8bits?

i got the log

Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chu-tianxiang LGTM! Thanks a million for this great work! And apologies for the very delayed reviews. The GPTQ support would have not been possible without your continuous updates on the PR. Also, thanks for the clean and high-quality code. We'd really really appreciate it. Thanks again for this amazing work!

Comment on lines +281 to +285
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While it doesn't ned to be addressed right now, can we somehow factor out this part? This seems to be error prone.

@WoosukKwon WoosukKwon merged commit 0fbfc4b into vllm-project:main Dec 15, 2023
2 checks passed
@hex-plex
Copy link

hey, i just noticed the recent most push fails on rocm

vllm/setup.py

Line 222 in b81a6a6

"csrc/quantization/gptq/q_gemm.cu",

can you add this line in the setup.py along with awq quantization like so

vllm/setup.py

Line 228 in b81a6a6

vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")

If this feature is suppose to support rocm can you also check as hipify is not able to import #include<hipblas.h>

Thank you

@WoosukKwon
Copy link
Collaborator

@hex-plex Thanks for reporting the issue! We will temporarily set it as CUDA-only feature and add it back once it is tested.

@renwuli
Copy link

renwuli commented Dec 25, 2023

Thanks for trying this out. use-safetensors is already included in EngineArgs

python -m vllm.entrypoints.openai.api_server --model TheBloke/Llama-2-7b-Chat-GPTQ --use-safetensors

In my local machine (A100-80G), the quantized model is only faster than fp16 when batch_size = 1 and the performance quickly degenerate as batch size increases. The throughput of TheBloke/Llama-2-7b-Chat-GPTQ is 4.68 requests/s which is 15% less than the fp16 counterpart (meta-llama/Llama-2-7b-chat-hf at 5.45 requests/s). output

Same result

@guangzlu
Copy link

Hi, I want to load "PanEa/dolly-v2-gptj-enhanced-auto-gptq" and it shows up "KeyError: 'transformer.h.0.attn.qkv_proj.qweight'". Can someone tell me what is happening and how to solve it? Here is my command and error log:
image

@chu-tianxiang
Copy link
Contributor Author

Hi, I want to load "PanEa/dolly-v2-gptj-enhanced-auto-gptq" and it shows up "KeyError: 'transformer.h.0.attn.qkv_proj.qweight'". Can someone tell me what is happening and how to solve it? Here is my command and error log: image

Please add quantization="gptq" argument. Earlier GPTQ models doesn't include the quantization method in config.json so you have to specify it by yourself.

@guangzlu
Copy link

Hi, I want to load "PanEa/dolly-v2-gptj-enhanced-auto-gptq" and it shows up "KeyError: 'transformer.h.0.attn.qkv_proj.qweight'". Can someone tell me what is happening and how to solve it? Here is my command and error log: image

Please add quantization="gptq" argument. Earlier GPTQ models doesn't include the quantization method in config.json so you have to specify it by yourself.

It works! Thank you very much!

@esmeetu
Copy link
Collaborator

esmeetu commented Dec 27, 2023

Hi, @chu-tianxiang. I have a performance-related question. The speed is faster when group_size is set to -1, but slower when it's set to 32. Do they use different kernels? Thanks!
Model(group size = -1): https://huggingface.co/TheBloke/WizardCoder-Python-34B-V1.0-GPTQ/tree/main

Model(group size = 32): https://huggingface.co/TheBloke/WizardCoder-Python-34B-V1.0-GPTQ/tree/gptq-4bit-32g-actorder_True

The g-1'speed is x2 faster than g32.

@xunfeng1980
Copy link

Good job

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
@raduaerpo
Copy link

this mr don't support 8bits?

i got the log

Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.

Encountering the same. did you find any solution to this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet