-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies #4696
[Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies #4696
Changes from 4 commits
16fd11a
e9dc917
8112e0d
57fd939
803572e
b9b38c0
e0691a6
bf498b0
b352246
a868884
49d994d
e95d660
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import cProfile | ||
import pstats | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
# A very long prompt, total number of tokens is about 15k. | ||
LONG_PROMPT = ["You are an expert in large language models, aren't you?" | ||
] * 1000 | ||
LONG_PROMPT = ' '.join(LONG_PROMPT) | ||
|
||
|
||
def main(args): | ||
llm = LLM( | ||
model=args.model, | ||
enforce_eager=True, | ||
enable_prefix_caching=True, | ||
tensor_parallel_size=args.tensor_parallel_size, | ||
use_v2_block_manager=args.use_v2_block_manager, | ||
) | ||
|
||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) | ||
profiler = cProfile.Profile() | ||
|
||
print("------warm up------") | ||
for i in range(3): | ||
output = llm.generate(LONG_PROMPT, sampling_params) | ||
print(output[0].outputs[0].text) | ||
|
||
print("------start generating------") | ||
for i in range(3): | ||
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', | ||
globals(), locals()) | ||
|
||
# analyze the runtime of hashing function | ||
stats = pstats.Stats(profiler) | ||
stats.sort_stats('cumulative') | ||
total_time = 0 | ||
total_calls = 0 | ||
for func in stats.stats: | ||
if 'hash_of_block' in func[2]: | ||
total_time = stats.stats[func][3] | ||
total_calls = stats.stats[func][0] | ||
percentage = (total_time / stats.total_tt) * 100 | ||
print(f"Hashing took {total_time:.2f} seconds," | ||
f"{percentage:.2f}% of the total runtime.") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description='Benchmark the performance of hashing function in' | ||
'automatic prefix caching.') | ||
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') | ||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) | ||
parser.add_argument('--output-len', type=int, default=10) | ||
parser.add_argument('--enable-prefix-caching', | ||
action='store_true', | ||
help='enable prefix caching') | ||
parser.add_argument('--use-v2-block-manager', | ||
action='store_true', | ||
help='Use BlockSpaceMangerV2') | ||
args = parser.parse_args() | ||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import copy | ||
import enum | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Dict, List, Optional, Union | ||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union | ||
|
||
from vllm.block import LogicalTokenBlock | ||
from vllm.lora.request import LoRARequest | ||
|
@@ -119,6 +119,7 @@ def __init__( | |
output_token_ids = [] | ||
|
||
self.prompt_token_ids = prompt_token_ids | ||
self.prompt_token_ids_tuple = tuple(prompt_token_ids) | ||
self.output_token_ids = output_token_ids | ||
self.cumulative_logprob = 0.0 | ||
# The number of tokens that are computed (that run against the model). | ||
|
@@ -141,6 +142,17 @@ def get_output_len(self) -> int: | |
def get_token_ids(self) -> List[int]: | ||
return self.prompt_token_ids + self.output_token_ids | ||
|
||
def get_prefix_token_ids( | ||
self, num_tokens: int | ||
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: | ||
"""Get prefix tokens, and make the return value hashable""" | ||
prompt_length = len(self.prompt_token_ids_tuple) | ||
if num_tokens > prompt_length: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when does this happen? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This happens when calculating hashes for both the user input (i.e. the prompt tokens) and the LLM-generated output (i.e. output tokens ). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can it happen under normal circumstance or is it only for recomputation case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it will happen for normal cases (inside function |
||
return (self.prompt_token_ids_tuple, | ||
tuple(self.output_token_ids[:num_tokens - prompt_length])) | ||
else: | ||
return (self.prompt_token_ids_tuple[:num_tokens], None) | ||
|
||
def get_num_computed_tokens(self) -> int: | ||
"""Return the number of prefill tokens that are already computed.""" | ||
return self._num_computed_tokens | ||
|
@@ -245,14 +257,9 @@ def get_output_text_to_return(self, buffer_length: int): | |
self.output_text) | ||
|
||
def hash_of_block(self, logical_idx: int) -> int: | ||
# TODO This can produce incorrect hash when block size > prompt size | ||
|
||
# Compute the number of tokens in the sequence | ||
# TODO: The current hashing function is O(L^2). We should optimize | ||
# this in the future. | ||
num_tokens = self.num_hashed_tokens_of_block(logical_idx) | ||
return hash( | ||
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id)) | ||
hashed_tokens = self.data.get_prefix_token_ids(num_tokens) | ||
return hash((hashed_tokens, self.lora_int_id)) | ||
|
||
def num_hashed_tokens_of_block(self, logical_idx: int): | ||
return logical_idx * self.block_size + self.block_size | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to store it as both list and a tuple?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically we do not have to, as long as there is no code that directly change the value of
![image](https://private-user-images.githubusercontent.com/16879550/329095735-24d1e570-1d94-412e-842a-c9121bb47e9e.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjA1NzkxMzcsIm5iZiI6MTcyMDU3ODgzNywicGF0aCI6Ii8xNjg3OTU1MC8zMjkwOTU3MzUtMjRkMWU1NzAtMWQ5NC00MTJlLTg0MmEtYzkxMjFiYjQ3ZTllLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MTAlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzEwVDAyMzM1N1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPThjY2Y1MzE2MjZhZGI2NjE1NDI0ODdkNzNjMzBiN2UxNzYzNzUxZmI5Mjc1MTE1MWI4YjE3MjI5YWIxMmU4NmEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.4nUdGJDlOw7otYqPSEaXx2woPi_t2HYj9lH0MHp8328)
prompt_token_ids
attribute inside the classSequenceGroup
.I did a code search that confirms that there is no such code inside vLLM repo now. The search keyword is ".prompt_token_ids"
And there is no such code that changes
prompt_token_ids
besides the initialization function.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make it a tuple then, IMO. This also signifies prompt tokens are immutable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, most of the existing codes are assuming the attribute
prompt_token_ids
to be of typeList[int]
, and only the hashing function requires it to beTuple[int]
so that it is immutable and thus hashable. From this perspective, I guess it is worthwhile to have an extra copy ofprompt_token_ids
so that both hashing is fast, and developers can still treatprompt_token_ids
as a list.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we at least make the attribute with list private and only allow access (and not setting) through a property? Not perfect since it can still be modified in place but better than nothing.
Another idea is to subclass list to create
FrozenList
with mutable methods set to raise an exception but that's a lot more complex.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since immutable prompt token IDs is the most important assumption, I'd also suggest to change it to tuple directly. It seems not an issue to me to change all List[int] to Tuple[int, ...]. At least type annotation should not be a blocker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed the data type of
prompt_token_ids
toTuple[int, ...]
, and fixed typing conflicts caused by this change.