Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Automatic Prefix Caching #2762

Merged
merged 88 commits into from Mar 2, 2024
Merged

Conversation

SageMoore
Copy link
Contributor

@SageMoore SageMoore commented Feb 5, 2024

resolves #2614

The goal of this diff is to allow for automatic prefix caching. This is done by adding an additional level of indirection between the logical and physical blocks which allows for identical logical blocks to map to the same physical block.

This diff replaces the existing manual prefix caching mechanism added in #1669

Before:

Logical block table --> physical block table.

After

Logical block table --> hash table --> physical block table.

The BlockAllocator class now contains a hash table that maps to full PhysicalTokenBlocks which have already been computed and can be read from by multiple Sequences in parallel. This table is accessed when PhysicalTokenBlocks are allocated and, if the caller provides a hash value and that value is in the table, allocate will return a cached block instead of making a new one. If a hash value is not passed into allocate, a "unique" block will be generated using a timestamp as the hash value. This is primarily used to allocate new partial blocks.

The hash value passed into BlockAllocator is computed by the Sequence class. The Sequence's hash method takes in a logical block index and uses all tokens leading up to and including that block to compute a unique hash.

This caching system does not currently work for partial blocks, but there is a mechanism inside of the BlockSpaceAllocator class that will "promote" a partial block to a cacheable full block when the partial block fills up. At this point the BlockSpaceManager will use the sequence to compute the hash and that hash will be added to the BlockAllocator's table, making the block usable by other sequences.

There is an eviction system as well to manage PhysicalTokenBlocks coming in and out of the cache. Eviction is triggered on allocation when there are no more available PhysicalTokenBlocks. Only PhysicalTokenBlocks with a ref count of 0 are eligible for eviction. PhysicalTokenBlocks with a ref count of 0 can be "brought back" since they are not removed from the hash table until they are evicted.

The eviction policy has two "levels" to it. The first level is Least Recently Used. A timestamp is maintained inside of each PhysicalTokenBlock that denotes when that block was last used. The eviction function simply finds the oldest one and removes it from the cache. In the case where there are multiple PhysicalTokenBlocks that have the same last accessed time, the eviction function falls back to looking at the number of prefix tokens in that block. The PhysicalTokenBlock with the highest number of prefix tokens will be evicted first. If there are multiple blocks with the same number of prefix tokens, one is arbitrarily chosen.

prefix_block_tables.append(prefix.get_block_numbers())
else:
prefix_block_tables.append([])
prefix_block_tables.append([])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore Was thinking more about this. Is there a separate spot in the code where we propagate the information about which prefixes where found in the hash table?

With manual prefix caching case, the prefix is passed by the user, and we set the context_lens (i.e. the length of the prefix) based on this. This metadata is then used by the model to only run forward on the new input_tokens and to use the kv_caches of the cached prefix during attention (e.g. calling context_forward_attn)

context_attention_fwd(

I wasn't sure if there is a separate spot in the code where this logic sits

Copy link
Contributor

@mgoin mgoin Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be addressed now, thanks for catching.

@jadielam
Copy link

jadielam commented Feb 8, 2024

Related PR here: #2511
Not exact same work. This one bounds the cache to not run out of memory.

@zhuohan123 zhuohan123 mentioned this pull request Mar 2, 2024
3 tasks
Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the great contribution! I will submit some small style fixes in a separate PR.

@shixianc
Copy link

shixianc commented Mar 4, 2024

I have a probably dumb question:
Can anyone explain to me the difference between this CR vs. what's mentioned in the original vLLM paper - Page7 - "Shared prefix" Section:

In vLLM, this can be conveniently
achieved by reserving a set of physical blocks for a set of
predefined shared prefixes by the LLM service provider, as
how OS handles shared library across processes. A user in-
put prompt with the shared prefix can simply map its logi-
cal blocks to the cached physical blocks (with the last block
marked copy-on-write). The prompt phase computation only
needs to execute on the user’s task input.

I've been using vLLM with the belief that it already does automatic prefix caching between input prompts, but from your CR it apparently doesn't. What does the paper actually suggest then?

@robertgshaw2-neuralmagic
Copy link
Collaborator

I have a probably dumb question: Can anyone explain to me the difference between this CR vs. what's mentioned in the original vLLM paper - Page7 - "Shared prefix" Section:

In vLLM, this can be conveniently
achieved by reserving a set of physical blocks for a set of
predefined shared prefixes by the LLM service provider, as
how OS handles shared library across processes. A user in-
put prompt with the shared prefix can simply map its logi-
cal blocks to the cached physical blocks (with the last block
marked copy-on-write). The prompt phase computation only
needs to execute on the user’s task input.

I've been using vLLM with the belief that it already does automatic prefix caching between input prompts, but from your CR it apparently doesn't. What does the paper actually suggest then?

In the paper, they were talking about beam search or generating n>1 samples for the same prompt. In each case, there is one prefill and many sequences generated. So vLLM shares the KVs across the many sequences generated for ONE REQUEST

This diff caches KVs automatically shares them ACROSS REQUESTS

def hash_of_block(self, logical_idx: int) -> int:
# Compute the number of tokens in the sequence
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please explain why hash of block does not take into account its position? Multiple prompts combined from the same input id blocks have different meaning, because positional embedding is applied.

dtransposed pushed a commit to afeldman-nm/vllm that referenced this pull request Mar 26, 2024
Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
dtransposed pushed a commit to afeldman-nm/vllm that referenced this pull request Mar 26, 2024
@SteveJonathon
Copy link

QQ: Is there a schedule policy based on the longest prefix match, just like SGLang? Thanks,

@matthieu-zimmer
Copy link

@SageMoore will this also cache multi-turn/conversation queries?

Assuming the first request is A and that vllm generates B, if the second request is A+B+C, will A+B already be cached?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Automatic Prefix Caching