From 41c18834093b52be3e6f8778f441cbaad7e1c49d Mon Sep 17 00:00:00 2001 From: "Marek Hradil jr." Date: Fri, 28 Nov 2025 16:43:12 +0100 Subject: [PATCH] feat: Improve the grammar error handling --- llama-cpp-2/src/lib.rs | 17 +++++++++ llama-cpp-2/src/sampling.rs | 74 +++++++++++++++++++++++++------------ 2 files changed, 68 insertions(+), 23 deletions(-) diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 8139725f..16b5bd8e 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -156,6 +156,23 @@ pub enum EmbeddingsError { NonePoolType, } +/// Errors that can occur when initializing a grammar sampler +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum GrammarError { + /// The grammar root was not found in the grammar string + #[error("Grammar root not found in grammar string")] + RootNotFound, + /// The trigger word contains null bytes + #[error("Trigger word contains null bytes")] + TriggerWordNullBytes, + /// The grammar string or root contains null bytes + #[error("Grammar string or root contains null bytes")] + GrammarNullBytes, + /// The grammar call returned null + #[error("Grammar call returned null")] + NullGrammar, +} + /// Decode a error from llama.cpp into a [`DecodeError`]. impl From for DecodeError { fn from(value: NonZeroI32) -> Self { diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index d1275aec..de232005 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -9,6 +9,7 @@ use crate::model::LlamaModel; use crate::token::data_array::LlamaTokenDataArray; use crate::token::logit_bias::LlamaLogitBias; use crate::token::LlamaToken; +use crate::GrammarError; /// A safe wrapper around `llama_sampler`. pub struct LlamaSampler { @@ -274,13 +275,14 @@ impl LlamaSampler { } /// Grammar sampler - /// - /// # Panics - /// If either of ``grammar_str`` or ``grammar_root`` contain null bytes. #[must_use] - pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Option { - let grammar_str = CString::new(grammar_str).unwrap(); - let grammar_root = CString::new(grammar_root).unwrap(); + pub fn grammar( + model: &LlamaModel, + grammar_str: &str, + grammar_root: &str, + ) -> Result { + let (grammar_str, grammar_root) = + Self::sanitize_grammar_strings(grammar_str, grammar_root)?; let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_grammar( @@ -291,19 +293,15 @@ impl LlamaSampler { }; if sampler.is_null() { - None + Err(GrammarError::NullGrammar) } else { - Some(Self { sampler }) + Ok(Self { sampler }) } } /// Lazy grammar sampler, introduced in /// /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered. - /// - /// # Panics - /// - If `grammar_str` or `grammar_root` contain null bytes - /// - If any trigger word contains null bytes #[must_use] pub fn grammar_lazy( model: &LlamaModel, @@ -311,17 +309,13 @@ impl LlamaSampler { grammar_root: &str, trigger_words: impl IntoIterator>, trigger_tokens: &[LlamaToken], - ) -> Option { - let grammar_str = CString::new(grammar_str).unwrap(); - let grammar_root = CString::new(grammar_root).unwrap(); - - let trigger_word_cstrings: Vec = trigger_words - .into_iter() - .map(|word| CString::new(word.as_ref()).unwrap()) - .collect(); + ) -> Result { + let (grammar_str, grammar_root) = + Self::sanitize_grammar_strings(grammar_str, grammar_root)?; + let trigger_words = Self::sanitize_trigger_words(trigger_words)?; let mut trigger_word_ptrs: Vec<*const c_char> = - trigger_word_cstrings.iter().map(|cs| cs.as_ptr()).collect(); + trigger_words.iter().map(|cs| cs.as_ptr()).collect(); let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_grammar_lazy( @@ -336,12 +330,46 @@ impl LlamaSampler { }; if sampler.is_null() { - None + Err(GrammarError::NullGrammar) } else { - Some(Self { sampler }) + Ok(Self { sampler }) } } + fn sanitize_grammar_strings( + grammar_str: &str, + grammar_root: &str, + ) -> Result<(CString, CString), GrammarError> { + if !grammar_str.contains(grammar_root) { + return Err(GrammarError::RootNotFound); + } + + if grammar_str.contains('\0') || grammar_root.contains('\0') { + return Err(GrammarError::GrammarNullBytes); + } + + Ok(( + CString::new(grammar_str).unwrap(), + CString::new(grammar_root).unwrap(), + )) + } + + fn sanitize_trigger_words( + trigger_words: impl IntoIterator>, + ) -> Result, GrammarError> { + let trigger_words: Vec<_> = trigger_words.into_iter().collect(); + if trigger_words + .iter() + .any(|word| word.as_ref().contains(&b'\0')) + { + return Err(GrammarError::TriggerWordNullBytes); + } + Ok(trigger_words + .into_iter() + .map(|word| CString::new(word.as_ref()).unwrap()) + .collect()) + } + /// DRY sampler, designed by p-e-w, as described in: /// , porting Koboldcpp /// implementation authored by pi6am: