Skip to content

Conversation

@xxxxyu
Copy link
Contributor

@xxxxyu xxxxyu commented Oct 2, 2025

Purpose

Add support for GPTQv2 format model checkpoints, by adapting gptq_gemm CUDA kernel with correct zero point handling. The condition is determined by "checkpoint_format": "gptq_v2" in quantize_config.json.

Currently, vllm treats GPTQv2 and GPTQv1 (the default GPTQ) format checkpoints equally, and causes gibberish output (repeated !!!!) with GPTQv2 format checkpoints (details in #26343).

What is GPTQv2?

GPTQv2 is both a quantization algorithm and a checkpoint format. This PR adds support for the GPTQv2 checkpoint format, not the quantization algorithm. Specifically:

  • As a quantization algorithm: GPTQv2 (i.e., https://github.com/Intelligent-Computing-Lab-Panda/GPTAQ) provides potentially higher model accuracy than GPTQv1 (i.e., the default GPTQ). It is integrated to the GPTQModel (https://github.com/ModelCloud/GPTQModel) library as an experimental feature (enabled by setting v2=True).
  • As a checkpoint format: GPTQv2 adopts a slightly different checkpoint format with GPTQv1, in how to store the zero points of quantized weights. GPTQModel supports both GPTQv1 (by default) and GPTQv2 format (enabled by setting format='gptq_v2'). Checkpoints in GPTQv2 format will show "checkpoint_format": "gptq_v2" in quantize_config.json.

Also:

Why support GPTQv2 format?

GPTQv2, as a checkpoint format, provides higher accuracy for asymmetrically (especially low-bit) quantized models.

By default, GPTQModel uses GPTQv2 format internally, but converts to GPTQv1 format when storing the quantized checkpoint, mainly for compatibility purpose. However, this conversion is not lossless, and potentially harms model accuracy, especially for low-bit quantized models.

Specifically, the v2 -> v1 conversion requires subtracting 1 from the zero point (ranging [0, 2^b]), causing both zero=0 and zero=1 stored as the same value. For example, in INT2 quantization, this reduces the actual range of zero point values from {0,1,2,3} to {1,2,3} (both will be stored as {0,1,2} in GPTQv1 format).

How to support GPTQv2 format?

The only difference between GPTQv1 and GPTQv2 format, is how they stores the zero points. Specifically, GPTQv1 format subtracts 1 from zero points, and GPTQv2 format does not.

The current gptq_gemm kernel resumes the original zero points by adding 1 back to the zero points, when dequantizing weights. For example:

dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);

To support GPTQv2 format, in the newly added gptq_gemm_v2 kernel, I simply removed all the zero + 1 or -zero -1 logic, and uses the zero points as is.

Test Plan

I added/updated 2 tests to the codebase:

For this PR specifically, I am testing with the following code:

import torch
import logging
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# Test models
MODELS = [
    "BitDistiller/Qwen-8B-w2g64-gptq",
    "BitDistiller/Llama-3.1-8B-Instruct-w2g64-gptq",
]

def test_gptq_v2_models():
    results = []
    
    for model_id in MODELS:
        # Collect output for this model
        model_output = []
        model_output.append(f"\nTesting {model_id}")
        model_output.append("-" * 50)
        
        try:
            llm = LLM(model=model_id, dtype=torch.float16, max_model_len=512)
            sampling_params = SamplingParams(max_tokens=128, temperature=0)
            
            # Prepare prompt
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "What is the meaning of life?"}
            ]
            
            # Handle thinking models
            enable_thinking = "Qwen" in model_id
            text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True,
                enable_thinking=enable_thinking
            )
            
            model_output.append(f"Prompt:\n{text}")
            
            # Generate response
            output = llm.generate(text, sampling_params)
            response = output[0].outputs[0].text
            
            model_output.append(f"Response:\n{response}")
            
            del llm
            
        except Exception as e:
            model_output.append(f"Error: {str(e)}")
        
        results.append("\n".join(model_output))
    
    # Print all results at the end
    print("\n" + "=" * 60)
    print("RESULTS")
    print("=" * 60)
    for result in results:
        print(result)

if __name__ == "__main__":
    test_gptq_v2_models()

Test Result

Running the above test code, I get:
image

The output now is correct.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

github-actions bot commented Oct 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the ci/build label Oct 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for GPTQv2 format checkpoints by introducing a new CUDA kernel, gptq_gemm_v2. The changes are well-structured and include necessary updates to the build system, Python bindings, and tests. My main feedback is regarding the implementation of the new CUDA kernel, which is a near-duplicate of the existing q_gemm.cu. This creates a significant maintenance overhead. I've suggested a refactoring approach to merge the v1 and v2 logic into a single file using template parameters to avoid code duplication.

Comment on lines 1 to 13
/*
Adapted from `q_gemm.cu`, which is adapted from
https://github.com/turboderp/exllamav2 and
https://github.com/qwopqwop200/GPTQ-for-LLaMa.
This supports GPTQ v2 format checkpoints (checkpoint_format: 'gptq_v2'),
by removing the v1-specific "zero + 1" logic during dequantization.
Specifically, GPTQ v1 format checkpoints store (zero - 1), and need to + 1 at
runtime during dequantization. GPTQ v2 format checkpoints store the zero point
as is, and doesn't require + 1 at runtime. For more details, please refer to
ModelCloud/GPTQModel:
https://github.com/ModelCloud/GPTQModel/blob/020ac04b74f6263f22491e6a6a034cb4fa5bf181/gptqmodel/utils/model.py#L625
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new file q_gemm_v2.cu is almost a complete duplicate of q_gemm.cu, with the only difference being the handling of the zero point (+1 for v1, no offset for v2). This introduces a significant maintenance burden, as any future bug fixes or performance improvements in q_gemm.cu would need to be manually ported to this file.

To avoid this code duplication, I suggest merging the logic into the existing q_gemm.cu file and using a template parameter to differentiate between the v1 and v2 formats.

For example, you could template the kernels on a boolean V2_FORMAT:

template <bool V2_FORMAT, bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_4bit_kernel(...) {
  // ...
  const int zero_point_offset = V2_FORMAT ? 0 : 1;
  // ...
  dequant_4bit_8_prep_zero(zeros[0] + zero_point_offset, z1z16[0], y1y16[0]);
  // ...
}

The pick_..._kernel functions can then be updated to select the correct template instantiation based on a new is_v2 boolean parameter. This approach would apply to all duplicated kernels (gemm_..., reconstruct_..., etc.), consolidating the logic into a single, more maintainable file. The author has already mentioned being open to this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this improves maintainability — if it is expected that this fallback GPTQ kernel will be updated (not likely since vllm already has several optimized GPTQ kernels like Marlin).

For now, maybe it is also OK to keep the duplicated code? I can continue to work on templating the functions if the reviewer think it's necessary.

@mergify
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xxxxyu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mgoin
Copy link
Member

mgoin commented Oct 7, 2025

Could you have this format just use the gptq_marlin kernels? These are much more performant, should already support zero point, and are widely used by default for >sm80
Since pytorch only supports sm75 below sm80, I think we want to focus on marlin

@xxxxyu
Copy link
Contributor Author

xxxxyu commented Oct 7, 2025

Could you have this format just use the gptq_marlin kernels? These are much more performant, should already support zero point, and are widely used by default for >sm80 Since pytorch only supports sm75 below sm80, I think we want to focus on marlin

Thanks for your time : )

In my case, I am serving some low-bit (e.g., 2/3-bit) models. Since gptq_marlin only supports 4/8-bit models, vllm falls back to gptq_gemm and produces gibberish outputs.

Although 4/8-bit models are more popular, it is good to support 2/3-bit ones for efficiency consideration. On my machines, the 2-bit models runs faster than 4-bit ones. Some benchmark results (single A100 80GB):

# 2-bit GPTQv2
vllm bench latency --model BitDistiller/Qwen-8B-w2g64-gptq  --input-len 8192 --output-len 1 --batch-size 16 --dtype float16
Avg latency: 8.880313996163506 seconds

# 4-bit GPTQ
vllm bench latency --model JunHowie/Qwen3-8B-GPTQ-Int4 --input-len 8192 --output-len 1 --batch-size 16
Avg latency: 12.523354894698908 seconds

# 4-bit AWQ
vllm bench latency --model Qwen/Qwen3-8B-AWQ --input-len 8192 --output-len 1 --batch-size 16
Avg latency: 12.61037957224374 seconds
  • About using marlin — I'm not quite sure about the feasibility to adapt gptq_marlin for 2/3-bit inference, but I'll take a look, and welcome for discussion.

  • Plus, if it is not planned to support this after discussion, some doc update would be nice — the user might not tell the difference between v1/v2, and it took me quite a while to locate this issue.

@mgoin
Copy link
Member

mgoin commented Oct 7, 2025

Ah I did not consider that gptq_gemm supports lower precisions than 4-bit. Okay @xxxxyu , I think this is a valid case if we can keep the code and binary size impact of this format addition down. It also give some justification for keeping around gptq_gemm for long term
First, I would like to prevent duplication and just add the format as a templated/conditional case in the kernel. We can use the same function and just add a new arg for the format, much like the exllama arg.
Second, I don't see any logic update to the gptq_marlin override to make it aware of this new format so we should make sure 4bit models in the new format aren't running with marlin incorrectly.
After those changes I think this PR should be in a reasonable state - thanks for the contribution!

@xxxxyu
Copy link
Contributor Author

xxxxyu commented Oct 8, 2025

I've merged GPTQv2 format support into gptq_gemm, with a new arg bool use_v2_format.
The CUDA kernels now will determine zero_offset = use_v2_format ? 0 : 1, and use zero + zero_offset instead of zero + 1.

Now I will check the calling logic of 4-bit models.

@xxxxyu xxxxyu changed the title [Kernel] Add gptq_gemm_v2 CUDA kernel to correctly support GPTQv2 format checkpoints [Kernel] Adapt gptq_gemm CUDA kernel to correctly support GPTQv2 format checkpoints Oct 8, 2025
@xxxxyu
Copy link
Contributor Author

xxxxyu commented Oct 8, 2025

Hi @mgoin, some update on marlin override:

Both GPTQMarlinConfig and GPTQBitBLASConfig require is_sym == True. So I guess they bypass this zero point issue internally and don't require logic update unless to support asymmetric quantization. To verify, I tested a symmetrically quantized GPTQv2 format model, and gptq_marlin generates correct outputs. (And it's not meaningful to store a symmetrically quantized model in GPTQv2 format.)

# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}

# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}

Other quantization/linear to consider:

  • Marlin24 seems to use a specialized checkpoint format "marlin_24", so won't be affected.
  • ExllamaLinearKernel calls gptq_gemm, but it inherits MPLinearKernel and doesn't have access to the quantization config — so I hardcode use_v2_format=False in apply_weights() to keep its behavior unchanged.
    • I doubt if ExllamaLinearKernel still in use. I don't see it selected by any quantization.

@xxxxyu
Copy link
Contributor Author

xxxxyu commented Oct 8, 2025

Now 2/3/8-bit GPTQv2 models can run correctly, but 4-bit has some issues with gptq_gemm — it seems to be an existing bug with vllm:

With vllm v0.10.1.1 from pypi, when running https://huggingface.co/JunHowie/Qwen3-8B-GPTQ-Int4 (GPTQv1 format) and forcing quantization="gptq", I got:
image

Usually this won't be triggered as "gptq_marlin" is the default quantization for 4-bit.

My proposal to deal with this:

  • Remove 4 from supported bits in GPTQConfig, since marlin is used by default for >SM80.
  • Or, mark 4-bit as unsupported, only when checkpoint_format is "gptq_v2" — but this bug will persist.

@mergify
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xxxxyu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 14, 2025
…points

Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
…rrectly handled

Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
…_v2_format flag

Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
@xxxxyu
Copy link
Contributor Author

xxxxyu commented Oct 14, 2025

Hi @mgoin, I think I've addressed concerns in #26092 (comment). Do you have some time to review?

  1. Merge duplicate kernels (see [Kernel] Add GPTQv2 format support for low-bit or asymmetric quantization, by adapting gptq_gemm #26092 (comment)).
  2. Make sure other quantizations like gptq_marlin is correct. See below.

After reading the source code, and doing some tests, I've figured out the current state of vLLM's support for GPTQ(v2) format models:

quantization supported bits supported sym supported format
gptq_marlin 4, 8 True gptq, gptq_v2
gptq_bitblas 4, 8 True gptq, gptq_v2
gptq 2, 3, 8 True, False gptq

Specifically:

In summary, this PR completes vLLM's GPTQ format support by adding GPTQv2 format support to 2/3-bit, or asymmetric quantization. It maintains vLLM's original override priorities (gptq_marlin>gptq_bitblas>gptq), and won't cause errors with marlin/bitblas.

@xxxxyu xxxxyu changed the title [Kernel] Adapt gptq_gemm CUDA kernel to correctly support GPTQv2 format checkpoints [Kernel] Add GPTQv2 format support for low-bit or asymmetric quantization, by adapting gptq_gemm Oct 15, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM sorry for the delay @xxxxyu, just would like to prune the test time down a bit if possible

Comment on lines 16 to 19
MODELS = [
("BitDistiller/Qwen-8B-w2g64-gptq", "gptq_v2", True),
("BitDistiller/Llama-3.1-8B-Instruct-w2g64-gptq", "gptq_v2", False),
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use smaller models? I would appreciate 1B params at most for CI cost

Also do you need both? It seems that enable_thinking isn't related to quantization so not important to test here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try quantizing a dummy model < 1B for the test. One is enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3-1.7B can generate normal texts after simple 2-bit quantization, but Qwen3-0.6B cannot. So I uploaded https://huggingface.co/XXXXyu/Qwen3-1.7B-w2g64-gptq_v2 for testing. Model size is 1.0GB, should be OK?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can switch to 0.6B 2-bit too, but the checking logic would be simplified to only check if !!!! exists.

@mgoin mgoin added quantization ready ONLY add when PR is ready to merge/full CI is needed labels Oct 20, 2025
Signed-off-by: Xiangyu Li <xiangyu.sdlc@foxmail.com>
@xxxxyu xxxxyu requested a review from pavanimajety as a code owner October 21, 2025 17:07
@xxxxyu
Copy link
Contributor Author

xxxxyu commented Oct 22, 2025

Hi @mgoin, deprecating the buggy 4-bit gptq_gemm is causing some tests to fail. How do we deal with this?

@mgoin
Copy link
Member

mgoin commented Oct 22, 2025

@xxxxyu Let's just leave the 4bit error change out for now and we can address in a followup PR

@xxxxyu
Copy link
Contributor Author

xxxxyu commented Oct 24, 2025

@mgoin I've replaced the 4-bit error message with a warning. Is this ready to merge?

@mgoin mgoin merged commit 5cc6bdd into vllm-project:main Oct 24, 2025
87 checks passed
atalhens pushed a commit to atalhens/vllm that referenced this pull request Oct 24, 2025
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
rohin-garg pushed a commit to rohin-garg/vllm that referenced this pull request Oct 25, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…tion, by adapting gptq_gemm (vllm-project#26092)

Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…tion, by adapting gptq_gemm (vllm-project#26092)

Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants