From 737f3eeff1bae501c3e7f383364a23e3c169a575 Mon Sep 17 00:00:00 2001 From: Yusuf Simonson Date: Fri, 28 Nov 2025 11:45:35 -0500 Subject: [PATCH] Added lifetime to LlamaBatch to track underlying data from get_one constructor --- examples/mtmd/src/mtmd.rs | 6 +++--- llama-cpp-2/src/llama_batch.rs | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/mtmd/src/mtmd.rs b/examples/mtmd/src/mtmd.rs index e0f52de9..a2af6328 100644 --- a/examples/mtmd/src/mtmd.rs +++ b/examples/mtmd/src/mtmd.rs @@ -72,11 +72,11 @@ pub struct MtmdCliParams { /// State of the MTMD CLI application. #[allow(missing_debug_implementations)] -pub struct MtmdCliContext { +pub struct MtmdCliContext<'a> { /// The MTMD context for multimodal processing. pub mtmd_ctx: MtmdContext, /// The batch used for processing tokens. - pub batch: LlamaBatch, + pub batch: LlamaBatch<'a>, /// The list of loaded bitmaps (images/audio). pub bitmaps: Vec, /// The number of past tokens processed. @@ -87,7 +87,7 @@ pub struct MtmdCliContext { pub chat: Vec, } -impl MtmdCliContext { +impl<'a> MtmdCliContext<'a> { /// Creates a new MTMD CLI context /// /// # Errors diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index b96588c7..acb7ecec 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -2,10 +2,11 @@ use crate::token::LlamaToken; use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id}; +use std::marker::PhantomData; /// A safe wrapper around `llama_batch`. #[derive(Debug)] -pub struct LlamaBatch { +pub struct LlamaBatch<'a> { /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized allocated: usize, /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed. @@ -13,6 +14,7 @@ pub struct LlamaBatch { #[allow(clippy::doc_markdown)] /// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, , )` pub(crate) llama_batch: llama_batch, + phantom: PhantomData<&'a [LlamaToken]>, } /// Errors that can occur when adding a token to a batch. @@ -26,7 +28,7 @@ pub enum BatchAddError { EmptyBuffer, } -impl LlamaBatch { +impl<'a> LlamaBatch<'a> { /// Clear the batch. This does not free the memory associated with the batch, but it does reset /// the number of tokens to 0. pub fn clear(&mut self) { @@ -150,6 +152,7 @@ impl LlamaBatch { allocated: n_tokens, initialized_logits: vec![], llama_batch: batch, + phantom: PhantomData, } } @@ -163,7 +166,7 @@ impl LlamaBatch { /// /// # Panics /// If the number of tokens in ``tokens`` exceeds [`i32::MAX`]. - pub fn get_one(tokens: &[LlamaToken]) -> Result { + pub fn get_one(tokens: &'a [LlamaToken]) -> Result { if tokens.is_empty() { return Err(BatchAddError::EmptyBuffer); } @@ -183,6 +186,7 @@ impl LlamaBatch { .try_into() .expect("number of tokens exceeds i32::MAX + 1")], llama_batch: batch, + phantom: PhantomData, }; Ok(batch) } @@ -194,7 +198,7 @@ impl LlamaBatch { } } -impl Drop for LlamaBatch { +impl<'a> Drop for LlamaBatch<'a> { /// Drops the `LlamaBatch`. /// /// ```