Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #424 from Andreybest/main
Browse files Browse the repository at this point in the history
Add "context swap" functions to session and add "decoded_tokens" to snapshot read/write
  • Loading branch information
philpax committed Nov 12, 2023
2 parents e5e0fe1 + 2e3c6f7 commit 2c127df
Showing 1 changed file with 188 additions and 22 deletions.
210 changes: 188 additions & 22 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,22 @@ impl InferenceSession {
output_request: &mut OutputRequest,
mut callback: impl FnMut(&[u8]) -> Result<InferenceFeedback, E>,
) -> Result<(), InferenceError> {
let beginning_of_sentence = self.n_past == 0;

let vocab = model.tokenizer();
let prompt_tokens = prompt.into().to_tokens(vocab, beginning_of_sentence)?;
let prompt_tokens = self.get_prompt_tokens(model, prompt)?;

if self.n_past + prompt_tokens.len() >= model.context_size() {
return Err(InferenceError::ContextFull);
}

self.feed_prompt_tokens(model, output_request, &mut callback, prompt_tokens)
}

fn feed_prompt_tokens<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
output_request: &mut OutputRequest,
mut callback: impl FnMut(&[u8]) -> Result<InferenceFeedback, E>,
prompt_tokens: Vec<TokenId>,
) -> Result<(), InferenceError> {
'outer: for batch in prompt_tokens.chunks(self.config.n_batch) {
model.evaluate(self, batch, output_request);
for &tk in batch {
Expand Down Expand Up @@ -414,10 +421,46 @@ impl InferenceSession {
}
}
log::trace!("Finished feed prompt");

Ok(())
}

fn get_prompt_tokens<'a, P: Into<Prompt<'a>>>(
&self,
model: &dyn Model,
prompt: P,
) -> Result<Vec<TokenId>, TokenizationError> {
let beginning_of_sentence = self.n_past == 0;

let vocab = model.tokenizer();
prompt.into().to_tokens(vocab, beginning_of_sentence)
}

/// Feed a prompt to the model for this session.
/// Same as [Self::feed_prompt] but includes logic for cutting tokens in case if the prompt is longer than current n_past.
#[instrument(skip_all)]
pub fn feed_prompt_with_swap<
'a,
E: std::error::Error + Send + Sync + 'static,
P: Into<Prompt<'a>>,
>(
&mut self,
model: &dyn Model,
prompt: P,
n_keep: usize,
output_request: &mut OutputRequest,
mut callback: impl FnMut(&[u8]) -> Result<InferenceFeedback, E>,
) -> Result<(), InferenceError> {
let prompt_tokens = self.get_prompt_tokens(model, prompt)?;

if self.n_past + prompt_tokens.len() >= model.context_size() {
let rewind_by = self.n_past + prompt_tokens.len() - model.context_size();
self.remove_tokens(model, n_keep, rewind_by)
.map_err(|_e| InferenceError::ContextFull)?;
}

self.feed_prompt_tokens(model, output_request, &mut callback, prompt_tokens)
}

/// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`.
pub fn rewind(&mut self, model: &dyn Model, num: usize) -> Result<Vec<TokenId>, RewindError> {
if !model.supports_rewind() {
Expand Down Expand Up @@ -445,6 +488,46 @@ impl InferenceSession {
Ok(deleted_tokens)
}

/// Removes `num` tokens from the specified position of the buffer. Similar to [Self::rewind].
fn remove_tokens(
&mut self,
model: &dyn Model,
start_from: usize,
num: usize,
) -> Result<Vec<TokenId>, RewindError> {
if !model.supports_rewind() {
return Err(RewindError::UnsupportedArchitecture);
}

if start_from + num >= self.n_past {
return Err(RewindError::NotEnoughTokens);
}

// Remove the tokens from self.tokens.
let end = start_from + num;
let deleted_tokens: Vec<_> = self.tokens.drain(start_from..end).collect();

// Remove the corresponding chars from decoded
let mut decoded_start = 0;
let mut decoded_end = 0;
if start_from != 0 {
for id in &self.tokens[0..start_from] {
decoded_start += model.tokenizer().token(*id as usize).len();
}
decoded_end += decoded_start;
}

for id in &deleted_tokens {
decoded_end += model.tokenizer().token(*id as usize).len();
}
self.decoded_tokens.drain(decoded_start..decoded_end);

// Decrement the n_past tokens counter.
self.n_past -= num;

Ok(deleted_tokens)
}

/// Infer the next token for this session.
#[instrument(level = "trace", skip_all)]
pub fn infer_next_token(
Expand Down Expand Up @@ -510,19 +593,7 @@ impl InferenceSession {
) -> Result<InferenceStats, InferenceError> {
let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX);
if request.play_back_previous_tokens {
// "Play back" the existing tokens, so that loading from an inference snapshot works
// as expected.
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in &self.tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) =
token_utf8_buf.push(&model.tokenizer().token(*token_id as usize))
{
if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
}
self.play_back_previous_tokens(model, &mut callback)?
}
log::trace!(
"Starting inference request with max_token_count: {}",
Expand All @@ -547,10 +618,25 @@ impl InferenceSession {
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;

// After the prompt is consumed, sample tokens by repeatedly calling
// `infer_next_token`. We generate tokens until the model returns an
// EndOfText token, or we run out of space in the context window,
// or we reach the specified limit.
self.infer_tokens(model, rng, &mut callback, maximum_token_count, parameters)?;
stats.predict_duration = start_at.elapsed().unwrap();
stats.predict_tokens = self.n_past;

Ok(stats)
}

/// sample tokens by repeatedly calling
/// [Self::infer_next_token]. Generate tokens until the model returns an
/// EndOfText token, or we run out of space in the context window,
/// or we reach the specified limit.
fn infer_tokens<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
rng: &mut impl rand::Rng,
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E>,
maximum_token_count: usize,
parameters: &InferenceParameters,
) -> Result<(), InferenceError> {
let mut tokens_processed = 0;
let mut token_utf8_buf = TokenUtf8Buffer::new();
while tokens_processed < maximum_token_count {
Expand All @@ -574,6 +660,79 @@ impl InferenceSession {

tokens_processed += 1;
}
Ok(())
}

/// "Play back" the existing tokens, so that loading from an inference snapshot works
/// as expected.
fn play_back_previous_tokens<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E>,
) -> Result<(), InferenceError> {
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in &self.tokens {
// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(&model.tokenizer().token(*token_id as usize))
{
if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
}
Ok(())
}

/// Generate text by using the provided [Model] to evaluate the `prompt`.
/// Works the same way as [Self::infer] except has infinite text generation via context swapping
#[instrument(skip_all)]
pub fn infer_with_swap<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
rng: &mut impl rand::Rng,
request: &InferenceRequest,
n_keep: usize,
output_request: &mut OutputRequest,
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E>,
) -> Result<InferenceStats, InferenceError> {
let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX);
if request.play_back_previous_tokens {
self.play_back_previous_tokens(model, &mut callback)?
}

// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt
// - remove half of the tokens after n_keep ((n_ctx - n_keep) / 2)
if self.n_past >= model.context_size() {
self.remove_tokens(model, n_keep, (self.n_past - n_keep) / 2)
.map_err(|_e| InferenceError::ContextFull)?;
}

log::trace!(
"Starting inference request with max_token_count: {}",
maximum_token_count
);

let mut stats = InferenceStats::default();
let start_at = std::time::SystemTime::now();

let parameters = request.parameters;

// Feed the initial prompt through the transformer, to update its
// context window with new data, if necessary.
if !request.prompt.is_empty() {
self.feed_prompt(
model,
request.prompt,
output_request,
feed_prompt_callback(&mut callback),
)?;
}
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;

self.infer_tokens(model, rng, &mut callback, maximum_token_count, parameters)?;
stats.predict_duration = start_at.elapsed().unwrap();
stats.predict_tokens = self.n_past;

Expand Down Expand Up @@ -677,6 +836,7 @@ impl InferenceSession {
npast: self.n_past,
config: self.config,
tokens: self.tokens.clone(),
decoded_tokens: self.decoded_tokens.clone(),
last_logits: self.last_logits.clone(),
memory_k,
memory_v,
Expand Down Expand Up @@ -709,6 +869,7 @@ impl InferenceSession {

session.n_past = snapshot.npast;
session.tokens = snapshot.tokens;
session.decoded_tokens = snapshot.decoded_tokens;
session.last_logits = snapshot.last_logits;

Ok(session)
Expand Down Expand Up @@ -814,6 +975,8 @@ pub struct InferenceSnapshotRef<'a> {
pub config: InferenceSessionConfig,
/// All tokens generated by this inference session.
pub tokens: Vec<TokenId>,
/// All decoded tokens generated by this inference session.
pub decoded_tokens: Vec<u8>,
/// The vector of logits that was produced after the last inference.
pub last_logits: Vec<f32>,
/// The contents of the 'key' memory tensor.
Expand All @@ -832,6 +995,7 @@ impl InferenceSnapshotRef<'_> {
npast: self.npast,
config: self.config,
tokens: self.tokens.clone(),
decoded_tokens: self.decoded_tokens.clone(),
last_logits: self.last_logits.clone(),
memory_k: self.memory_k.to_vec(),
memory_v: self.memory_v.to_vec(),
Expand All @@ -850,6 +1014,8 @@ pub struct InferenceSnapshot {
pub config: InferenceSessionConfig,
/// All tokens generated by this inference session.
pub tokens: Vec<TokenId>,
/// All decoded tokens generated by this inference session.
pub decoded_tokens: Vec<u8>,
/// The vector of logits that was produced after the last inference.
pub last_logits: Vec<f32>,
/// The contents of the 'key' memory tensor.
Expand Down

0 comments on commit 2c127df

Please sign in to comment.