[WP] PagedAttention + Prefix Cache for FlashAttention2 #36737
+176
−4
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I am not sure this is very useful or not for transformers future development, but this implementation consistent with vLLM and other text inference engine where the input shape is [1, L, D] where L is consist of multiple sequences and this input shape is not consistent with typical forward design which is required input shape [B, L, D].
Because Flash Attention able to infer on [1, L, D] as long we store the cumulative length, this complement with how Paged KV Cache design to gather back the sequences from blocks to become [1, L, D] without required need to pad at all.
Example code,
As you can see sequence ID
10
shared the same block as0
due to same input tokens.This work no yet validate the accuracy but basic implementation is there, super simple, so this useful to design lightweight continuous batching, im open for discussion.