Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support Fp8 Checkpoints (Dynamic + Static) #4332

Merged
merged 93 commits into from
Apr 30, 2024

Conversation

robertgshaw2-neuralmagic
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic commented Apr 24, 2024

This PR does two things:

  • Supports loading serialized fp8 models with static weight and act scales
  • Supports loading serialized fp8 models with static weight scales and dynamic activations scales

For loading serialized models, vllm auto-detects if merged Layers (QKV / Gate-UpProj) have a shared weight_scale

As a result, vLLM now supports the following cases on H100:

  • Loading fp16 checkpoints and converting to dynamic per tensor activation scaling (but without memory savings)
  • (new) Loading fp8 checkpoints with static dynamic activation scales
  • (new) Loading fp8 checkpoints with dynamic activations scales

Performance will be bad if merged layers (QKV / GateUp) do not have shared weight scales. We are working on cutlass kernels to replace the naive for loop

Next steps (to be handled in another PR):

Usage

from vllm import LLM

# fp16 model at fp8
model = LLM("mistralai/Mistral-7B-Instruct-v0.2", quantization="fp8")

# fp8 static act_scale model
model = LLM("nm-testing/mistral-fp8-static")

# fp8 dynamic act_scale model
model = LLM("nm-testing/mistral-fp8-dynamic")

Performance

Note: all models have shared weight scales for merged layers (QKV, GateUpProj), so there is no naive for loop

Benchmarking shapes

We can see that fp8 is faster end-to-end for prefills, but is slower than fp16 for decode
image

Benchmark serving (1xH100, ShareGPT, Request Rate=5.0, Num Prompts=1000:

  • fp16: meta-llama/Meta-Llama-3-8B-Instruct
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests
python3 benchmark_serving.py --backend openai --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --request-rate 5.0 --num-prompts 1000
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  205.16    
Total input tokens:                      215196    
Total generated tokens:                  186204    
Request throughput (req/s):              4.87      
Input token throughput (tok/s):          1048.94   
Output token throughput (tok/s):         907.62    
---------------Time to First Token----------------
Mean TTFT (ms):                          20.49     
Median TTFT (ms):                        19.88     
P99 TTFT (ms):                           42.07     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.42     
Median TPOT (ms):                        10.36     
P99 TPOT (ms):                           13.23     
==================================================
  • fp8 static with shared scales: nm-testing/llama-3-instruct-fp8-static-shared-scales
python3 -m vllm.entrypoints.openai.api_server --model nm-testing/llama-3-instruct-fp8-static-shared-scales --disable-log-requests
python3 benchmark_serving.py --backend openai --model nm-testing/llama-3-instruct-fp8-static-shared-scales --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --request-rate 5.0 --num-prompts 1000
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  205.65    
Total input tokens:                      215196    
Total generated tokens:                  142560    
Request throughput (req/s):              4.86      
Input token throughput (tok/s):          1046.41   
Output token throughput (tok/s):         693.21    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.48     
Median TTFT (ms):                        22.34     
P99 TTFT (ms):                           37.28     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          27.67     
Median TPOT (ms):                        11.58     
P99 TPOT (ms):                           94.82     
==================================================

^not sure why P99 TPOT is so bad here

  • fp8 static with dynamic scales: meta-llama/Meta-Llama-3-8B-Instruct --quantization fp8
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --quantization fp8 --disable-log-requests
python3 benchmark_serving.py --backend openai --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --request-rate 5.0 --num-prompts 1000
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  206.00    
Total input tokens:                      215196    
Total generated tokens:                  186083    
Request throughput (req/s):              4.85      
Input token throughput (tok/s):          1044.65   
Output token throughput (tok/s):         903.32    
---------------Time to First Token----------------
Mean TTFT (ms):                          23.52     
Median TTFT (ms):                        24.31     
P99 TTFT (ms):                           40.56     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          11.80     
Median TPOT (ms):                        11.93     
P99 TPOT (ms):                           14.11     
==================================================

Experimental checkpoint structure

Here we detail the experimental structure for the fp8 checkpoints.

We plan to expand upon this as we start adding support for int8, sparsity, etc

The following is added to config.json

"quantization_config": {
    "quant_method": "fp8",
    "activation_scheme": "static" or "dynamic"
  },

Each quantized layer in the state_dict will have:

If the config has "activation_scheme": "static":

model.layers.0.mlp.down_proj.weight              < F8_E4M3
model.layers.0.mlp.down_proj.act_scale           < F32
model.layers.0.mlp.down_proj.weight_scale        < F32

If config has "activation_scheme": "dynamic":

model.layers.0.mlp.down_proj.weight              < F8_E4M3
model.layers.0.mlp.down_proj.weight_scale        < F32
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

qinput, x_scale = per_tensor_quantize(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't you meant to use the static input scale stored in the model state dict?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the next step I'm working on

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title [WIP] [Kernel] Load + Run Serialized Models [WIP] [Kernel] Load + Run Serialized Fp8 Models Apr 24, 2024
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pcmoritz
Copy link
Collaborator

pcmoritz commented Apr 30, 2024

There is a simplification we can make to this PR, which is when loading the checkpoint, always rescale the different parts of the tensor so the resulting tensor has only one scale, for example in the case of qkv weigts (wq, wk, wv) it would be:

w_scale = max(wq_scale, wk_scale, wv_scale)
wq' = (wq_scale / w_scale) * wq
wk' = (wk_scale / w_scale) * wk
wv' = (wv_scale / w_scale) * wv

so the common scale w_scale can be used for the whole tensor.

This way we wouldn't need custom kernels to treat the different parts of the tensor differently.

Since the weights don't have much variance and fp8 is actually a decent amount of accuracy for the weights, this would likely not impact the accuracy (but we would need to test that of course).

@pcmoritz
Copy link
Collaborator

Note that at the moment the "nm-testing/mistral-fp8-static" checkpoint doesn't seem to be working well, these are the results I got:

In [1]: from vllm import LLM, SamplingParams

In [2]: model = LLM("nm-testing/mistral-fp8-static")
fp8 quantization is not fully optimized yet. The speed can be slower than non-quantized models.
Detected fp8 checkpoint. Please note that the format is experimental and subject to change.

In [3]: prompts = [
   ...:     "Hello, my name is",
   ...:     "The president of the United States is",
   ...:     "The capital of France is",
   ...:     "The future of AI is",
   ...: ]
   ...: sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

In [4]: 

In [4]: outputs = model.generate(prompts, sampling_params)
   ...: 
   ...: # Print the outputs.
   ...: for output in outputs:
   ...:     prompt = output.prompt
   ...:     generated_text = output.outputs[0].text
   ...:     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
   ...: 
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 19.62it/s]
Prompt: 'Hello, my name is', Generated text: ' Marissa'
Prompt: 'The president of the United States is', Generated text: ''
Prompt: 'The capital of France is', Generated text: ' an iconic city that is famous for its landmarks, museums, caf'
Prompt: 'The future of AI is', Generated text: ' bright and the potential of AI to revolutionize industries, improve lives and drive eff'

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

There is a simplification we can make to this PR, which is when loading the checkpoint, always rescale the different parts of the tensor so the resulting tensor has only one scale, for example in the case of qkv weigts (wq, wk, wv) it would be:

w_scale = max(wq_scale, wk_scale, wv_scale)
wq' = (wq_scale / w_scale) * wq
wk' = (wk_scale / w_scale) * wk
wv' = (wv_scale / w_scale) * wv

so the common scale w_scale can be used for the whole tensor.

This way we wouldn't need custom kernels to treat the different parts of the tensor differently.

Since the weights don't have much variance and fp8 is actually a decent amount of accuracy for the weights, this would likely not impact the accuracy (but we would need to test that of course).

This is a good idea. I will make this change

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

Note that at the moment the "nm-testing/mistral-fp8-static" checkpoint doesn't seem to be working well, these are the results I got:

In [1]: from vllm import LLM, SamplingParams

In [2]: model = LLM("nm-testing/mistral-fp8-static")
fp8 quantization is not fully optimized yet. The speed can be slower than non-quantized models.
Detected fp8 checkpoint. Please note that the format is experimental and subject to change.

In [3]: prompts = [
   ...:     "Hello, my name is",
   ...:     "The president of the United States is",
   ...:     "The capital of France is",
   ...:     "The future of AI is",
   ...: ]
   ...: sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

In [4]: 

In [4]: outputs = model.generate(prompts, sampling_params)
   ...: 
   ...: # Print the outputs.
   ...: for output in outputs:
   ...:     prompt = output.prompt
   ...:     generated_text = output.outputs[0].text
   ...:     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
   ...: 
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 19.62it/s]
Prompt: 'Hello, my name is', Generated text: ' Marissa'
Prompt: 'The president of the United States is', Generated text: ''
Prompt: 'The capital of France is', Generated text: ' an iconic city that is famous for its landmarks, museums, caf'
Prompt: 'The future of AI is', Generated text: ' bright and the potential of AI to revolutionize industries, improve lives and drive eff'

Yes, this checkpoint was only calibrated with 10 datapoints so I could work on the implementation - Its not intended to be an accurate model

@WoosukKwon
Copy link
Collaborator

There is a simplification we can make to this PR, which is when loading the checkpoint, always rescale the different parts of the tensor so the resulting tensor has only one scale, for example in the case of qkv weigts (wq, wk, wv) it would be:

w_scale = max(wq_scale, wk_scale, wv_scale)
wq' = (wq_scale / w_scale) * wq
wk' = (wk_scale / w_scale) * wk
wv' = (wv_scale / w_scale) * wv

so the common scale w_scale can be used for the whole tensor.

This way we wouldn't need custom kernels to treat the different parts of the tensor differently.

Since the weights don't have much variance and fp8 is actually a decent amount of accuracy for the weights, this would likely not impact the accuracy (but we would need to test that of course).

Agreed. IIRC, TRT-LLM also takes this approach.

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

We will need to eval the accuracy of this (especially for big models), but I think good for now

vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
Co-authored-by: Michael Goin <michael@neuralmagic.com>
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 111815d into vllm-project:main Apr 30, 2024
48 checks passed
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic deleted the fp8-static branch April 30, 2024 21:46
pcmoritz pushed a commit that referenced this pull request May 4, 2024
… Dynamic/Static Activations) (#4527)

Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
robertgshaw2-neuralmagic added a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
… Dynamic/Static Activations) (vllm-project#4527)

Follow on to vllm-project#4332 to enable FP8 checkpoint loading for Mixtral and supersedes vllm-project#4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
… Dynamic/Static Activations) (vllm-project#4527)

Follow on to vllm-project#4332 to enable FP8 checkpoint loading for Mixtral and supersedes vllm-project#4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
… Dynamic/Static Activations) (vllm-project#4527)

Follow on to vllm-project#4332 to enable FP8 checkpoint loading for Mixtral and supersedes vllm-project#4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
… Dynamic/Static Activations) (vllm-project#4527)

Follow on to vllm-project#4332 to enable FP8 checkpoint loading for Mixtral and supersedes vllm-project#4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants