Skip to content

Commit

Permalink
Merge pull request #17 from thejohnbackes/split-by-tokens
Browse files Browse the repository at this point in the history
Add ability to tokenize a string and return the decoded tokens using the correct BPE model
  • Loading branch information
zurawiki committed Apr 16, 2023
2 parents 51e1d36 + d3461d3 commit 2ada6ca
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tiktoken-rs/src/api.rs
Expand Up @@ -46,7 +46,7 @@ pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
Ok(context_size.saturating_sub(prompt_tokens))
}

#[derive(Debug, Default, Clone, PartialEq)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ChatCompletionRequestMessage {
/// The role of the author of this message.
pub role: String,
Expand Down
63 changes: 63 additions & 0 deletions tiktoken-rs/src/vendor_tiktoken.rs
Expand Up @@ -235,6 +235,20 @@ impl CoreBPE {
ret
}

#[allow(clippy::needless_lifetimes)] // the iterator captures a lifetime outside of the function
fn _decode_native_and_split<'a>(
&'a self,
tokens: Vec<usize>,
) -> impl Iterator<Item = Vec<u8>> + '_ {
tokens.into_iter().map(move |token| {
let token_bytes = self
.decoder
.get(&token)
.unwrap_or_else(|| &self.special_tokens_decoder[&token]);
token_bytes.clone()
})
}

fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
Expand Down Expand Up @@ -541,6 +555,55 @@ impl CoreBPE {
Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
}
}

/// Tokenize a string and return the decoded tokens using the correct BPE model.
///
/// This method takes a string, encodes it using the BPE model, and decodes the encoded tokens into
/// a vector of strings. It can be used to tokenize a string and return the decoded tokens using the
/// correct BPE model.
///
/// # Examples
///
/// ```
/// use tiktoken_rs::cl100k_base;
/// let bpe = cl100k_base().unwrap();
/// let tokenized: Result<Vec<_>, _> = bpe
/// .split_by_token_with_special_tokens("This is a test with a lot of spaces")
/// .collect();
/// let tokenized = tokenized.unwrap();
/// assert_eq!(
/// tokenized,
/// vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"]
/// );
/// ```
///
/// # Arguments
///
/// * text: A string slice containing the text to be tokenized.
///
/// # Returns
///
/// * Result<Vec<String>>: A Result containing a vector of decoded tokens as strings, or an error
/// if the string cannot be converted into a valid UTF-8 string.
///
/// # Errors
///
/// This function will return an error if:
///
/// * The input text cannot be converted into a valid UTF-8 string during the decoding process.
///
pub fn split_by_token_with_special_tokens<'a>(
&'a self,
text: &'a str,
) -> impl Iterator<Item = Result<String>> + 'a {
// First, encode the text using the BPE model
let encoded = self.encode_with_special_tokens(text);

self._decode_native_and_split(encoded).map(|token|
// Map each token to a Result<String>
String::from_utf8(token)
.map_err(|e| anyhow!(e.to_string())))
}
}

#[cfg(feature = "python")]
Expand Down
13 changes: 13 additions & 0 deletions tiktoken-rs/tests/tiktoken.rs
Expand Up @@ -82,6 +82,19 @@ fn cl100k_base_test() {
);
}

#[test]
fn cl100k_split_test() {
let bpe = cl100k_base().unwrap();
let tokenized: Result<Vec<_>, _> = bpe
.split_by_token_with_special_tokens("This is a test with a lot of spaces")
.collect();
let tokenized = tokenized.unwrap();
assert_eq!(
tokenized,
vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"]
);
}

#[test]
fn p50k_base_singleton_test() {
// let now = std::time::Instant::now();
Expand Down

0 comments on commit 2ada6ca

Please sign in to comment.