One small step towards continuous generation for Outlines #781
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR is for solving #667.
API
continuous
continuous
wraps anySequenceGenerator
object, it could be:outlines.generate.choice
outlines.generate.text
outlines.generate.json
...
The
continuous
wrapper allows the generator to save the state of aSequence
, it means that, if you continuously generate a sequence as shown:KV Cache (under some conditions) will be saved. Algorithms such as beam search could be used to optimize the whole sequence rather than separately.
You can mix different types of
SequenceGenerator
objects:Once a prompt is given to the
continuous
wrapper, it becomes aSequenceState
object.SequenceState
Indexing
Each
SequenceState
has three dimensionsSequenceState[batch_key: Union[int, slice], sample_key: Union[int, slice], ids_size_key: Union[int, slice]]
.However, there are three cases where this is handled differently:
batch_size == 1
andsample_size == 1
SequenceState[ids_size_key: Union[int, slice]]
, instead ofSequenceState[0, 0, ids_size_key: Union[int, slice]]
.batch_size == 1
SequenceState[sample_key: Union[int, slice], ids_size_key: Union[int, slice]]
, instead ofSequenceState[0, sample_key: Union[int, slice], ids_size_key: Union[int, slice]]
.sample_size == 1
SequenceState[batch_key: Union[int, slice], ids_size_key: Union[int, slice]]
, instead ofSequenceState[batch_key: Union[int, slice], 0, ids_size_key: Union[int, slice]]
.Operations
You can apply two operations on a SequenceState:
Slicing
Adding (
SequenceState
to aSequenceState
andSequenceState
to a prompt)Adding
SequenceState
to aSequenceState
This won't save the first part of the KV Cache for the moment, but it does accumulate the weights between both sequences.
I don't have an idea how to implement it, the KV Cache implementation from HuggingFace accepts either (1) a
None
value or (2) a KV Cache with a context size less than one than the one for thetoken_ids
.I've just done an experiment where I use the model to compute (or complete) the KV Cache for the second sequence using the model to satisfy (2). The function is called
complete_kv_cache_from_token_ids
, it's not implemented because it's slow.SequenceState
to a promptThis will reinitialize everything.
Slicing
Conditions under which KV Cache is saved:
The slice considers only one element
(batch_size_after_the_slice == 1, sample_size_after_the_slice == 1)
, slicing more than one element will reset the KV Cache. The condition includes the base case where(batch_size == 1, sample_size == 1)
.The slice starts from the first index for the prompt
(SequenceState[..., :M], SequenceState[..., 0:M])
.There are some technical intricacies that don't allow saving KV Cache even under 1. and 2., see
[NOTE] [SPECIAL CASE]
flags intoken_level_slice_from_string_level_slice
utility.It's also one of the reasons to not go wander to get KV Cache work if
batch_size > 1
andnum_samples > 1
. The tradeoff complexity-usefulness seems just way off to me.Using
list(SequenceState)
list(SequenceState)
allows to convert theSequenceState
object into a list of strings.Exceptions
Three types of exceptions could be raised while using the
continuous
wrapper.SampleMismatch: This is (1) raised when the sequence's samples are sliced, then thrown to the wrapper (a mismatch between the number of samples in the sequence and the one in the generator) and (2) two sequences with different number of samples are added.
BatchMismatch: This is raised when two sequences of different batch sizes are added.
SlicingError: This is raised when the slice doesn't allow the KV Cache to be saved, it is handled through resetting the KV Cache.
FLAGs
You will see multiple flags that I've put in the code comments:
[NOTE]: Those are general notes, explaining how I approached some problems.
[QUESTION]: Those are questions that I had when I was coding certain mechanisms.
[POTENTIAL BUG]: Those are lines of code that could potentially trigger bugs.
Modifications
GenerationState
returnsattention_masks
as well.sequence_generator
takeskv_cache
as a keyword argument with the valueNone
as a default.test_continuous.py
Those are some tests I added for the different parts of the code.
PS: I use the name
SequenceState
instead ofSequence
just because it made the coding more obvious to me, tell me if you want to switch it back toSequence
.