Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MLA for DeepSeek-V2 with Triton - step 1 #905

Merged
merged 7 commits into from
Aug 4, 2024

Conversation

ispobock
Copy link
Collaborator

@ispobock ispobock commented Aug 3, 2024

Motivation

MLA implementation.

Modification

  • MLA forward
  • memory & triton kernel adaptation
  • benchmark & evaluation

@zhyncs
Copy link
Member

zhyncs commented Aug 3, 2024

@ispobock Nice work!

@zhyncs zhyncs self-assigned this Aug 3, 2024
@zhyncs zhyncs marked this pull request as draft August 3, 2024 18:24
@zhyncs zhyncs added the wip label Aug 3, 2024
@merrymercy merrymercy marked this pull request as ready for review August 4, 2024 02:19
@ispobock
Copy link
Collaborator Author

ispobock commented Aug 4, 2024

Logits diff

HF:

prefill logits tensor([14.9297, 12.5312, 11.5078,  ...,  3.4199,  3.7383,  3.7520], device='cuda:0')
prefill logits tensor([14.2188, 11.8984, 11.7734,  ...,  3.7812,  4.0625,  4.0664], device='cuda:0')
prefill logits tensor([19.5938, 13.4219, 11.8906,  ...,  3.6426,  3.6973,  3.8281], device='cuda:0')

sglang MLA:

prefill logits (final) tensor([[14.8750, 12.5000, 11.5000,  ...,  3.3594,  3.6719,  3.6875],
        [14.0625, 11.8125, 11.7500,  ...,  3.5000,  3.7656,  3.7812],
        [19.3750, 13.2500, 11.7500,  ...,  3.4844,  3.5312,  3.6562]],
       device='cuda:0')

Benchmark

  • model: DeepSeek-V2-Lite
  • dataset: ShareGPT
  • num_prompts: 5000
  • A100-80G

main branch + triton kernel:

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    128.0
Successful requests:                     5000
Benchmark duration (s):                  561.05
Total input tokens:                      1187865
Total generated tokens:                  1089941
Total generated tokens (retokenized):    1088599
Request throughput (req/s):              8.91
Input token throughput (tok/s):          2117.20
Output token throughput (tok/s):         1942.66
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   229316.60
Median E2E Latency (ms):                 234510.04
---------------Time to First Token----------------
Mean TTFT (ms):                          205445.14
Median TTFT (ms):                        211642.51
P99 TTFT (ms):                           413023.93
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          117.57
Median TPOT (ms):                        113.63
P99 TPOT (ms):                           307.47
---------------Inter-token Latency----------------
Mean ITL (ms):                           1079.55
Median ITL (ms):                         80.03
P99 ITL (ms):                            681.20
==================================================

main branch + flashinfer kernel:

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    128.0
Successful requests:                     5000
Benchmark duration (s):                  488.49
Total input tokens:                      1187865
Total generated tokens:                  1089941
Total generated tokens (retokenized):    1088617
Request throughput (req/s):              10.24
Input token throughput (tok/s):          2431.72
Output token throughput (tok/s):         2231.26
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   217337.71
Median E2E Latency (ms):                 221728.97
---------------Time to First Token----------------
Mean TTFT (ms):                          195409.99
Median TTFT (ms):                        200734.20
P99 TTFT (ms):                           401025.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          106.12
Median TPOT (ms):                        103.63
P99 TPOT (ms):                           183.47
---------------Inter-token Latency----------------
Mean ITL (ms):                           1023.47
Median ITL (ms):                         71.94
P99 ITL (ms):                            675.86
==================================================

MLA (this PR):

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    128.0
Successful requests:                     5000
Benchmark duration (s):                  454.67
Total input tokens:                      1187865
Total generated tokens:                  1089941
Total generated tokens (retokenized):    1088651
Request throughput (req/s):              11.00
Input token throughput (tok/s):          2612.61
Output token throughput (tok/s):         2397.23
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   157306.30
Median E2E Latency (ms):                 135071.76
---------------Time to First Token----------------
Mean TTFT (ms):                          4326.20
Median TTFT (ms):                        1761.04
P99 TTFT (ms):                           22399.46
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          1415.77
Median TPOT (ms):                        890.21
P99 TPOT (ms):                           6931.18
---------------Inter-token Latency----------------
Mean ITL (ms):                           744.06
Median ITL (ms):                         541.06
P99 ITL (ms):                            3870.68
==================================================

reproduce:

# main branch triton
python3 -m sglang.launch_server --model-path DeepSeek-V2-Lite --port 30000 --trust-remote-code --disable-flashinfer --disable-radix-cache
# main branch flashinfer
python3 -m sglang.launch_server --model-path DeepSeek-V2-Lite --port 30000 --trust-remote-code --disable-radix-cache
# mla
python3 -m sglang.launch_server --model-path DeepSeek-V2-Lite --port 30000 --trust-remote-code --disable-radix-cache --enable-mla

python3 -m sglang.bench_serving --backend sglang --tokenizer DeepSeek-V2-Lite --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 5000 --request-rate 128

Evaluation

mmlu average accuracy:

  • main branch: 0.579
  • MLA: 0.579

gsm8k accuracy:

  • main branch: 0.360
  • MLA: 0.365

reproduce:

# main branch
python3 -m sglang.launch_server --model-path DeepSeek-V2-Lite --port 30000 --trust-remote-code --disable-radix-cache
# mla
python3 -m sglang.launch_server --model-path DeepSeek-V2-Lite --port 30000 --trust-remote-code --disable-radix-cache --enable-mla

python3 benchmark/mmlu/bench_sglang.py
python3 benchmark/gsm8k/bench_sglang.py

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

Hi @ispobock Very impressive throughput improvement, may you test the eval, such as MMLU and gsm8k, and also do a regression test on Llama 3 8B Instruct with/without FlashInfer? Thanks.

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

Hi @ispobock You may add --disable-radix-cache to disable the tree cache.

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

After completing the benchmark and evaluation, we might add a switch to the MLA feature, keeping FlashInfer as the default for now on DeepSeek V2. Currently, even the main branch struggles with running DeepSeek V2 on H100s due to issues with Triton's implementation. #913

We could look into adding weight fusion support in another PR. Special thanks to @ispobock's contribution and @grimoire for previously implementing the MLA version of DeepSeek V2 in LMDeploy PyTorch Engine, which has been incredibly helpful and inspiring. https://github.com/InternLM/lmdeploy/pull/1621/files

The initial MLA implementation on Triton significantly outperforms MHA on Triton. We're considering incorporating MLA Attention into FlashInfer moving forward and would appreciate if @ispobock could explore this possibility. Looking forward to an update from @yzh119.

Do you have any suggestions? Thanks. @merrymercy @Ying1123 @hnyls2002 @yzh119

@zhyncs zhyncs removed the wip label Aug 4, 2024
@zhyncs zhyncs changed the title Support MLA for DeepSeek-V2 Support MLA for DeepSeek-V2 with Triton - step 1 Aug 4, 2024
@ispobock
Copy link
Collaborator Author

ispobock commented Aug 4, 2024

You may add --disable-radix-cache to disable the tree cache.

Benchmark result is updated.

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

A100 80G
ShareGPT 1k

# vLLM
python -m vllm.entrypoints.openai.api_server --model  deepseek-ai/DeepSeek-V2-Lite --disable-log-requests --trust-remote-code --max-model-len 4096
# main FlashInfer
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-Lite --disable-radix-cache --trust-remote-code
# main Triton
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-Lite --disable-radix-cache --trust-remote-code --disable-flashinfer
# MLA Triton
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-Lite --disable-radix-cache --trust-remote-code

# client
python3 -m sglang.bench_serving --backend vllm
python3 -m sglang.bench_serving --backend sglang
============ Serving Benchmark Result ============
Backend:                                 vllm
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  172.72
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215357
Request throughput (req/s):              5.79
Input token throughput (tok/s):          1367.22
Output token throughput (tok/s):         1248.36
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   84630.85
Median E2E Latency (ms):                 85899.10
---------------Time to First Token----------------
Mean TTFT (ms):                          48520.47
Median TTFT (ms):                        43484.37
P99 TTFT (ms):                           135740.63
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          191.38
Median TPOT (ms):                        190.46
P99 TPOT (ms):                           482.57
---------------Inter-token Latency----------------
Mean ITL (ms):                           394.61
Median ITL (ms):                         150.76
P99 ITL (ms):                            494.64
==================================================

# FlashInfer
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  89.33
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215327
Request throughput (req/s):              11.19
Input token throughput (tok/s):          2643.61
Output token throughput (tok/s):         2413.80
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   41187.40
Median E2E Latency (ms):                 40548.85
---------------Time to First Token----------------
Mean TTFT (ms):                          24502.59
Median TTFT (ms):                        23297.79
P99 TTFT (ms):                           62219.05
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          94.00
Median TPOT (ms):                        86.57
P99 TPOT (ms):                           263.22
---------------Inter-token Latency----------------
Mean ITL (ms):                           192.80
Median ITL (ms):                         59.62
P99 ITL (ms):                            271.23
==================================================

# Triton
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  106.50
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215344
Request throughput (req/s):              9.39
Input token throughput (tok/s):          2217.31
Output token throughput (tok/s):         2024.55
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   44904.45
Median E2E Latency (ms):                 44066.16
---------------Time to First Token----------------
Mean TTFT (ms):                          26872.67
Median TTFT (ms):                        26855.06
P99 TTFT (ms):                           67354.23
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          100.41
Median TPOT (ms):                        92.15
P99 TPOT (ms):                           278.48
---------------Inter-token Latency----------------
Mean ITL (ms):                           210.10
Median ITL (ms):                         62.88
P99 ITL (ms):                            360.94
==================================================

# Triton MLA
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  93.36
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215355
Request throughput (req/s):              10.71
Input token throughput (tok/s):          2529.41
Output token throughput (tok/s):         2309.53
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   38101.47
Median E2E Latency (ms):                 33589.46
---------------Time to First Token----------------
Mean TTFT (ms):                          7483.10
Median TTFT (ms):                        7373.59
P99 TTFT (ms):                           13939.18
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          335.31
Median TPOT (ms):                        168.04
P99 TPOT (ms):                           2129.49
---------------Inter-token Latency----------------
Mean ITL (ms):                           181.06
Median ITL (ms):                         115.35
P99 ITL (ms):                            444.42
==================================================

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

After this commit 94b1578, when we want to use MLA, we should add --enable-mla.

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

A100 80G x8
ShareGPT 1k

# main FlashInfer
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2 --disable-radix-cache --trust-remote-code --tp 8
# mla enable-mla
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2 --disable-radix-cache --trust-remote-code --tp 8 --enable-mla

# main FlashInfer
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  635.63
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215021
Request throughput (req/s):              1.57
Input token throughput (tok/s):          371.51
Output token throughput (tok/s):         339.22
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   312719.92
Median E2E Latency (ms):                 318283.73
---------------Time to First Token----------------
Mean TTFT (ms):                          289771.06
Median TTFT (ms):                        294630.15
P99 TTFT (ms):                           577199.48
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          98.64
Median TPOT (ms):                        52.54
P99 TPOT (ms):                           1393.25
---------------Inter-token Latency----------------
Mean ITL (ms):                           1467.39
Median ITL (ms):                         46.78
P99 ITL (ms):                            275.39
==================================================

# Triton MLA
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  267.62
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215061
Request throughput (req/s):              3.74
Input token throughput (tok/s):          882.36
Output token throughput (tok/s):         805.66
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   109370.04
Median E2E Latency (ms):                 108220.54
---------------Time to First Token----------------
Mean TTFT (ms):                          58931.99
Median TTFT (ms):                        38989.13
P99 TTFT (ms):                           148840.96
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          276.88
Median TPOT (ms):                        258.92
P99 TPOT (ms):                           748.81
---------------Inter-token Latency----------------
Mean ITL (ms):                           512.84
Median ITL (ms):                         185.41
P99 ITL (ms):                            720.15
==================================================

Copy link
Member

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ying1123
Copy link
Member

Ying1123 commented Aug 4, 2024

LGTM. Very nice work!! 🎉

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

hold on

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

@ispobock We may add a regression testing with Llama 3.

@ispobock
Copy link
Collaborator Author

ispobock commented Aug 4, 2024

We may add a regression testing with Llama 3.

There is an issue when disabling flashinfer, let me fix it.

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

We may add a regression testing with Llama 3.

There is an issue when disabling flashinfer, let me fix it.

ok

@ispobock
Copy link
Collaborator Author

ispobock commented Aug 4, 2024

Evaluated the average accuracy on mmlu by python3 benchmark/mmlu/bench_sglang.py --nsub 10:

Llama-3-8B:

  • flashinfer: 0.619
  • triton: 0.618

DeepSeek-V2-Lite:

  • flashinfer: 0.543
  • triton: 0.540
  • mla triton: 0.549

@zhyncs zhyncs changed the title Support MLA for DeepSeek-V2 with Triton - step 1 [DO NOT MERGE] Support MLA for DeepSeek-V2 with Triton - step 1 Aug 4, 2024
@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

TP cases

# flashinfer
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-lite --disable-radix-cache --trust-remote-code --tp 2

# triton
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-lite --disable-radix-cache --trust-remote-code --tp 2 --disable-flashinfer

# mla
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2-lite --disable-radix-cache --trust-remote-code --tp 2 --enable-mla
python3 -m sglang.bench_serving --backend sglang

python3 benchmark/mmlu/bench_sglang.py --nsub 10
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  47.37
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215356
Request throughput (req/s):              21.11
Input token throughput (tok/s):          4984.92
Output token throughput (tok/s):         4551.58
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   20699.26
Median E2E Latency (ms):                 18888.38
---------------Time to First Token----------------
Mean TTFT (ms):                          5280.38
Median TTFT (ms):                        4278.32
P99 TTFT (ms):                           12355.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          140.18
Median TPOT (ms):                        87.34
P99 TPOT (ms):                           668.44
---------------Inter-token Latency----------------
Mean ITL (ms):                           98.27
Median ITL (ms):                         54.34
P99 ITL (ms):                            501.68
==================================================

subject: abstract_algebra, #q:100, acc: 0.300
subject: anatomy, #q:135, acc: 0.511
subject: astronomy, #q:152, acc: 0.586
subject: business_ethics, #q:100, acc: 0.600
subject: clinical_knowledge, #q:265, acc: 0.630
subject: college_biology, #q:144, acc: 0.646
subject: college_chemistry, #q:100, acc: 0.420
subject: college_computer_science, #q:100, acc: 0.470
subject: college_mathematics, #q:100, acc: 0.400
subject: college_medicine, #q:173, acc: 0.601
Total latency: 16.818
Average accuracy: 0.541

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  71.50
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215353
Request throughput (req/s):              13.99
Input token throughput (tok/s):          3302.80
Output token throughput (tok/s):         3015.69
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   24554.45
Median E2E Latency (ms):                 21974.87
---------------Time to First Token----------------
Mean TTFT (ms):                          6747.84
Median TTFT (ms):                        5495.57
P99 TTFT (ms):                           14803.86
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          157.37
Median TPOT (ms):                        96.33
P99 TPOT (ms):                           698.46
---------------Inter-token Latency----------------
Mean ITL (ms):                           117.50
Median ITL (ms):                         60.15
P99 ITL (ms):                            536.09
==================================================

subject: abstract_algebra, #q:100, acc: 0.340
subject: anatomy, #q:135, acc: 0.541
subject: astronomy, #q:152, acc: 0.579
subject: business_ethics, #q:100, acc: 0.630
subject: clinical_knowledge, #q:265, acc: 0.642
subject: college_biology, #q:144, acc: 0.639
subject: college_chemistry, #q:100, acc: 0.430
subject: college_computer_science, #q:100, acc: 0.470
subject: college_mathematics, #q:100, acc: 0.400
subject: college_medicine, #q:173, acc: 0.618
Total latency: 17.655
Average accuracy: 0.553

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  71.77
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215342
Request throughput (req/s):              13.93
Input token throughput (tok/s):          3290.49
Output token throughput (tok/s):         3004.44
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   25321.60
Median E2E Latency (ms):                 22072.11
---------------Time to First Token----------------
Mean TTFT (ms):                          6508.90
Median TTFT (ms):                        6170.12
P99 TTFT (ms):                           10115.41
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          198.78
Median TPOT (ms):                        99.93
P99 TPOT (ms):                           1205.42
---------------Inter-token Latency----------------
Mean ITL (ms):                           121.54
Median ITL (ms):                         67.46
P99 ITL (ms):                            412.39
==================================================

subject: abstract_algebra, #q:100, acc: 0.360
subject: anatomy, #q:135, acc: 0.504
subject: astronomy, #q:152, acc: 0.572
subject: business_ethics, #q:100, acc: 0.630
subject: clinical_knowledge, #q:265, acc: 0.653
subject: college_biology, #q:144, acc: 0.653
subject: college_chemistry, #q:100, acc: 0.410
subject: college_computer_science, #q:100, acc: 0.450
subject: college_mathematics, #q:100, acc: 0.410
subject: college_medicine, #q:173, acc: 0.607
Total latency: 18.795
Average accuracy: 0.550

@zhyncs zhyncs changed the title [DO NOT MERGE] Support MLA for DeepSeek-V2 with Triton - step 1 Support MLA for DeepSeek-V2 with Triton - step 1 Aug 4, 2024
@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

Conclusion:

  1. For the lite model, using tp1 is sufficient, and the default FlashInfer can be used (currently).
  2. For the v2 model, when using tp8, it is recommended to --enable-mla.

@zhyncs zhyncs merged commit e1eae1f into sgl-project:main Aug 4, 2024
3 checks passed
@yzh119
Copy link
Collaborator

yzh119 commented Aug 4, 2024

Great work! I might have some bandwidth to work on flashinfer's MLA next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants