Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Prefix Caching Support #1669

Merged
merged 44 commits into from
Jan 18, 2024
Merged

Conversation

caoshiyi
Copy link
Contributor

@caoshiyi caoshiyi commented Nov 15, 2023

add prefix caching support

Section 1 (Basic Functionality):

  • Test on single request for llama model (no cpu swap)
  • Test on batches where there are requests with prefix and without prefix (no cpu swap)
  • Benchmark performance for batched analytics tasks (no cpu swap)
  • Alibi (Thanks @DouHappy )
  • Clean code

Todo:
Automatic Prefix Caching Support -- SGLang RadixAttention

@zhuohan123 zhuohan123 changed the title [WIP] Prefix [WIP] Prefix Caching Dec 4, 2023
@DouHappy
Copy link
Contributor

DouHappy commented Dec 5, 2023

It seems that the prefix has not updated its physical block?

I tested on the meta-llama/Llama-2-70b-chat-hf and baichuan2-13b-chat, but it seems to have no acceleration effect.
Then I added a logger at the 358th row of worker.py to check if prefix updated its physics block.
image

logger's output is as follow:
------start generating------
prefix length: 592
block size: 16
Processed prompts: 0%| | 0/500 [00:00<?, ?it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:31 worker.py:358] prefix_block_tables: [[], [], [], [], []]
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:36 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:42 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:47 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: [[], [], [], []]
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: []
[36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:51 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 39x across cluster][0m
Processed prompts: 0%| | 1/500 [00:23<3:18:23, 23.86s/it][36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:52 worker.py:358] prefix_block_tables: [[], [], [], []][32m [repeated 3x across cluster][0m
[36m(RayWorker pid=1138290)[0m INFO 12-05 10:46:55 worker.py:358] prefix_block_tables: [][32m [repeated 59x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:46:57 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 21x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:02 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:08 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:13 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 44x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], []]
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: []
Processed prompts: 43%|████▎ | 217/500 [00:45<00:50, 5.65it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: [[], []]
Processed prompts: 44%|████▍ | 219/500 [00:45<00:49, 5.65it/s][36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:17 worker.py:358] prefix_block_tables: [[]]
[36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 19x across cluster][0m
Processed prompts: 44%|████▍ | 220/500 [00:47<00:53, 5.24it/s][36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:15 worker.py:358] prefix_block_tables: [[], [], []][32m [repeated 3x across cluster][0m
[36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:18 worker.py:358] prefix_block_tables: [][32m [repeated 59x across cluster][0m
[36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:16 worker.py:358] prefix_block_tables: [[], []][32m [repeated 3x across cluster][0m
[36m(RayWorker pid=1138290)[0m INFO 12-05 10:47:17 worker.py:358] prefix_block_tables: [[]][32m [repeated 3x across cluster][0m
[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:23 worker.py:358] prefix_block_tables: [[], [], [], [], []][32m [repeated 41x across cluster][0m
Processed prompts: 100%|██████████| 500/500 [00:54<00:00, 9.13it/s]cost time 56.1341028213501
saving output

[36m(RayWorker pid=1138291)[0m INFO 12-05 10:47:26 worker.py:358] prefix_block_tables: []

image

By the way, My prompts are generated by this script:

# generate test prompt
test_table = "|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||"

prompt_template = '''\
You are a helpful assistant in recongnizes the content of tables in markdown format. Here is a table as follows. You need to answer my question about the table.
# Table
{}

# Question
What' s the content in the ({},{}) cells
'''
with open('prompt.txt', 'w') as outer:
    for row in range(50):
        for column in range(10):
            tmp_str = prompt_template.format(test_table, row + 1, column + 1)
            tmp_str
            # outer.write(f"{tmp_str}")
            print(tmp_str.replace("\n", "\\n"), file=outer)

Maybe there are some bugs in my test?

@caoshiyi
Copy link
Contributor Author

caoshiyi commented Dec 13, 2023

@DouHappy Can you try with calling llm.generate() with one prompt first to warmup? prefix_block_tables=[[],[],[]] indicates that the kv cache for the prefix part hasn't been computed yet. Also, can you share your testing script?

@DouHappy
Copy link
Contributor

DouHappy commented Dec 13, 2023

@DouHappy Can you try with calling llm.generate() with one prompt first to warmup? prefix_block_tables=[[],[],[]] indicates that the kv cache for the prefix part hasn't been computed yet. Also, can you share your testing script?

My test script:

# %%
# generate test prompt
test_table = "|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||"

prompt_template = '''\
You are a helpful assistant in recongnizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.
# Table
{}

# Question
What' s the content in the ({},{}) cells
'''
with open('prompt.txt', 'w') as outer:
    for row in range(50):
        for column in range(10):
            tmp_str = prompt_template.format(test_table, row + 1, column + 1)
            tmp_str
            # outer.write(f"{tmp_str}")
            print(tmp_str.replace("\n", "\\n"), file=outer)

# %%
import time
import datetime
import os

from vllm import LLM
from vllm import SamplingParams

import torch

def test_prefix(llm = None, sampling_params=None, prompts=None, prefix_len=None, save_file=None):
    # set sampling_params
    if sampling_params == None:
        sampling_params = SamplingParams(temperature=0)

    print("------start generating------")
    start_time = time.time()
    # whether use Prefix
    if prefix_len != None:
        print("warmup")
        outputs = llm.generate(prompts[0], sampling_params=sampling_params, prefix_pos=[prefix_len])
        # start inference
        outputs = llm.generate(prompts, sampling_params=sampling_params, prefix_pos=[prefix_len] * len(prompts))
    else:
        outputs = llm.generate(prompts, sampling_params=sampling_params)

    end_time = time.time()
    print(f"cost time {end_time - start_time}")

    if save_file != None:
        print("saving output......")
        for output in outputs:
            print(output, file=save_file)
        print(f"output saved in {save_file.name} {datetime.datetime.now()}")

# %%
# set gpus
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
# init model and sampling parames
tensor_parallel_size = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
# set baichuan model
model = "/data/images/llms/models--baichuan-inc--Baichuan2-13B-Chat"
# model = "/data/images/llms/chatdoc-llama-2-70b-chat-hf-checkpoint-8000"

# Create an LLM.
llm = LLM(model=model, tokenizer_mode='auto', trust_remote_code=True, tensor_parallel_size=tensor_parallel_size)

# %%
# get prompts
prompts = []
with open("prompt.txt", 'r') as reader:
    prompts = reader.readlines()[:500]

with open("output.txt", 'w') as f:
    test_prefix(llm=llm,
                prompts=prompts[:50],
                prefix_len=591,
                save_file=f
                )



I find a bug that when using two GPUs is slower than single GPU. Prefix‘s state 'on_gpu' is always False before prepare_inputs() When using two GPUs. and it works nice on single gpu. It mean multi_query_cached_kv_attention never be used when running on multi-gpus. My last test is also pass on single gpu.

one gpu with prefix
------start generating------
Processed prompts: 100%|██████████| 500/500 [00:22<00:00, 22.55it/s]cost time 23.27279233932495
saving output

one gpu without prefix
------start generating------
Processed prompts: 100%|██████████| 500/500 [01:21<00:00,  6.16it/s]cost time 82.11750793457031
saving output

but it cost about 60s on two gpus.
Although It is very fast, I can't got the right output when I use prefix.

if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len

# set the prefix state

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when tp>1, seq_group_metadata.prefix here is copied by ray workers, so on_gpu=true won't work on multi gpus.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc. @DouHappy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your reply. This help me a lot.

@caoshiyi caoshiyi marked this pull request as ready for review January 2, 2024 02:36
Co-authored-by: DouHappy <2278958187@qq.com>
@caoshiyi caoshiyi changed the title [WIP] Prefix Caching Prefix Caching Jan 2, 2024
@zhuohan123 zhuohan123 self-requested a review January 3, 2024 23:55
Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! Can you also merge with the latest main branch as well? I will test the PR after the merge.

examples/output2.txt Outdated Show resolved Hide resolved
examples/api_client.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding another example just for prefix caching?

vllm/core/block_manager.py Outdated Show resolved Hide resolved
vllm/core/block_manager.py Outdated Show resolved Hide resolved
vllm/model_executor/input_metadata.py Show resolved Hide resolved
vllm/prefix.py Outdated Show resolved Hide resolved
vllm/prefix.py Outdated Show resolved Hide resolved
vllm/prefix.py Outdated Show resolved Hide resolved
vllm/prefix.py Show resolved Hide resolved
@franklyd
Copy link

Thanks a lot for this great feature!
I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model)

Hi @DouHappy , did you observe any speed improvement afterwards?

@DouHappy
Copy link
Contributor

Thanks a lot for this great feature! I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model)

Hi @DouHappy , did you observe any speed improvement afterwards?

Yes,i got observe speed up. Could you should me your test script? Maybe you forgot warmup? BTW, I am trying to introduce prefix but only chinese version now. See this vLLM-prefix浅析(System Prompt,大模型推理加速) @franklyd

@zhuohan123
Copy link
Collaborator

@franklyd @DouHappy There was a bug in my refactor. If you try now, you should be able to see speedups.

Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the great work! Pushed some style refactors by myself.

@zhuohan123 zhuohan123 merged commit d10f8e1 into vllm-project:main Jan 18, 2024
12 of 16 checks passed
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Jan 18, 2024
Co-authored-by: DouHappy <2278958187@qq.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: DouHappy <2278958187@qq.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Copy link
Contributor

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhuohan123

I think this merge now invalidates FP8 KV cache (#2279).
Look at kernels in prefix_prefill.py, when FP8 KV cache is ON, K/V and K_cache/V_cache are different types now.

Please let know what is the best way to move forward, thanks!

@zhuohan123 zhuohan123 mentioned this pull request Feb 19, 2024
@gangooteli
Copy link

Thanks a lot for this great feature! I tried it with the latest caoshiyi:prefix, but I found that there's no speed improvement. (one V100 GPU, tested with Baichuan2-13B-chat model)
Hi @DouHappy , did you observe any speed improvement afterwards?

Yes,i got observe speed up. Could you should me your test script? Maybe you forgot warmup? BTW, I am trying to introduce prefix but only chinese version now. See this vLLM-prefix浅析(System Prompt,大模型推理加速) @franklyd

Could you provide a test script for the speedup?

@ksjadeja
Copy link

Could you provide a test script for the speedup?

+1

@AlpinDale
Copy link

Hi @HaiShaw

Triton doesn't seem to support mixed precision dot product, so this kernel here fails if the k is uint8 and q is another precision. I've been trying to find a solution to this problem, but coming up with blanks. Do you have any ideas on how to approach this?

@chenxu2048
Copy link
Contributor

Hi @HaiShaw

Triton doesn't seem to support mixed precision dot product, so this kernel here fails if the k is uint8 and q is another precision. I've been trying to find a solution to this problem, but coming up with blanks. Do you have any ideas on how to approach this?

Hi, @AlpinDale. Are you using prefix caching with FP8 KVCache? PyTorch and Triton used by vLLM could not support FP8 KVCache. Here are more information about prefix caching and FP8 KVCache in #3234.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet