Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] W8A16 Int8 inside FusedMoE #7415

Merged
merged 33 commits into from
Aug 16, 2024

Conversation

mzusman
Copy link
Contributor

@mzusman mzusman commented Aug 12, 2024

🚀 The feature, motivation and pitch

Additional feature for fused_moe triton kernel to support W8A16 with Int8, supports Ampere/Ada lovelace/Hopper, called ExpertsInt8.
Based on symmetric per-column per-expert Int8 quantization, casting the weights to FP16/BF16 before matmul inside the fused_moe kernel (compute_type in FP/BF16).
Support quantization and scales extraction on startup (takes 1min on Jamba).
We've ran quality benchmarks on Jamba and it shows no quality degradation:

Method BF16 ExpertsInt8
gsm8k cot 59.5 59.8
MMLU 67.3 67.3
gsm8k 50.6 50.1
narrative_qa 67.8 68.7
ppl/c4 -0.4301 -0.4301

Performance:
E2E latency in seconds on requests with Prompt length=1024, decode length=128:

Model Hardware Method BS=1 BS=4 BS=8
Mixtral8x22B H100*8 FP8 (MoE quant-only) 1.3 1.79 2.26
Mixtral8x22B H100*8 ExpertsInt8 1.3 1.84 2.205
-- -- -- -- -- --
Mixtral8x7B H100*4 FP8 (MoE quant-only) 0.835 1.22 1.44
Mixtral8x7B H100*4 ExpertsInt8 0.83 1.22 1.42
-- -- -- -- -- --
Jamba H100*2 FP8 (MoE quant-only) 1.3 2 3.2
Jamba H100*2 ExpertsInt8 1.16 2 3.14
-- -- -- -- -- --
Jamba A100*2 FP16 1.75 3.2 4.6
Jamba A100*2 ExpertsInt8 1.65 2.7 4
Jamba A100*2 GPTQ 8bit (W/o FusedMoe) 3.8 5.9 9.4

Advantages:

  • Doesn't require calibration preprocess.
  • No quality degradation.
  • Quantized FusedMoE methodology that runs on A100s.
  • Safer in case of large activations since they can be saved in BF16 reducing the risk of overflow.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mzusman
Copy link
Contributor Author

mzusman commented Aug 12, 2024

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 12, 2024
@mzusman mzusman changed the title [Kernel] W8A16 Int8 MoE [Kernel] W8A16 Int8 inside FusedMoE Aug 12, 2024
@halexan
Copy link

halexan commented Aug 13, 2024

@mzusman

How to convert models to ExpertsInt8?

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

Hey - just FYI there are some ongoing efforts to extend Marlin to support W4A16 and W8A16

Right now the kernels load GPTQ models, but we could really connect them to any models type

We should run benchmarks against these as well in deciding which kernel to use

#7079

@mzusman
Copy link
Contributor Author

mzusman commented Aug 13, 2024

@mzusman

How to convert models to ExpertsInt8?

You would just need to run vLLM with --quantization experts_int8, It supports quantization on the fly

@mzusman
Copy link
Contributor Author

mzusman commented Aug 13, 2024

Hey - just FYI there are some ongoing efforts to extend Marlin to support W4A16 and W8A16

Right now the kernels load GPTQ models, but we could really connect them to any models type

We should run benchmarks against these as well in deciding which kernel to use

#7079

I understand, I was trying to benchmark this method against the PR #7079 on Nous-Hermes-2-Mixtral-8x7B-DPO-GPTQ-8bit-128g/TheBloke/Mixtral-8x7B-v0.1-GPTQ and it ended up unsuccessful (it hangs on startup).

@robertgshaw2-neuralmagic Do you think this is a blocker for merging this PR? We can have the two options available.

@jeejeelee
Copy link
Contributor

It seems that #6502 is also a similar PR.

@halexan
Copy link

halexan commented Aug 15, 2024

I tested this pull request on deepseek-v2-chat-236B. Indeed more concurrency.

Copy link
Sponsor Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick changes! I think these are my last round of comments

.buildkite/run-cpu-test.sh Outdated Show resolved Hide resolved
benchmarks/kernels/benchmark_moe.py Outdated Show resolved Hide resolved
benchmarks/kernels/benchmark_moe.py Outdated Show resolved Hide resolved
benchmarks/kernels/benchmark_moe.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/fused_moe/fused_moe.py Outdated Show resolved Hide resolved
@mzusman
Copy link
Contributor Author

mzusman commented Aug 15, 2024

This generally looks good to me and experiments with some interesting ideas. I think the biggest issue is the clashing assumption between use_fp8 and use_int8, where one is W8A8 and the other is W8A16 - we should definitely be explicit with the level of quantization if we're going this route.

It would be nice if we could select between INT8 W8A16 or W8A8 (since we already have efficient activation quant methods), like @qingquansong has proposed in #6978

In the short-term it seems like we might end up with use_fp8_w8a8, use_int8_w8a8, and use_int8_w8a16 - so it would be nice if we could work towards a more sustainable future for the code.

Thank you for the thorough review! I've changed the terms use_int8/use_fp8 to use_int8_w8a16/use_fp8_w8a8 and delivered this idea to the moe config files dtype as well.

@mzusman
Copy link
Contributor Author

mzusman commented Aug 15, 2024

CI failures seem to not be related to this PR

else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
weight_loader(param, loaded_weight, weight_name, shard_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an fyi - we're updating/expanding the weight loading for Fused MoE layers: #7527

Copy link
Sponsor Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Ditto with Dipika that it'd be good to make this work with the MoE Parameter refactor eventually

@mzusman
Copy link
Contributor Author

mzusman commented Aug 16, 2024

Thanks! I'll rebase, maybe will resolve the CI issues

@simon-mo simon-mo merged commit 7fc23be into vllm-project:main Aug 16, 2024
52 of 56 checks passed
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
zifeitong pushed a commit to zifeitong/vllm that referenced this pull request Aug 20, 2024
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
omrishiv pushed a commit to omrishiv/vllm that referenced this pull request Aug 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants