From d279371c19d8bdc387e78ec62b672409f5320fcc Mon Sep 17 00:00:00 2001 From: Juho Peltonen Date: Fri, 7 Apr 2023 01:08:47 +0300 Subject: [PATCH] Reserve more eval memory and use ggml scratch buffers --- ggml/src/lib.rs | 39 +++++++++++++++++++++++++++++++++++++++ llama-rs/src/lib.rs | 32 +++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 22c9eee8..a54b72ca 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -303,6 +303,26 @@ impl Context { pub fn used_mem(&self) -> usize { unsafe { ggml_sys::ggml_used_mem(self.ptr.as_ptr()) } } + + /// Set scratch buffer + pub fn use_scratch(&self, scratch_buffer: Option<&mut Buffer>) { + let (size, data) = if let Some(buffer) = scratch_buffer { + (buffer.data.len(), buffer.data.as_ptr() as *mut c_void) + } else { + (0, std::ptr::null_mut()) + }; + // SAFETY: this just passes (most likely uninitialized) memory buffer to the ggml C API + unsafe { + ggml_sys::ggml_set_scratch( + self.ptr.as_ptr(), + ggml_sys::ggml_scratch { + offs: 0, + size, + data, + }, + ); + } + } } impl Drop for Context { @@ -315,6 +335,25 @@ impl Drop for Context { } } +/// Pre-allocated buffer +pub struct Buffer { + data: Vec, +} + +impl Buffer { + /// Creates new buffer + pub fn new(size: usize) -> Self { + let mut data: Vec = Vec::with_capacity(size); + // SAFETY: contents are left uninitialized. Don't use them. + #[allow(clippy::uninit_vec)] + unsafe { + data.set_len(size) + }; + + Buffer { data } + } +} + /// Tensors are owned by the context. A tensor is alive as long as the /// underlying context it was created with is alive. pub struct Tensor { diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index d5ef2a23..8636a86e 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -27,6 +27,8 @@ mod util; /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) +const SCRATCH_SIZE: usize = 512 * 1024 * 1024; // 512MB + /// The hyperparameters of the model. #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)] pub struct Hyperparameters { @@ -103,6 +105,9 @@ pub struct InferenceSession { /// The logits that were last predicted by the network. Zeroed out otherwise. last_logits: Vec, + + /// Scratch buffers + scratch: [ggml::Buffer; 2], } impl InferenceSession { fn repetition_penalty_tokens(&self) -> &[TokenId] { @@ -128,10 +133,18 @@ impl Clone for InferenceSession { mem_per_token: self.mem_per_token, tokens: self.tokens.clone(), last_logits: self.last_logits.clone(), + scratch: inference_session_scratch_buffers(), } } } +fn inference_session_scratch_buffers() -> [ggml::Buffer; 2] { + [ + ggml::Buffer::new(SCRATCH_SIZE), + ggml::Buffer::new(SCRATCH_SIZE), + ] +} + #[derive(serde::Serialize, Clone, PartialEq)] /// A serializable snapshot of the inference process. Can be saved to disk. // Keep in sync with [InferenceSession] and [InferenceSnapshot] @@ -1116,6 +1129,7 @@ impl Model { mem_per_token: 0, tokens: vec![], last_logits: vec![0.0; n_vocab], + scratch: inference_session_scratch_buffers(), } } @@ -1150,7 +1164,15 @@ impl Model { // For the first run, we need to guess a maximum buffer size so we can measure // the actual memory consumption of the temporary ggml context. - let mut buf_size = 1024 * 1024 * 1024; + let mut buf_size = 1024 + * 1024 + * if n_layer >= 80 { + 1536 + } else if n_layer >= 60 { + 1280 + } else { + 1024 + }; if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { // add 10% to account for ggml object overhead buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; @@ -1189,6 +1211,8 @@ impl Model { let input_self_attention = input_layer.share(); let mut current: ggml::Tensor; + ctx0.use_scratch(Some(&mut session.scratch[0])); + // norm { current = ctx0.op_rms_norm(&input_layer); @@ -1312,6 +1336,8 @@ impl Model { current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); } + ctx0.use_scratch(Some(&mut session.scratch[1])); + let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); // feed-forward network @@ -1345,6 +1371,8 @@ impl Model { input_layer = current; } + ctx0.use_scratch(Some(&mut session.scratch[0])); + // Used at the end to optionally extract the embeddings. let embeddings_tensor; @@ -1362,6 +1390,8 @@ impl Model { input_layer = ctx0.op_mul_mat(&self.output, &input_layer); } + ctx0.use_scratch(None); + // logits -> probs // inpL = ctx0.op_soft_max(&inpL);