Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][FP8] Initial support with dynamic per-tensor scaling #4118

Merged
merged 16 commits into from
Apr 20, 2024

Conversation

comaniac
Copy link
Contributor

@comaniac comaniac commented Apr 16, 2024

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm

We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results

Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

  • BF16: 1.47s
  • FP8: 1.66s

I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this PR.

To Do

  • Unit tests
  • Comprehensive benchmarking
  • Refine the interface

cc @WoosukKwon @zhuohan123 @robertgshaw2-neuralmagic @mgoin

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
@mgoin
Copy link
Collaborator

mgoin commented Apr 16, 2024

Nice job here for a clean implementation. For performance improvement, I think the easiest/clearest measurement would be to measure TTFT for a large prompt since we are trying to improve compute efficiency here, rather than just memory bandwidth like most existing weight-only quantization methods in vllm.

@comaniac
Copy link
Contributor Author

I did a benchmarking with 50 requests sending with QPS 1. The average prompt length is 512 and decoding length is 128. The results are as follows:

Case TTFT (ms) ITL (ms)
FP16 28.8 8.5
FP8 38.8 10.2
FP8-test 27.4 9.5

In FP8-test, I removed activation quantization to understand its impact. Although this results in garbage output, it shows that activation quantization does introduce a large amount of overhead. Meanwhile, even removing activation quantization, the speedup isn't obvious. This looks weird to me because in this case we only have .to and torch._scaled_mm in the forward. I'm going to perform a microbenchmark with torch._scaled_mm to see what speedup we could actually achieve against F.linear.

@comaniac
Copy link
Contributor Author

comaniac commented Apr 17, 2024

Microbenchmark between torch._scaled_mm and F.linear with shapes extracted from Mistral-7B on H100, PyTorch 2.2.1+cu121. Each run includes 100 trials with torch.cuda.synchronize(). The results are the average latency.

  • FP16: F.linear
  • FP8: torch._scaled_mm with pre-quantized inputs.
  • FP8 w. x-to: torch._scaled_mm with pre-quantized weights and run time x.to(torch.float8_e4m3fn). Note that this results in incorrect output values.
  • FP8 w. x-quantized: torch._scaled_mm with pre-quantized weights and run time quantize(x). The quantize is compiled with torch.compile.

Summary:

  1. Shape matters. Larger the better. Without any other overheads such as quantization, we can see relatively obvious speedup with input tokens >=32.
  2. Even just x.to(torch.float8_e4m3fn) could moderate most speedups. In means we may not get decent speedups if only linear layers are computed in FP8. The activation quantization overhead can be moderated as long as most ops are computed in FP8 and achieve some speedups on average.
M (tokens) N K FP16 (ms) FP8 (ms) Speedup FP8 w. x-to (ms) Speedup FP8 w. x-quantized (ms) Speedup
1 4096 6144 0.3205 0.7489 0.43 1.1397 0.28 0.7889 0.41
1 4096 4096 0.0303 0.0435 0.7 0.0491 0.62 0.1142 0.27
1 4096 28672 3.2293 0.1094 29.52 0.1055 30.61 0.1738 18.58
1 14336 4096 0.0667 0.0799 0.83 0.0733 0.91 0.1443 0.46
2 4096 6144 0.2818 0.0464 6.07 0.0514 5.48 0.1521 1.85
2 4096 4096 0.0298 0.0412 0.72 0.0474 0.63 0.1446 0.21
2 4096 28672 0.1003 0.1118 0.9 0.1075 0.93 0.2046 0.49
2 14336 4096 0.0639 0.0721 0.89 0.071 0.90 0.1709 0.37
4 4096 6144 0.04 0.0456 0.88 0.052 0.77 0.1507 0.27
4 4096 4096 0.0294 0.0401 0.73 0.0468 0.63 0.1443 0.20
4 4096 28672 0.1013 0.1092 0.93 0.1116 0.91 0.2077 0.49
4 14336 4096 0.0637 0.0689 0.92 0.0718 0.89 0.1695 0.38
8 4096 6144 0.04 0.0471 0.85 0.0505 0.79 0.147 0.27
8 4096 4096 0.0295 0.0403 0.73 0.0452 0.65 0.1422 0.21
8 4096 28672 0.1003 0.111 0.9 0.1137 0.88 0.2121 0.47
8 14336 4096 0.0652 0.0879 0.74 0.0721 0.90 0.1692 0.39
16 4096 6144 0.0406 0.0455 0.89 0.0502 0.81 0.1473 0.28
16 4096 4096 0.03 0.0406 0.74 0.0469 0.64 0.1432 0.21
16 4096 28672 0.1017 0.109 0.93 0.1078 0.94 0.2034 0.50
16 14336 4096 0.0689 0.0718 0.96 0.0703 0.98 0.1688 0.41
32 4096 6144 0.0436 0.0421 1.04 0.0472 0.92 0.1432 0.30
32 4096 4096 0.0282 0.0344 0.82 0.0404 0.70 0.1363 0.21
32 4096 28672 0.1059 0.0786 1.35 0.0799 1.33 0.1746 0.61
32 14336 4096 0.066 0.054 1.22 0.058 1.14 0.1531 0.43
64 4096 6144 0.0436 0.0377 1.16 0.0424 1.03 0.149 0.29
64 4096 4096 0.0342 0.0345 0.99 0.0423 0.81 0.1378 0.25
64 4096 28672 0.1064 0.0791 1.35 0.0808 1.32 0.1776 0.60
64 14336 4096 0.065 0.054 1.21 0.0608 1.07 0.1578 0.41
128 4096 6144 0.0509 0.0423 1.2 0.0511 1.00 0.1452 0.35
128 4096 4096 0.036 0.0388 0.93 0.0432 0.83 0.1395 0.26
128 4096 28672 0.1237 0.0894 1.38 0.0884 1.40 0.1874 0.66
128 14336 4096 0.0681 0.0544 1.25 0.0585 1.16 0.1538 0.44
256 4096 6144 0.0511 0.0428 1.19 0.0483 1.06 0.1462 0.35
256 4096 4096 0.0416 0.0372 1.12 0.0433 0.96 0.1416 0.29
256 4096 28672 0.1482 0.1095 1.35 0.111 1.34 0.2085 0.71
256 14336 4096 0.0758 0.0613 1.24 0.0651 1.16 0.1604 0.47
512 4096 6144 0.0668 0.0513 1.3 0.0568 1.18 0.1532 0.44
512 4096 4096 0.0493 0.0438 1.13 0.0491 1.00 0.1463 0.34
512 4096 28672 0.1957 0.1326 1.48 0.1361 1.44 0.2299 0.85
512 14336 4096 0.1239 0.1013 1.22 0.1011 1.23 0.1938 0.64
1024 4096 6144 0.102 0.0746 1.37 0.0788 1.29 0.1748 0.58
1024 4096 4096 0.072 0.0579 1.24 0.0641 1.12 0.1578 0.46
1024 4096 28672 0.3453 0.2298 1.5 0.2473 1.40 0.3565 0.97
1024 14336 4096 0.2188 0.1573 1.39 0.1589 1.38 0.2487 0.88
2048 4096 6144 0.1739 0.1162 1.5 0.1246 1.40 0.2125 0.82
2048 4096 4096 0.1239 0.0854 1.45 0.0931 1.33 0.1872 0.66
2048 4096 28672 0.7443 0.418 1.78 0.4744 1.57 0.6277 1.19
2048 14336 4096 0.401 0.251 1.6 0.2566 1.56 0.3548 1.13

@mgoin
Copy link
Collaborator

mgoin commented Apr 17, 2024

@comaniac I did some lower-level benchmarking just on torch._scaled_mm as you mentioned and got some fairly interesting results that I'm still investigating (EDIT: nice timing on the above 🥈 😄 ). I did this on an RTX 4090 so I did have to use the pytorch-nightly build in order to support CUDA capability 8.9. You can find my code here: https://github.com/mgoin/torch-fp8

Spreadsheet of data here for the following graphs: https://docs.google.com/spreadsheets/d/1X-N8h8_O-Z-S7tN5WRkAXXKkhCEniJciUpEzaUaksqs/edit?usp=sharing

Here is a fine-grained benchmark showing various M (num tokens dim, so either batch size or prefill length) for the layer shapes of Llama 7B:

image

Obviously some interesting stuff seeing more than 2x speedup, but I wouldn't take that so seriously since this is individual layer benchmarking and it is likely fitting into cache at small M (however I am unsure about the bump at M between 100-130, very interesting)

Here is a benchmark that is less fine-grained but showing the impact of dynamic activation quantization:

image

I'll continue benchmarking and investigating how we can minimize overhead on the individual layer level and on different GPUs, but it does look like the performance of scaled_mm is highly variable for M<192 on an RTX 4090. This is bound to be worse on a GPU with more memory bandwidth like an H100.

@comaniac
Copy link
Contributor Author

comaniac commented Apr 17, 2024

Thanks for the benchmarking and it basically aligns with my results on H100. I'll also do more investigations to see whether we could amortize this overhead.

btw my benchmark code is pretty much the same as yours. The only different is I added @torch.compile to to_float8. It helps a little.

@comaniac comaniac marked this pull request as ready for review April 17, 2024 23:40
@comaniac
Copy link
Contributor Author

@WoosukKwon @zhuohan123 @robertgshaw2-neuralmagic @mgoin @pcmoritz

Although we haven't achieved a decent performance with this FP8 support, I'd like to get this PR reviewed first, because it seems better to build follow-up optimizations on top of this PR so merging this PR could unblock others to try more interesting stuffs (unless you have concerns about the current interface and API implementation).

In short, this PR now offers:

  1. Loading FP16 checkpoints and per-tensor quantizing linear weights to FP8 during the first forward pass.
  2. Per-tensor quantizing activations to FP8 in each forward pass.
  3. We also accept FP8 checkpoints. It has to store linear weights in torch.float8_e4m3fn along with an additional w_scale parameter. The major benefit of loading a FP8 checkpoint is lowering the memory footprint.

Some follow-ups (I probably don't have time to do all of these so you're welcome to take anything you interest):

  1. Confirm with PyTorch and NVIDIA about the performance of torch._scaled_mm.
  2. Benchmark TGI v2.0.0 with FP8, and confirm that we are on the same boat now.
  3. If the above 2 points are true, we need a high-efficient GEMM kernel for small number of tokens. @pcmoritz suggested that we could try Triton (https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html).
  4. Investigate about how can we reduce/amortize activation quantization overheads. There are two possible directions. First, compute more ops in FP8 so that one quantization can benefit more. Second, leverage delayed scaling (i.e. reuse scaling factor), but this may affect the result quality significantly.
  5. Enable FP8 in attention layer: Flash attention should have supported FP8 so we should be good with prefill. For decoding, we better integrate FlashInfer directly as it supports FP8 too.

vllm/entrypoints/llm.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nicely done! We should add some tests pretty quickly as a follow up PR. I tested it locally with

from vllm import LLM, SamplingParams


llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1", tensor_parallel_size=1, quantization="fp8")

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

and making sure FP8 is indeed used with

In [9]: llm.llm_engine.model_executor.driver_worker.model_runner.model.model.layers[0].mlp.down_proj.linear_method
Out[9]: <vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod at 0x7f66163aa280>

(TP2 also works, but only after master is merged into this PR -- it seems it was broken when the PR was branched off)

@comaniac
Copy link
Contributor Author

Thanks for the review! Will add some tests tomorrow along with your suggestions.

benchmarks/benchmark_latency.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to test the FP8 weight loading, could you include an example of how you save the FP8 model format? Something like the FP8 KV Cache example would be nice

@comaniac
Copy link
Contributor Author

comaniac commented Apr 18, 2024

In order to test the FP8 weight loading, could you include an example of how you save the FP8 model format? Something like the FP8 KV Cache example would be nice

Thanks for the pointer. I could actually make it compatible with the existing quantization flow.

@pcmoritz pcmoritz enabled auto-merge (squash) April 20, 2024 00:03
@pcmoritz pcmoritz merged commit a22cdea into vllm-project:main Apr 20, 2024
46 checks passed
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 21, 2024
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
ZackBradshaw pushed a commit to ZackBradshaw/vllm that referenced this pull request Apr 22, 2024
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
ZackBradshaw pushed a commit to ZackBradshaw/vllm that referenced this pull request Apr 22, 2024
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
pcmoritz added a commit that referenced this pull request Apr 24, 2024
This PR is the first step towards fixing #3208

It implements dynamic per-tensor scaling (see #4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in #3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
xjpang pushed a commit to xjpang/vllm that referenced this pull request Apr 25, 2024
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
xjpang pushed a commit to xjpang/vllm that referenced this pull request Apr 25, 2024
This PR is the first step towards fixing vllm-project#3208

It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 26, 2024
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 26, 2024
This PR is the first step towards fixing vllm-project#3208

It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
@comaniac comaniac deleted the fp8-init branch April 26, 2024 23:33
alexeykondrat pushed a commit to alexeykondrat/ci-vllm that referenced this pull request May 1, 2024
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
alexeykondrat pushed a commit to alexeykondrat/ci-vllm that referenced this pull request May 1, 2024
This PR is the first step towards fixing vllm-project#3208

It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
joerunde pushed a commit to IBM/vllm that referenced this pull request May 6, 2024
This PR is the first step towards fixing vllm-project/vllm#3208

It implements dynamic per-tensor scaling (see vllm-project/vllm#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project/vllm#3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
pcmoritz added a commit that referenced this pull request May 9, 2024
This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
This PR improves the FP8 performance of linear layers, which had been lacking before (vllm-project#4118 (comment) and vllm-project#4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
This PR improves the FP8 performance of linear layers, which had been lacking before (vllm-project#4118 (comment) and vllm-project#4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
This PR improves the FP8 performance of linear layers, which had been lacking before (vllm-project#4118 (comment) and vllm-project#4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
This PR is the first step towards fixing vllm-project#3208

It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
This PR improves the FP8 performance of linear layers, which had been lacking before (vllm-project#4118 (comment) and vllm-project#4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants