Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWQ quantization #102

Merged
merged 7 commits into from
Dec 7, 2023
Merged

Add AWQ quantization #102

merged 7 commits into from
Dec 7, 2023

Conversation

flozi00
Copy link
Collaborator

@flozi00 flozi00 commented Dec 5, 2023

moved to predibase repo from this #100

flozi00 and others added 2 commits December 5, 2023 11:18
@flozi00
Copy link
Collaborator Author

flozi00 commented Dec 5, 2023

git submodule update --init --recursive
I just forgot to init the submodules
Now the docker build is not hanging anymore at punica build

@flozi00
Copy link
Collaborator Author

flozi00 commented Dec 5, 2023

File "/opt/conda/lib/python3.9/site-packages/lorax_server/utils/awq/awq.py", line 40, in forward
    return out.reshape(out_shape)
RuntimeError: shape '[32000, 6144]' is invalid for input of size 131072000

Thats the error message i am struggling with
Tried several ideas but no one worked.
Checked the huggingface tgi main branch and compared the implementations but cant see whats the mistake
Even used their commit in the makefile but still errors

Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 827, in warmup
    _, batch = self.generate_token(batch)
  File "/opt/conda/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 923, in generate_token
    raise e
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_causal_lm.py", line 920, in generate_token
    out = self.forward(batch, adapter_data)
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/flash_mistral.py", line 399, in forward
    logits = self.model.forward(
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 578, in forward
    logits = self.lm_head(hidden_states, adapter_data)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/utils/layers.py", line 476, in forward
    result = self.base_layer(input)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/utils/layers.py", line 308, in forward
    output = super().forward(input)
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/utils/layers.py", line 253, in forward
    return self.linear.forward(x)
  File "/opt/conda/lib/python3.9/site-packages/lorax_server/utils/layers.py", line 91, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

any idea on the first look ?

@tgaddair
Copy link
Contributor

tgaddair commented Dec 5, 2023

Hey @flozi00, I spent some time playing around with it last night. At least for the first issue, it seems that AWQ made a change to the format of their weights in this commit: mit-han-lab/llm-awq@1480555#diff-cd7278928f5da471b08f4aedab4f33e560067768adf06ff06beec1972e9e7240

That seems to be causing the shape mismatch error. What I want to do is spend some time figuring out if the format of AWQ weights saved before this change can be successfully loaded and used with the newer code, as ideally we'd want to be on the newer version of AWQ.

@tgaddair
Copy link
Contributor

tgaddair commented Dec 5, 2023

@flozi00 Docker image has ben built and pushed to https://github.com/predibase/lorax/pkgs/container/lorax/154831836?tag=awq-test.

Any time you push to this branch, it will rebuild the image with the same tag.

@flozi00
Copy link
Collaborator Author

flozi00 commented Dec 5, 2023

Using the format before the changes you linked results in the seconds error code i posted above, but i think that is confusing since its not a real cuda error. I read an thread about where the pytorch team said that error also occures sometimes when missmatching linears.

As far i understand its both times related to lm_head.

Definitely prefering the newer awq version since its faster than the one used in tgi if i remember correctly

@tgaddair
Copy link
Contributor

tgaddair commented Dec 6, 2023

Sounds like newer AWQ performance is quite a bit faster, so I agree we should try to get it working with the newer version.

@flozi00
Copy link
Collaborator Author

flozi00 commented Dec 6, 2023

@tgaddair what do you think about using the kernels from autoawq project ?

https://github.com/casper-hansen/AutoAWQ/blob/5a673bf8435e019f50470b1b8878abf4ee63de57/awq/modules/linear.py#L213C7-L213C7
He is using customized kernels in his project, for example added v2 of an function where the original project use the same twice.

@tgaddair
Copy link
Contributor

tgaddair commented Dec 6, 2023

@flozi00 sounds good to me!

@flozi00
Copy link
Collaborator Author

flozi00 commented Dec 7, 2023

image
image

Its working now
near fp16 performance, better than gptq

ready to be merged from my side

@tgaddair

@flozi00 flozi00 requested a review from tgaddair December 7, 2023 10:48
@flozi00
Copy link
Collaborator Author

flozi00 commented Dec 7, 2023

time_per_token="52.262122ms" on A2000 12GB

Similiar to A6000 48GB with fp16

@flozi00 flozi00 changed the title [untested] AWQ (#100) AWQ Dec 7, 2023
@tgaddair tgaddair changed the title AWQ Add AWQ quantization Dec 7, 2023
Copy link
Contributor

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing! Just tested it myself and verified results look good!

@tgaddair tgaddair merged commit bf3901b into main Dec 7, 2023
1 check failed
@tgaddair tgaddair deleted the awq branch December 7, 2023 18:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants