Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to execute uniform quantization instead of NF4 quantization? #12

Open
LuletterSoul opened this issue Dec 26, 2023 · 4 comments
Open

Comments

@LuletterSoul
Copy link

Hi, thanks for your amzaing job. I found the code using NF4 quantization by default, but don't add any support to switch UQ. If I have a model quantized by GPTQ, how to use LoftQ on it?

I have tried a GPTQ-quantized model using PEFT, but it raised a exception as followed:

Traceback (most recent call last):
  File "quantize_save.py", line 221, in <module>
    base_dir, lora_dir = quantize_and_save()
  File "quantize_save.py", line 191, in quantize_and_save
    lora_model = get_peft_model(model, lora_config)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/mapping.py", line 133, in get_peft_model
    return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/peft_model.py", line 1043, in __init__
    super().__init__(model, peft_config, adapter_name)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/peft_model.py", line 125, in __init__
    self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/tuners/lora/model.py", line 111, in __init__
    super().__init__(model, config, adapter_name)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/tuners/tuners_utils.py", line 87, in __init__
    self.inject_adapter(self.model, adapter_name)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/tuners/tuners_utils.py", line 244, in inject_adapter
    self._create_and_replace(peft_config, adapter_name, target, target_name, parent, **optional_kwargs)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/tuners/lora/model.py", line 181, in _create_and_replace
    new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/tuners/lora/model.py", line 283, in _create_new_module
    new_module = QuantLinear(target, adapter_name, **kwargs)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/tuners/lora/gptq.py", line 40, in __init__
    self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/tuners/lora/layer.py", line 96, in update_layer
    self.loftq_init(adapter_name)
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/peft/tuners/lora/layer.py", line 134, in loftq_init
    weight = self.get_base_layer().weight
  File "/root/miniconda3/envs/chatglm3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'QuantLinear' object has no attribute 'weight'
@yxli2123
Copy link
Owner

yxli2123 commented Jan 7, 2024

Hi, thank you for the interest of our work. LoftQ supports any existing quantization function in theory, but GPTQ implementation AutoGPTQ doesn't support dequantization, which is required in LoftQ (see Section 2.2 in LoftQ paper).

If you can find GPTQ implementation that has the dequantization method, please let me know. I'm glad to add it to LoftQ :)

@yxli2123
Copy link
Owner

yxli2123 commented Jan 7, 2024

Plus, we do have the experimental uniform quantization method at https://github.com/yxli2123/LoftQ/blob/main/glue/utils.py#L103. However, it's not the same uniform quantization used in GPTQ.

@LuletterSoul
Copy link
Author

Hi, thank you for the interest of our work. LoftQ supports any existing quantization function in theory, but GPTQ implementation AutoGPTQ doesn't support dequantization, which is required in LoftQ (see Section 2.2 in LoftQ paper).

If you can find GPTQ implementation that has the dequantization method, please let me know. I'm glad to add it to LoftQ :)

Do you mean that vecquant4matmul is not a seperate dequantization function (dequantization + matmul) ?

@LuletterSoul
Copy link
Author

LuletterSoul commented Jan 17, 2024

@yxli2123 Thank you for providing experimental details. And congratulations to LoftQ for being accepted as a oral at ICLR 2024! It's sure that AutoGPTQ uses group-wise quantization and bit compression. Maybe LoftQ requires a custom dequantization function if it have to integrate into PEFT.

I found some related discussions about Pytorch-like Dequatization function:

Faster Pytorch dequantize() + matmul for quantized models

hqq_aten.cpp

A dequantization function seems to be implemented by offical pytorch:
FUNCTION AT::_WEIGHT_INT4PACK_MM

I hope the above information will help you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants