Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One small step towards continuous generation for Outlines #781

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

miftahmoha
Copy link
Contributor

@miftahmoha miftahmoha commented Apr 1, 2024

This PR is for solving #667.

last_tests

last_pre_commit

API

import outlines
from outlines.generate import continuous

generator = outlines.generate.text(model)
generator_c = continuous(generator)

response = generator_c(prompt, max_tokens=30)

continuous

continuous wraps any SequenceGenerator object, it could be:

  • outlines.generate.choice
  • outlines.generate.text
  • outlines.generate.json
  • ...

The continuous wrapper allows the generator to save the state of a Sequence, it means that, if you continuously generate a sequence as shown:

import outlines
from outlines.generate import continuous

generator = outlines.generate.text(model)
generator_c = continuous(generator)

response_1 = generator_c(prompt, max_tokens=100)
response_2 = generator_c(response)

KV Cache (under some conditions) will be saved. Algorithms such as beam search could be used to optimize the whole sequence rather than separately.

import outlines
from outlines.generate import continuous

generator = outlines.generate.text(model, sampler=BeamSearchSampler(3))
generator_c = continuous(generator)

response_1 = generator_c(prompt, max_tokens=100)
response_2 = generator_c(response_1)

You can mix different types of SequenceGenerator objects:

import outlines
from outlines.generate import continuous

generator_text = outlines.generate.text(model)
generator_choice  = outlines.generate.choice(model, ["Positive", "Negative"])

generator_text_c = continuous(generator_text)
generator_choice_c = continuous(generator_choice)

response_1 = generator_text_c(prompt, max_tokens=100)
response_2 = generator_choice_c(response_1)

Once a prompt is given to the continuous wrapper, it becomes a SequenceState object.

class SequenceState:
    token_ids: torch.Tensor
    weights: torch.Tensor
    attention_masks: torch.Tensor
    kv_cache: torch.Tensor
    tokenizer: "Tokenizer"

SequenceState

Indexing

Each SequenceState has three dimensions SequenceState[batch_key: Union[int, slice], sample_key: Union[int, slice], ids_size_key: Union[int, slice]].

However, there are three cases where this is handled differently:

  1. batch_size == 1 and sample_size == 1
    SequenceState[ids_size_key: Union[int, slice]], instead of SequenceState[0, 0, ids_size_key: Union[int, slice]].

  2. batch_size == 1
    SequenceState[sample_key: Union[int, slice], ids_size_key: Union[int, slice]], instead of SequenceState[0, sample_key: Union[int, slice], ids_size_key: Union[int, slice]].

  3. sample_size == 1
    SequenceState[batch_key: Union[int, slice], ids_size_key: Union[int, slice]], instead of SequenceState[batch_key: Union[int, slice], 0, ids_size_key: Union[int, slice]].

Operations

You can apply two operations on a SequenceState:

  • Slicing

  • Adding (SequenceState to a SequenceState and SequenceState to a prompt)

Adding

  1. SequenceState to a SequenceState

This won't save the first part of the KV Cache for the moment, but it does accumulate the weights between both sequences.

I don't have an idea how to implement it, the KV Cache implementation from HuggingFace accepts either (1) a None value or (2) a KV Cache with a context size less than one than the one for the token_ids.

I've just done an experiment where I use the model to compute (or complete) the KV Cache for the second sequence using the model to satisfy (2). The function is called complete_kv_cache_from_token_ids, it's not implemented because it's slow.

  1. SequenceState to a prompt

This will reinitialize everything.

Slicing

Conditions under which KV Cache is saved:
  1. The slice considers only one element (batch_size_after_the_slice == 1, sample_size_after_the_slice == 1), slicing more than one element will reset the KV Cache. The condition includes the base case where (batch_size == 1, sample_size == 1).

  2. The slice starts from the first index for the prompt (SequenceState[..., :M], SequenceState[..., 0:M]).

There are some technical intricacies that don't allow saving KV Cache even under 1. and 2., see [NOTE] [SPECIAL CASE] flags in token_level_slice_from_string_level_slice utility.

It's also one of the reasons to not go wander to get KV Cache work if batch_size > 1 and num_samples > 1. The tradeoff complexity-usefulness seems just way off to me.

Using list(SequenceState)

list(SequenceState) allows to convert the SequenceState object into a list of strings.

Exceptions

Three types of exceptions could be raised while using the continuous wrapper.

  1. SampleMismatch: This is (1) raised when the sequence's samples are sliced, then thrown to the wrapper (a mismatch between the number of samples in the sequence and the one in the generator) and (2) two sequences with different number of samples are added.

  2. BatchMismatch: This is raised when two sequences of different batch sizes are added.

  3. SlicingError: This is raised when the slice doesn't allow the KV Cache to be saved, it is handled through resetting the KV Cache.

FLAGs

You will see multiple flags that I've put in the code comments:

  • [NOTE]: Those are general notes, explaining how I approached some problems.

  • [QUESTION]: Those are questions that I had when I was coding certain mechanisms.

  • [POTENTIAL BUG]: Those are lines of code that could potentially trigger bugs.

Modifications

  1. GenerationState returns attention_masks as well.

  2. sequence_generator takes kv_cache as a keyword argument with the value None as a default.

test_continuous.py

Those are some tests I added for the different parts of the code.

PS: I use the name SequenceState instead of Sequence just because it made the coding more obvious to me, tell me if you want to switch it back to Sequence.

@miftahmoha miftahmoha marked this pull request as draft April 1, 2024 11:18
@rlouf
Copy link
Member

rlouf commented Apr 3, 2024

This is cool, I really like the continuous function to a point I'd consider doing the same thing for stream. I'll review the PR soon, please bear with me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants