From b9b1391ee8233853990e036cb001a52a06ea7ce0 Mon Sep 17 00:00:00 2001 From: Andrew Date: Wed, 23 Aug 2023 14:59:59 +0200 Subject: [PATCH 1/2] Remove error on context window overflow --- crates/llm-base/src/inference_session.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index e3f5a785..e8dbb1c2 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -381,9 +381,10 @@ impl InferenceSession { output_request: &mut OutputRequest, rng: &mut impl rand::Rng, ) -> Result, InferenceError> { - if self.n_past + 1 >= model.context_size() { - return Err(InferenceError::ContextFull); - } + // disable error throw on context size overflow to use llama.cpp "context window slide" (if it exists). + // if self.n_past + 1 >= model.context_size() { + // return Err(InferenceError::ContextFull); + // } let next_token = crate::samplers::sample_token( params.sampler.clone(), From 99a9fb4dfefa455b983c016a89f54618d062469a Mon Sep 17 00:00:00 2001 From: Andrii Kotliar Date: Tue, 12 Sep 2023 14:40:26 +0200 Subject: [PATCH 2/2] Add "context swap" functions to session and add "decoded_tokens" to snapshot read/write --- crates/llm-base/src/inference_session.rs | 217 ++++++++++++++++++++--- 1 file changed, 191 insertions(+), 26 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index e8dbb1c2..12f66f5e 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -299,15 +299,22 @@ impl InferenceSession { output_request: &mut OutputRequest, mut callback: impl FnMut(&[u8]) -> Result, ) -> Result<(), InferenceError> { - let beginning_of_sentence = self.n_past == 0; - - let vocab = model.tokenizer(); - let prompt_tokens = prompt.into().to_tokens(vocab, beginning_of_sentence)?; + let prompt_tokens = self.get_prompt_tokens(model, prompt)?; if self.n_past + prompt_tokens.len() >= model.context_size() { return Err(InferenceError::ContextFull); } + self.feed_prompt_tokens(model, output_request, &mut callback, prompt_tokens) + } + + fn feed_prompt_tokens( + &mut self, + model: &dyn Model, + output_request: &mut OutputRequest, + mut callback: impl FnMut(&[u8]) -> Result, + prompt_tokens: Vec, + ) -> Result<(), InferenceError> { 'outer: for batch in prompt_tokens.chunks(self.config.n_batch) { model.evaluate(self, batch, output_request); for &tk in batch { @@ -341,10 +348,46 @@ impl InferenceSession { } } log::trace!("Finished feed prompt"); - Ok(()) } + fn get_prompt_tokens<'a, P: Into>>( + &self, + model: &dyn Model, + prompt: P, + ) -> Result, TokenizationError> { + let beginning_of_sentence = self.n_past == 0; + + let vocab = model.tokenizer(); + prompt.into().to_tokens(vocab, beginning_of_sentence) + } + + /// Feed a prompt to the model for this session. + /// Same as [Self::feed_prompt] but includes logic for cutting tokens in case if the prompt is longer than current n_past. + #[instrument(skip_all)] + pub fn feed_prompt_with_swap< + 'a, + E: std::error::Error + Send + Sync + 'static, + P: Into>, + >( + &mut self, + model: &dyn Model, + prompt: P, + n_keep: usize, + output_request: &mut OutputRequest, + mut callback: impl FnMut(&[u8]) -> Result, + ) -> Result<(), InferenceError> { + let prompt_tokens = self.get_prompt_tokens(model, prompt)?; + + if self.n_past + prompt_tokens.len() >= model.context_size() { + let rewind_by = self.n_past + prompt_tokens.len() - model.context_size(); + self.remove_tokens(model, n_keep, rewind_by) + .map_err(|_e| InferenceError::ContextFull)?; + } + + self.feed_prompt_tokens(model, output_request, &mut callback, prompt_tokens) + } + /// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`. pub fn rewind(&mut self, model: &dyn Model, num: usize) -> Result, RewindError> { if !model.supports_rewind() { @@ -372,6 +415,46 @@ impl InferenceSession { Ok(deleted_tokens) } + /// Removes `num` tokens from the specified position of the buffer. Similar to [Self::rewind]. + fn remove_tokens( + &mut self, + model: &dyn Model, + start_from: usize, + num: usize, + ) -> Result, RewindError> { + if !model.supports_rewind() { + return Err(RewindError::UnsupportedArchitecture); + } + + if start_from + num >= self.n_past { + return Err(RewindError::NotEnoughTokens); + } + + // Remove the tokens from self.tokens. + let end = start_from + num; + let deleted_tokens: Vec<_> = self.tokens.drain(start_from..end).collect(); + + // Remove the corresponding chars from decoded + let mut decoded_start = 0; + let mut decoded_end = 0; + if start_from != 0 { + for id in &self.tokens[0..start_from] { + decoded_start += model.tokenizer().token(*id as usize).len(); + } + decoded_end += decoded_start; + } + + for id in &deleted_tokens { + decoded_end += model.tokenizer().token(*id as usize).len(); + } + self.decoded_tokens.drain(decoded_start..decoded_end); + + // Decrement the n_past tokens counter. + self.n_past -= num; + + Ok(deleted_tokens) + } + /// Infer the next token for this session. #[instrument(level = "trace", skip_all)] pub fn infer_next_token( @@ -381,10 +464,9 @@ impl InferenceSession { output_request: &mut OutputRequest, rng: &mut impl rand::Rng, ) -> Result, InferenceError> { - // disable error throw on context size overflow to use llama.cpp "context window slide" (if it exists). - // if self.n_past + 1 >= model.context_size() { - // return Err(InferenceError::ContextFull); - // } + if self.n_past + 1 >= model.context_size() { + return Err(InferenceError::ContextFull); + } let next_token = crate::samplers::sample_token( params.sampler.clone(), @@ -438,19 +520,7 @@ impl InferenceSession { ) -> Result { let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX); if request.play_back_previous_tokens { - // "Play back" the existing tokens, so that loading from an inference snapshot works - // as expected. - let mut token_utf8_buf = TokenUtf8Buffer::new(); - for token_id in &self.tokens { - // Buffer the token until it's valid UTF-8, then call the callback. - if let Some(tokens) = - token_utf8_buf.push(&model.tokenizer().token(*token_id as usize)) - { - if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) { - return Err(InferenceError::UserCallback(Box::new(e))); - } - } - } + self.play_back_previous_tokens(model, &mut callback)? } log::trace!( "Starting inference request with max_token_count: {}", @@ -475,10 +545,25 @@ impl InferenceSession { stats.feed_prompt_duration = start_at.elapsed().unwrap(); stats.prompt_tokens = self.n_past; - // After the prompt is consumed, sample tokens by repeatedly calling - // `infer_next_token`. We generate tokens until the model returns an - // EndOfText token, or we run out of space in the context window, - // or we reach the specified limit. + self.infer_tokens(model, rng, &mut callback, maximum_token_count, parameters)?; + stats.predict_duration = start_at.elapsed().unwrap(); + stats.predict_tokens = self.n_past; + + Ok(stats) + } + + /// sample tokens by repeatedly calling + /// [Self::infer_next_token]. Generate tokens until the model returns an + /// EndOfText token, or we run out of space in the context window, + /// or we reach the specified limit. + fn infer_tokens( + &mut self, + model: &dyn Model, + rng: &mut impl rand::Rng, + mut callback: impl FnMut(InferenceResponse) -> Result, + maximum_token_count: usize, + parameters: &InferenceParameters, + ) -> Result<(), InferenceError> { let mut tokens_processed = 0; let mut token_utf8_buf = TokenUtf8Buffer::new(); while tokens_processed < maximum_token_count { @@ -502,6 +587,79 @@ impl InferenceSession { tokens_processed += 1; } + Ok(()) + } + + /// "Play back" the existing tokens, so that loading from an inference snapshot works + /// as expected. + fn play_back_previous_tokens( + &mut self, + model: &dyn Model, + mut callback: impl FnMut(InferenceResponse) -> Result, + ) -> Result<(), InferenceError> { + let mut token_utf8_buf = TokenUtf8Buffer::new(); + for token_id in &self.tokens { + // Buffer the token until it's valid UTF-8, then call the callback. + if let Some(tokens) = token_utf8_buf.push(&model.tokenizer().token(*token_id as usize)) + { + if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) { + return Err(InferenceError::UserCallback(Box::new(e))); + } + } + } + Ok(()) + } + + /// Generate text by using the provided [Model] to evaluate the `prompt`. + /// Works the same way as [Self::infer] except has infinite text generation via context swapping + #[instrument(skip_all)] + pub fn infer_with_swap( + &mut self, + model: &dyn Model, + rng: &mut impl rand::Rng, + request: &InferenceRequest, + n_keep: usize, + output_request: &mut OutputRequest, + mut callback: impl FnMut(InferenceResponse) -> Result, + ) -> Result { + let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX); + if request.play_back_previous_tokens { + self.play_back_previous_tokens(model, &mut callback)? + } + + // infinite text generation via context swapping + // if we run out of context: + // - take the n_keep first tokens from the original prompt + // - remove half of the tokens after n_keep ((n_ctx - n_keep) / 2) + if self.n_past >= model.context_size() { + self.remove_tokens(model, n_keep, (self.n_past - n_keep) / 2) + .map_err(|_e| InferenceError::ContextFull)?; + } + + log::trace!( + "Starting inference request with max_token_count: {}", + maximum_token_count + ); + + let mut stats = InferenceStats::default(); + let start_at = std::time::SystemTime::now(); + + let parameters = request.parameters; + + // Feed the initial prompt through the transformer, to update its + // context window with new data, if necessary. + if !request.prompt.is_empty() { + self.feed_prompt( + model, + request.prompt, + output_request, + feed_prompt_callback(&mut callback), + )?; + } + stats.feed_prompt_duration = start_at.elapsed().unwrap(); + stats.prompt_tokens = self.n_past; + + self.infer_tokens(model, rng, &mut callback, maximum_token_count, parameters)?; stats.predict_duration = start_at.elapsed().unwrap(); stats.predict_tokens = self.n_past; @@ -605,6 +763,7 @@ impl InferenceSession { npast: self.n_past, config: self.config, tokens: self.tokens.clone(), + decoded_tokens: self.decoded_tokens.clone(), logits: self.last_logits.clone(), memory_k, memory_v, @@ -637,6 +796,7 @@ impl InferenceSession { session.n_past = snapshot.npast; session.tokens = snapshot.tokens; + session.decoded_tokens = snapshot.decoded_tokens; session.last_logits = snapshot.last_logits; Ok(session) @@ -742,6 +902,8 @@ pub struct InferenceSnapshotRef<'a> { pub config: InferenceSessionConfig, /// All tokens generated by this inference session. pub tokens: Vec, + /// All decoded tokens generated by this inference session. + pub decoded_tokens: Vec, /// The vector of logits that was produced after the last inference. pub logits: Vec, /// The contents of the 'key' memory tensor. @@ -760,6 +922,7 @@ impl InferenceSnapshotRef<'_> { npast: self.npast, config: self.config, tokens: self.tokens.clone(), + decoded_tokens: self.decoded_tokens.clone(), last_logits: self.logits.clone(), memory_k: self.memory_k.to_vec(), memory_v: self.memory_v.to_vec(), @@ -778,6 +941,8 @@ pub struct InferenceSnapshot { pub config: InferenceSessionConfig, /// All tokens generated by this inference session. pub tokens: Vec, + /// All decoded tokens generated by this inference session. + pub decoded_tokens: Vec, /// The vector of logits that was produced after the last inference. pub last_logits: Vec, /// The contents of the 'key' memory tensor.