Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose kv_cache in generator API #765

Open
gautierdag opened this issue Mar 21, 2024 · 1 comment
Open

Expose kv_cache in generator API #765

gautierdag opened this issue Mar 21, 2024 · 1 comment

Comments

@gautierdag
Copy link

gautierdag commented Mar 21, 2024

Presentation of the new feature

I'm dealing with chat/agents long contexts, where the context grows with each interaction. An easy optimisation is to keep the KV_cache in memory. This can be done in the naive transformers library by passing past_key_values and is in fact done under the hood in outlines.

The problem is that outlines does not expose this functionality and therefore the model has to recompute the kv_cache after every interaction (user chat message).

Where does it fit in Outlines?

I think the easiest way to fix this would be to expose the kv_cache variable in sequence_generator and set it as a function parameter instead.

def sequence_generator(
    model: Callable,
    sampler: Callable,
    fsms: List["Guide"],
    token_ids: torch.Tensor,
    sequence_weights: torch.Tensor,
    attention_masks: torch.Tensor,
    fsm_states: List[int],
    rng: torch.Generator = torch.Generator(),
    kv_cache = None # add this
) -> Iterator[GenerationState]:

Then do the same for the SequenceGenerator class.

Downside, is that this doesn't really apply to other models and there are already a lot of arguments being passed around.

Are you willing to open a PR?

Yes, just curious to hear if that something that would be accepted before I draft this.

Distantly related to #667 #452

@miftahmoha
Copy link
Contributor

I'm currently working on #667, this is something I've done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants