diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 07bd1d1ca..5cc93c39b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -38,6 +38,7 @@ LlamaDiskCache, # type: ignore LlamaStaticDiskCache, # type: ignore LlamaRAMCache, # type: ignore + StateReloadError, # type: ignore ) import numpy as np @@ -1234,16 +1235,23 @@ def logit_bias_processor( file=sys.stderr, ) - before = time.time() - self.load_state(cache_item) - after = time.time() - if self.verbose: - print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr) - if self.verbose: - print( - f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}", - file=sys.stderr, - ) + try: + before = time.time() + self.cache.reload_from_cache_state(self, cache_item) + after = time.time() + if self.verbose: + print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr) + print( + f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}", + file=sys.stderr, + ) + except StateReloadError as e: + if self.verbose: + print( + f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}, but failed to reload state: {e}", + file=sys.stderr, + ) + print("Falling back to re-evaluating prompt", file=sys.stderr) elif self.verbose: print( f"Llama._create_completion: not reloading from cache, cache prefix len {cache_prefix_len} < eval prefix len {eval_prefix_len}", diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index aafba6e1d..4f754a8cc 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -1,3 +1,4 @@ +import ctypes import pickle import sys from abc import ABC, abstractmethod @@ -5,6 +6,7 @@ from typing import Optional, Sequence, Tuple import diskcache +import numpy as np import pytrie import llama_cpp.llama @@ -12,6 +14,12 @@ from .llama_types import * +class StateReloadError(Exception): + """ + Error for when state from cache cannot be read by current model. + """ + + class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" @@ -48,6 +56,20 @@ def __setitem__( ) -> None: raise NotImplementedError + @classmethod + def reload_from_cache_state( + cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState" + ) -> None: + """ + Reload the state onto the model. Normally this is done with load_state + (as state is created with the corresponding `save_state`), but for some + caches may need special handling as an optimization. + + Throws a StateReloadError if the state is not compatible with the model + (for example, logits ) + """ + model.load_state(state) + class LlamaRAMCache(BaseLlamaCache): """Cache for a llama.cpp model using RAM.""" @@ -223,6 +245,7 @@ def build_cache( capacity_bytes: int = 2 << 30, seed: Optional[int] = None, add_bos=True, + save_logits: bool = False, ) -> "LlamaStaticDiskCache": """ Using model passed in, evaluates each prompt and stores LlamaState in cache. @@ -246,6 +269,19 @@ def build_cache( print("LlamaStaticDiskCache.build_cache: eval", file=sys.stderr) model.eval(eval_toks) state = model.save_state() + + if not save_logits: + if ( + model.context_params.logits_all + or model.draft_model is not None + or model.context_params.embeddings + ): + # Erroring instead of falling back to just saving with scores + raise ValueError( + "Cannot save state without logits - model requires logits to sample." + ) + state.scores = None + cache._private_setitem(toks, state) # pylint: disable=protected-access # Set up Trie for efficient prefix search @@ -278,3 +314,83 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): # Should this just be a warning? raise ValueError("Cannot set items in a static cache") + + @classmethod + def reload_from_cache_state( + cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState" + ) -> None: + """ + Skip reloading logits and set last logits from llama.cpp context struct + as the scores for last token of prompt. + """ + # pylint: disable=protected-access + + # Check if model needs logits (draft model, log probs required, etc.) + model_needs_scores_to_reload = ( + # May be overly pessimistic if don't want embeddings for prompt tokens. + model.context_params.embeddings + or model.context_params.logits_all + # Same: is this really a hard requirement? We need token IDs from + # draft model and all the logits from base model to do verification + # of candidate tokens, but not for prompt tokens. + or model.draft_model is not None + ) + + if model_needs_scores_to_reload: + if state.scores is None: + raise StateReloadError( + "Model requires logits to be reloaded, but static cache does not store logits" + ) + else: + model.load_state(state) + return + + # Case where don't need logits from numpy and can just get last-token + # logits from llama.cpp struct + model.n_tokens = state.n_tokens + model.input_ids = state.input_ids.copy() + model.scores[:] = 0.0 + + state_size = state.llama_state_size + + try: + llama_state_array_type = ctypes.c_uint8 * state_size + # Have to do from_buffer_copy since LlamaState.llama_state is + # non-mutable bytes, not mutable bytearray. + llama_state = llama_state_array_type.from_buffer_copy(state.llama_state) + reloaded_state_size = llama_cpp.llama_set_state_data( + model._ctx.ctx, llama_state + ) + + if reloaded_state_size != state_size: + raise StateReloadError( + "Failed to set llama state data - reloaded state size " + f"{reloaded_state_size} does not match original size {state_size}" + ) + + # cffi dtype, compatible w/ numpy through ducktyping :scared: + dtype = llama_cpp.llama_cpp.llama_get_logits_ith.restype._type_ + + # If model scores dtype doesn't match dtype from sig, then can't + # copy it. + if model.scores.dtype != dtype: + raise StateReloadError( + f"Expected scores to be {dtype} but got " + f"{model.scores.dtype} - are you running this in the future? Or the past?" + ) + + # Will have a ValueError for null pointers + last_position_logits = np.array( + ctypes.cast( + model._ctx.get_logits_ith(-1), + ctypes.POINTER(dtype * model.n_vocab()), + ).contents, + # Otherwise will be a view into C array on llama.cpp context + copy=True, + dtype=dtype, + ) + + model._scores[-1, :] = last_position_logits + + except ValueError as e: + raise StateReloadError from e diff --git a/tests/test_llama_cache.py b/tests/test_llama_cache.py new file mode 100644 index 000000000..ab7afb3e5 --- /dev/null +++ b/tests/test_llama_cache.py @@ -0,0 +1,241 @@ +import os +import tempfile + +import pytest + +from llama_cpp.llama import Llama, LlamaState +from llama_cpp.llama_cache import LlamaStaticDiskCache, StateReloadError + + +# Have to be careful to reset to good state when testing, but don't want to +# recreate model each time. +@pytest.fixture(scope="module") +def small_model(): + model_filename = os.getenv("LLAMA_TEST_MODEL") + if not model_filename: + pytest.skip("LLAMA_TEST_MODEL environment variable is not set") + return + + model_filename = os.path.expanduser(model_filename) + + test_model = Llama( + model_filename, + n_ctx=2_048, + n_gpu_layers=0, + offload_kqv=False, + n_batch=512, + embedding=False, + verbose=False, + ) + + system_prompt = r""" +You are an advanced intelligence "Hal" aboard a spaceship. You are required to +act as the primary interface between the ship and its crew. You can: +* Provide information on the current status of the ship +* Turn on/off the lights in the crew quarters +* Open/close the airlocks + +Respond in a terse, professional manner. Do not engage in casual conversation. + +The current state of the ship is: +* Airlocks: closed +* Lights: on +* Oxygen levels: normal +""".strip() + + user_prompt = "Hal, please open the airlocks." + + # Ingest prompt and create completion so that will have some state. + # Last token of prompt + all tokens of generated completion will have + # non-zero logits. + _ = test_model.create_chat_completion( + [ + {"role": "system", "text": system_prompt}, + {"role": "user", "text": user_prompt}, + ], + seed=1234, + ) + + assert test_model.n_tokens > 0 + + # Have at least some scores, and last entry is non-zero + assert ~(test_model.scores == 0).all() + # pylint: disable=protected-access + assert (test_model._scores[-1, :] != 0.0).all() + + return test_model + + +@pytest.fixture(scope="module") +def llama_state(small_model) -> LlamaState: + state = small_model.save_state() + # Clear scores so that can test reloading from cache without them. + state.scores = None + return state + + +def test_reload_from_cache_state_success(small_model, llama_state: LlamaState): + current_state = small_model.save_state() + old_score = small_model.scores.copy() + + LlamaStaticDiskCache.reload_from_cache_state(small_model, llama_state) + new_state = small_model.save_state() + new_score = small_model.scores.copy() + + assert (current_state.input_ids == new_state.input_ids).all() + + assert current_state.n_tokens == new_state.n_tokens + + # Logits for last token should match, others may not if n_batch < n_tokens + assert ( + old_score[small_model.n_tokens - 1, :] == new_score[small_model.n_tokens - 1, :] + ).all() + + +def test_reload_from_cache_state_state_reload_error(small_model, llama_state): + small_model.context_params.logits_all = True + small_model.context_params.embeddings = True + try: + with pytest.raises(StateReloadError): + LlamaStaticDiskCache.reload_from_cache_state(small_model, llama_state) + finally: + small_model.context_params.logits_all = False + small_model.context_params.embeddings = False + + +def test_disk_cache_e2e(small_model: Llama): + prompts = ["this is a test prompt", "and this is another test prompt"] + capacity_bytes = 2 << 30 + + small_model.reset() + # This is a weird thing to reset, but input_ids > n_tokens are not + # significant (like a scratchpad), left over if had previous prompt that + # was longer. + # + # Reset for ease of comparison later. + small_model.input_ids[:] = 0 + + with tempfile.TemporaryDirectory() as cache_dir: + cache = LlamaStaticDiskCache.build_cache( + cache_dir=cache_dir, + prompts=prompts, + model=small_model, + capacity_bytes=capacity_bytes, + add_bos=True, + seed=1234, + save_logits=False, + ) + + for p in prompts: + key = tuple( + small_model.tokenize(p.encode("utf-8"), add_bos=True, special=True) + ) + assert key in cache + state = cache[key] + assert ~(state.input_ids == 0).all() + assert state is not None + assert ( + state.scores is None + ), "Logits should not be stored when save_logits=False and model doesn't require them." + + small_model.reset() + small_model.input_ids[:] = 0 + small_model.eval(key) + + state2 = small_model.save_state() + assert state2.n_tokens == state.n_tokens + assert ~(state2.input_ids == 0).all() + assert (state2.input_ids == state.input_ids).all() + + last_logits = small_model.scores[small_model.n_tokens - 1, :] + + LlamaStaticDiskCache.reload_from_cache_state(small_model, state) + + last_logits2 = small_model.scores[small_model.n_tokens - 1, :] + + assert (last_logits == last_logits2).all() + + +def test_cache_save_reload_scores_when_needed( + small_model: Llama, +): + """ + When model requires it, can reload from state with scores. + """ + test_prompt = "this is a test prompt" + with tempfile.TemporaryDirectory() as cache_dir: + cache = LlamaStaticDiskCache.build_cache( + cache_dir=cache_dir, + prompts=[test_prompt], + model=small_model, + capacity_bytes=2 << 30, + add_bos=True, + seed=1234, + save_logits=True, + ) + + llama_state = small_model.save_state() + cur_scores = llama_state.scores.copy() + assert ~(cur_scores == 0.0).all() + + try: + small_model.context_params.logits_all = True + state_from_cache = cache[ + tuple(llama_state.input_ids[: llama_state.n_tokens].tolist()) + ] + assert state_from_cache.scores is not None, "Scores should be saved." + LlamaStaticDiskCache.reload_from_cache_state(small_model, state_from_cache) + # Do I have to limit these to n_tokens? + assert (state_from_cache.input_ids == llama_state.input_ids).all() + assert ( + cur_scores == small_model.scores[: small_model.n_tokens] + ).all(), "Reloaded scores should match" + finally: + small_model.scores[:] = 0.0 + small_model.context_params.logits_all = False + small_model.reset() + + +def test_cache_reload_errors_when_requires_scores_and_state_doesnt_have_it( + small_model: Llama, llama_state: LlamaState +): + """ + If model requires logits for sampling and state doesn't have it, should raise error. + """ + old_state_scores = ( + llama_state.scores.copy() + if llama_state.scores is not None + else llama_state.scores + ) + try: + small_model.context_params.logits_all = True + llama_state.scores = None + + with pytest.raises(StateReloadError): + LlamaStaticDiskCache.reload_from_cache_state(small_model, llama_state) + finally: + small_model.context_params.logits_all = False + llama_state.scores = old_state_scores + + +# pylint: disable=invalid-name +def test_cache_errors_when_save_logits_False_but_model_requires(small_model: Llama): + """ + If model requires logits but save_logits is False, should raise error. + """ + + try: + small_model.context_params.logits_all = True + with pytest.raises(ValueError): + with tempfile.TemporaryDirectory() as cache_dir: + LlamaStaticDiskCache.build_cache( + cache_dir=cache_dir, + prompts=["this is a test prompt"], + model=small_model, + capacity_bytes=2 << 30, + add_bos=True, + seed=1234, + save_logits=False, + ) + finally: + small_model.context_params.logits_all = False