Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies #4696

Merged
merged 12 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions benchmarks/overheads/benchmark_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import cProfile
import pstats

from vllm import LLM, SamplingParams

# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)


def main(args):
llm = LLM(
model=args.model,
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
profiler = cProfile.Profile()

print("------warm up------")
for i in range(3):
output = llm.generate(LONG_PROMPT, sampling_params)
print(output[0].outputs[0].text)

print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())

# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)
18 changes: 15 additions & 3 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

from vllm.block import LogicalTokenBlock
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -119,6 +119,7 @@ def __init__(
output_token_ids = []

self.prompt_token_ids = prompt_token_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to store it as both list and a tuple?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we do not have to, as long as there is no code that directly change the value of prompt_token_ids attribute inside the class SequenceGroup.
I did a code search that confirms that there is no such code inside vLLM repo now. The search keyword is ".prompt_token_ids"
image
And there is no such code that changes prompt_token_ids besides the initialization function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make it a tuple then, IMO. This also signifies prompt tokens are immutable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, most of the existing codes are assuming the attribute prompt_token_ids to be of type List[int], and only the hashing function requires it to be Tuple[int] so that it is immutable and thus hashable. From this perspective, I guess it is worthwhile to have an extra copy of prompt_token_ids so that both hashing is fast, and developers can still treat prompt_token_ids as a list.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we at least make the attribute with list private and only allow access (and not setting) through a property? Not perfect since it can still be modified in place but better than nothing.

Another idea is to subclass list to create FrozenList with mutable methods set to raise an exception but that's a lot more complex.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since immutable prompt token IDs is the most important assumption, I'd also suggest to change it to tuple directly. It seems not an issue to me to change all List[int] to Tuple[int, ...]. At least type annotation should not be a blocker.

Copy link
Collaborator Author

@KuntaiDu KuntaiDu May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the data type of prompt_token_ids to Tuple[int, ...], and fixed typing conflicts caused by this change.

self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
Expand All @@ -141,6 +142,17 @@ def get_output_len(self) -> int:
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids

def get_prefix_token_ids(
self, num_tokens: int
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length = len(self.prompt_token_ids)
if num_tokens > prompt_length:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when does this happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens when calculating hashes for both the user input (i.e. the prompt tokens) and the LLM-generated output (i.e. output tokens ).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it happen under normal circumstance or is it only for recomputation case?

Copy link
Collaborator Author

@KuntaiDu KuntaiDu May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will happen for normal cases (inside function _allocate_last_physical_block in vllm/sequence.py).

return (self._prompt_token_ids_tuple,
tuple(self.output_token_ids[:num_tokens - prompt_length]))
else:
return (self._prompt_token_ids_tuple[:num_tokens], None)

def get_num_computed_tokens(self) -> int:
"""Return the number of prefill tokens that are already computed."""
return self._num_computed_tokens
Expand Down Expand Up @@ -251,8 +263,8 @@ def hash_of_block(self, logical_idx: int) -> int:
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
return hash(
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
return hash((hashed_tokens, self.lora_int_id))

def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
Expand Down
Loading