Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WP] PagedAttention + Prefix Cache for FlashAttention2 #36737

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

huseinzol05
Copy link
Contributor

@huseinzol05 huseinzol05 commented Mar 15, 2025

I am not sure this is very useful or not for transformers future development, but this implementation consistent with vLLM and other text inference engine where the input shape is [1, L, D] where L is consist of multiple sequences and this input shape is not consistent with typical forward design which is required input shape [B, L, D].

Because Flash Attention able to infer on [1, L, D] as long we store the cumulative length, this complement with how Paged KV Cache design to gather back the sequences from blocks to become [1, L, D] without required need to pad at all.

Example code,

import transformers
import torch
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers.cache_utils import PagedCache

dtype = torch.float16
device = 'cuda'

tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-135M-Instruct')
model = LlamaForCausalLM.from_pretrained(
    'HuggingFaceTB/SmolLM2-135M-Instruct', 
    attn_implementation = 'flash_attention_2', torch_dtype = dtype).to(device)

max_blocks = 256
block_size = 16
num_heads = model.config.num_key_value_heads
head_dim = model.config.head_dim
num_layers = model.config.num_hidden_layers

cache = PagedCache(
    max_blocks, 
    block_size, 
    num_heads, 
    head_dim, 
    num_layers,
    dtype,
    device
)

input_ids = tokenizer(['hello what is your name', 'where is'], 
                      return_tensors = 'pt', padding = True).to('cuda')
input_ids['input_ids'] = input_ids['input_ids'][input_ids['input_ids'] != tokenizer.pad_token_id][None]
cache_position = input_ids['attention_mask'].sum(dim = 1)
with torch.no_grad():
    outputs = model(**input_ids, 
          use_cache = True, 
          past_key_values = cache,
          cache_position = cache_position,
          prefilling = True,
          sequence_ids = [str(i) for i in range(input_ids['attention_mask'].shape[0])],
          cache_input_ids = input_ids['input_ids'],
    )

# prefix caching on same input ids but different sequence ids to simulate
with torch.no_grad():
    outputs = model(**input_ids, 
          use_cache = True, 
          past_key_values = cache,
          cache_position = cache_position,
          prefilling = True,
          sequence_ids = [str(i + 10) for i in range(input_ids['attention_mask'].shape[0])],
          cache_input_ids = input_ids['input_ids'],
    )

print(cache.sequence_map)
defaultdict(dict,
            {'0': {0: {'block_mapping': {0: 0}, 'seq_len': 5},
             '10': {0: {'block_mapping': {0: 0}, 'seq_len': 5},

As you can see sequence ID 10 shared the same block as 0 due to same input tokens.

This work no yet validate the accuracy but basic implementation is there, super simple, so this useful to design lightweight continuous batching, im open for discussion.

@github-actions github-actions bot marked this pull request as draft March 15, 2025 07:35
Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@Rocketknight1
Copy link
Member

I don't fully follow but I think the idea is to flatten multiple sequences in a batch into a single long sequence and then process it in chunks using sequence_ids (which I think we usually call position_ids). The idea is to not waste computation on lots of padding tokens when sequences have very different lengths. Is that correct?

cc @gante, but I think this might require quite a large redesign!

@gante
Copy link
Member

gante commented Mar 19, 2025

cc @ArthurZucker, who's taking the lead in the implementation of continuous batching (#35727)

@huseinzol05
Copy link
Contributor Author

huseinzol05 commented Mar 21, 2025

I don't fully follow but I think the idea is to flatten multiple sequences in a batch into a single long sequence and then process it in chunks using sequence_ids (which I think we usually call position_ids). The idea is to not waste computation on lots of padding tokens when sequences have very different lengths. Is that correct?

cc @gante, but I think this might require quite a large redesign!

Yes that is correct. Current implementation is minimal just to show the minimum working, but we can standardize it. For prefix caching, there are so much things we can do to optimize it make the hits better.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice! I am indeed working on this #35727 to make sure our generation is also faster.
Thanks a lot for your contribution!

THe main problem with your approach is that it introduces modeling changes, which are in general the worst + torch arange does not work well with cuda graphs!

Gimme a bit of time ~1week and I should be able to finish my PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants