Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/mtmd/src/mtmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ pub struct MtmdCliParams {

/// State of the MTMD CLI application.
#[allow(missing_debug_implementations)]
pub struct MtmdCliContext {
pub struct MtmdCliContext<'a> {
/// The MTMD context for multimodal processing.
pub mtmd_ctx: MtmdContext,
/// The batch used for processing tokens.
pub batch: LlamaBatch,
pub batch: LlamaBatch<'a>,
/// The list of loaded bitmaps (images/audio).
pub bitmaps: Vec<MtmdBitmap>,
/// The number of past tokens processed.
Expand All @@ -87,7 +87,7 @@ pub struct MtmdCliContext {
pub chat: Vec<LlamaChatMessage>,
}

impl MtmdCliContext {
impl<'a> MtmdCliContext<'a> {
/// Creates a new MTMD CLI context
///
/// # Errors
Expand Down
12 changes: 8 additions & 4 deletions llama-cpp-2/src/llama_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

use crate::token::LlamaToken;
use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id};
use std::marker::PhantomData;

/// A safe wrapper around `llama_batch`.
#[derive(Debug)]
pub struct LlamaBatch {
pub struct LlamaBatch<'a> {
/// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
allocated: usize,
/// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
pub(crate) initialized_logits: Vec<i32>,
#[allow(clippy::doc_markdown)]
/// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
pub(crate) llama_batch: llama_batch,
phantom: PhantomData<&'a [LlamaToken]>,
}

/// Errors that can occur when adding a token to a batch.
Expand All @@ -26,7 +28,7 @@ pub enum BatchAddError {
EmptyBuffer,
}

impl LlamaBatch {
impl<'a> LlamaBatch<'a> {
/// Clear the batch. This does not free the memory associated with the batch, but it does reset
/// the number of tokens to 0.
pub fn clear(&mut self) {
Expand Down Expand Up @@ -150,6 +152,7 @@ impl LlamaBatch {
allocated: n_tokens,
initialized_logits: vec![],
llama_batch: batch,
phantom: PhantomData,
}
}

Expand All @@ -163,7 +166,7 @@ impl LlamaBatch {
///
/// # Panics
/// If the number of tokens in ``tokens`` exceeds [`i32::MAX`].
pub fn get_one(tokens: &[LlamaToken]) -> Result<Self, BatchAddError> {
pub fn get_one(tokens: &'a [LlamaToken]) -> Result<Self, BatchAddError> {
if tokens.is_empty() {
return Err(BatchAddError::EmptyBuffer);
}
Expand All @@ -183,6 +186,7 @@ impl LlamaBatch {
.try_into()
.expect("number of tokens exceeds i32::MAX + 1")],
llama_batch: batch,
phantom: PhantomData,
};
Ok(batch)
}
Expand All @@ -194,7 +198,7 @@ impl LlamaBatch {
}
}

impl Drop for LlamaBatch {
impl<'a> Drop for LlamaBatch<'a> {
/// Drops the `LlamaBatch`.
///
/// ```
Expand Down
Loading