Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,23 @@ pub enum EmbeddingsError {
NonePoolType,
}

/// Errors that can occur when initializing a grammar sampler
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum GrammarError {
/// The grammar root was not found in the grammar string
#[error("Grammar root not found in grammar string")]
RootNotFound,
/// The trigger word contains null bytes
#[error("Trigger word contains null bytes")]
TriggerWordNullBytes,
/// The grammar string or root contains null bytes
#[error("Grammar string or root contains null bytes")]
GrammarNullBytes,
/// The grammar call returned null
#[error("Grammar call returned null")]
NullGrammar,
}

/// Decode a error from llama.cpp into a [`DecodeError`].
impl From<NonZeroI32> for DecodeError {
fn from(value: NonZeroI32) -> Self {
Expand Down
74 changes: 51 additions & 23 deletions llama-cpp-2/src/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::model::LlamaModel;
use crate::token::data_array::LlamaTokenDataArray;
use crate::token::logit_bias::LlamaLogitBias;
use crate::token::LlamaToken;
use crate::GrammarError;

/// A safe wrapper around `llama_sampler`.
pub struct LlamaSampler {
Expand Down Expand Up @@ -274,13 +275,14 @@ impl LlamaSampler {
}

/// Grammar sampler
///
/// # Panics
/// If either of ``grammar_str`` or ``grammar_root`` contain null bytes.
#[must_use]
pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Option<Self> {
let grammar_str = CString::new(grammar_str).unwrap();
let grammar_root = CString::new(grammar_root).unwrap();
pub fn grammar(
model: &LlamaModel,
grammar_str: &str,
grammar_root: &str,
) -> Result<Self, GrammarError> {
let (grammar_str, grammar_root) =
Self::sanitize_grammar_strings(grammar_str, grammar_root)?;

let sampler = unsafe {
llama_cpp_sys_2::llama_sampler_init_grammar(
Expand All @@ -291,37 +293,29 @@ impl LlamaSampler {
};

if sampler.is_null() {
None
Err(GrammarError::NullGrammar)
} else {
Some(Self { sampler })
Ok(Self { sampler })
}
}

/// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
///
/// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
///
/// # Panics
/// - If `grammar_str` or `grammar_root` contain null bytes
/// - If any trigger word contains null bytes
#[must_use]
pub fn grammar_lazy(
model: &LlamaModel,
grammar_str: &str,
grammar_root: &str,
trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
trigger_tokens: &[LlamaToken],
) -> Option<Self> {
let grammar_str = CString::new(grammar_str).unwrap();
let grammar_root = CString::new(grammar_root).unwrap();

let trigger_word_cstrings: Vec<CString> = trigger_words
.into_iter()
.map(|word| CString::new(word.as_ref()).unwrap())
.collect();
) -> Result<Self, GrammarError> {
let (grammar_str, grammar_root) =
Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
let trigger_words = Self::sanitize_trigger_words(trigger_words)?;

let mut trigger_word_ptrs: Vec<*const c_char> =
trigger_word_cstrings.iter().map(|cs| cs.as_ptr()).collect();
trigger_words.iter().map(|cs| cs.as_ptr()).collect();

let sampler = unsafe {
llama_cpp_sys_2::llama_sampler_init_grammar_lazy(
Expand All @@ -336,12 +330,46 @@ impl LlamaSampler {
};

if sampler.is_null() {
None
Err(GrammarError::NullGrammar)
} else {
Some(Self { sampler })
Ok(Self { sampler })
}
}

fn sanitize_grammar_strings(
grammar_str: &str,
grammar_root: &str,
) -> Result<(CString, CString), GrammarError> {
if !grammar_str.contains(grammar_root) {
return Err(GrammarError::RootNotFound);
}

if grammar_str.contains('\0') || grammar_root.contains('\0') {
return Err(GrammarError::GrammarNullBytes);
}

Ok((
CString::new(grammar_str).unwrap(),
CString::new(grammar_root).unwrap(),
))
}

fn sanitize_trigger_words(
trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
) -> Result<Vec<CString>, GrammarError> {
let trigger_words: Vec<_> = trigger_words.into_iter().collect();
if trigger_words
.iter()
.any(|word| word.as_ref().contains(&b'\0'))
{
return Err(GrammarError::TriggerWordNullBytes);
}
Ok(trigger_words
.into_iter()
.map(|word| CString::new(word.as_ref()).unwrap())
.collect())
}

/// DRY sampler, designed by p-e-w, as described in:
/// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
/// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
Expand Down
Loading