Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Allow loading of quantized lm_head (ParallelLMHead) #4442

Open
wants to merge 101 commits into
base: main
Choose a base branch
from

Conversation

Qubitium
Copy link
Contributor

@Qubitium Qubitium commented Apr 29, 2024

Reason for PR:

  • lm_head can be quantized with minimal loss to output
  • save vram by allowing quantized lm_head

Changes:

  • Read lm_head from quantize_config
  • Allow ParallelLmHead to be loaded quantized
  • Rename QuantizeMethodBase to more accurate QuantizableMethodBase since non-quantized methods also inherit this
  • Added QUANTIZED bool property to QuantizableMethodBase to avoid all the isinstance calls
  • Refractor repeating gptq param check code into utils/skip_gptq_extra_param

Tooling Cross Dependency (tools that make quantized lm_head using GPTQ):

Test Model (quantized by auto-round and load tested with autogptq):

https://huggingface.co/LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse

Intel/auto-round project @wenhuach21 has demonostrated that lm_head is a good candidate for quantization with minimal-loss.

https://github.com/intel/auto-round/blob/8a3da144423322dfedb0b3fa702ae35d242496d8/docs/Meta-Llama-3-8B-Instruct-acc.md?plain=1#L3

Metric BF16 w4g128 w/o lm-head w4g128 with quantized lm-head
Avg. 0.6352 0.6312 0.6303
mmlu 0.6386 0.6306 0.6318
winogrande 0.7143 0.7238 0.7269
truthfulqa_mc1 0.3623 0.3537 0.3525
rte 0.6751 0.6859 0.6679
piqa 0.7867 0.7797 0.7802
openbookqa 0.3400 0.3300 0.3320
lambada_openai 0.7182 0.7200 0.7173
hellaswag 0.5769 0.5699 0.5701
boolq 0.8297 0.8309 0.8284
arc_easy 0.8152 0.8089 0.8106
arc_challenge 0.5299 0.5102 0.5154

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@robertgshaw2-neuralmagic
Copy link
Collaborator

Running into a big issue loading OPT quantized lm_head due to the fact for OPT (and other models) lm_head just a soft-linked in code to embeddings (not a unique lm_head layer like llama). But the thing is OPT model has both lm_head and embedding in original weights except they are the same tensor shape/size and values.

Checking with intel/auto-round team to see if this is a potential issue that should be addressed at the quantization stage or a remapping (here in vlllm) that we need to do. I just don't know the correct answer yet. This could be a bug in vllm OPT code in which assumptions are made that lm_head is always the same as embedding and should be ignored on load. This may be true for pre-quant but post-quant it may not be.

Once this is addressed, I believe limits to ParallelLmHead can be fully unlocked and we just need to check for ParallelVocabEmbedding

Asking Intel/auto-round devs for clarification/insight: intel/auto-round#100

Change PR to draft-mode until this issue is resolved.

Let's handle this in a follow up PR since the scope of this is already big

@Qubitium
Copy link
Contributor Author

Qubitium commented May 1, 2024

@robertgshaw2-neuralmagic I overwrote your test changes in commit 2f63a72 Will re-merge with your changes later.

Fixed OPT model compat with lm_head. lm-head tests now passing.

New problem Marlin kernel loading with lm_head is broken. For now I disabled Marlin auto-upconvert when lm_head is detected/True.

  • Fixed OPT compat with lm_head: load both lm_head and embed_tokens separately. Do not assume they are the same.
  • Fixed: Marlin runtime convert of compatible models with lm_head enabled is failing. Add TODO and disable marlin upconvert when lm_head quant enabeld.

@Qubitium Qubitium marked this pull request as ready for review May 1, 2024 00:23
@robertgshaw2-neuralmagic
Copy link
Collaborator

@Qubitium no worries - to make it easier for us to both work on it, im going to move the model testing refactor to another PR

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented May 1, 2024

@Qubitium I'm going to merge this first, so we can test properly

#4510

2. optimize opt vram usage by keeping only one copy of lm_head vs embed_tokens when possible
@Qubitium
Copy link
Contributor Author

Qubitium commented May 1, 2024

@robertgshaw2-neuralmagic I do not plan to make any more changes unless pending CI build shows something broken. https://buildkite.com/vllm/ci/builds/6198

Feel free to add/mod.

Changes and notes:

  • Fixed: Identified runtime auto-convertion to marlin format for lm_head enabled quants is not working. I am not sure what's going on here since the loader path is different in this runtime conversion case. Need to fix in a new PR. This PR is getting bloated as is. For now, disable auto-marlin if lm_head is enabled.
  • Moved lm_head_quantized property from GPTQConfig to base QuantizationConfig.
  • Fixed OPT compat with lm_head quants
  • OPT on end of load_weights() will delete model.lm_head (softlink to embed_tokens) if model is not quantized or quantized with lm_head false as not to duplicate memory as in this case, the weights are the same.

On last point, it may be a good idea to add an api (future PR) to model so there is an equivalent on_load_weights_end method that does model weight cleanups/logic post load.

Extra notes:

  • Many models that have similar code like OPT where both lm_head and embedding are the same but de-duplicated in model code with skipping logic in load_weights needs to be modified in a similar fashion as well. I did not check if there are any others nor do I want to bloat this PR even more.

@Qubitium
Copy link
Contributor Author

Qubitium commented May 1, 2024

All the relevant tests are passing. Of note that the code changes for OPT works but is actually not efficient. There is negative advantage to have an OPT model enable quantization of lm_head as it would, at the moment, disable the sharing of embed_tokens and load 2 different layers causing he overall OPT memory footprint to increase. For llama models where lm_head are separate, the memory saving is significant.

@Qubitium
Copy link
Contributor Author

Qubitium commented May 6, 2024

To align with BaseConfig inheritance by interface and not property vars, refractored BaseConfig.lm_head_quantized variable into BaseConfig.is_lm_head_quantized() interface. 9981620

@Qubitium Qubitium changed the title [CORE] Allow loading of GPTQ quantized lm_head (ParallelLMHead) [CORE] Allow loading of quantized lm_head (ParallelLMHead) May 6, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

Going to pick this back up this weekend

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants