Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
LlamaDiskCache, # type: ignore
LlamaStaticDiskCache, # type: ignore
LlamaRAMCache, # type: ignore
StateReloadError, # type: ignore
)

import numpy as np
Expand Down Expand Up @@ -1234,16 +1235,23 @@ def logit_bias_processor(
file=sys.stderr,
)

before = time.time()
self.load_state(cache_item)
after = time.time()
if self.verbose:
print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr)
if self.verbose:
print(
f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}",
file=sys.stderr,
)
try:
before = time.time()
self.cache.reload_from_cache_state(self, cache_item)
after = time.time()
if self.verbose:
print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr)
print(
f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}",
file=sys.stderr,
)
except StateReloadError as e:
if self.verbose:
print(
f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}, but failed to reload state: {e}",
file=sys.stderr,
)
print("Falling back to re-evaluating prompt", file=sys.stderr)
elif self.verbose:
print(
f"Llama._create_completion: not reloading from cache, cache prefix len {cache_prefix_len} < eval prefix len {eval_prefix_len}",
Expand Down
116 changes: 116 additions & 0 deletions llama_cpp/llama_cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import ctypes
import pickle
import sys
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Optional, Sequence, Tuple

import diskcache
import numpy as np
import pytrie

import llama_cpp.llama

from .llama_types import *


class StateReloadError(Exception):
"""
Error for when state from cache cannot be read by current model.
"""


class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model."""

Expand Down Expand Up @@ -48,6 +56,20 @@ def __setitem__(
) -> None:
raise NotImplementedError

@classmethod
def reload_from_cache_state(
cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState"
) -> None:
"""
Reload the state onto the model. Normally this is done with load_state
(as state is created with the corresponding `save_state`), but for some
caches may need special handling as an optimization.

Throws a StateReloadError if the state is not compatible with the model
(for example, logits )
"""
model.load_state(state)


class LlamaRAMCache(BaseLlamaCache):
"""Cache for a llama.cpp model using RAM."""
Expand Down Expand Up @@ -223,6 +245,7 @@ def build_cache(
capacity_bytes: int = 2 << 30,
seed: Optional[int] = None,
add_bos=True,
save_logits: bool = False,
) -> "LlamaStaticDiskCache":
"""
Using model passed in, evaluates each prompt and stores LlamaState in cache.
Expand All @@ -246,6 +269,19 @@ def build_cache(
print("LlamaStaticDiskCache.build_cache: eval", file=sys.stderr)
model.eval(eval_toks)
state = model.save_state()

if not save_logits:
if (
model.context_params.logits_all
or model.draft_model is not None
or model.context_params.embeddings
):
# Erroring instead of falling back to just saving with scores
raise ValueError(
"Cannot save state without logits - model requires logits to sample."
)
state.scores = None

cache._private_setitem(toks, state) # pylint: disable=protected-access

# Set up Trie for efficient prefix search
Expand Down Expand Up @@ -278,3 +314,83 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
# Should this just be a warning?
raise ValueError("Cannot set items in a static cache")

@classmethod
def reload_from_cache_state(
cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState"
) -> None:
"""
Skip reloading logits and set last logits from llama.cpp context struct
as the scores for last token of prompt.
"""
# pylint: disable=protected-access

# Check if model needs logits (draft model, log probs required, etc.)
model_needs_scores_to_reload = (
# May be overly pessimistic if don't want embeddings for prompt tokens.
model.context_params.embeddings
or model.context_params.logits_all
# Same: is this really a hard requirement? We need token IDs from
# draft model and all the logits from base model to do verification
# of candidate tokens, but not for prompt tokens.
or model.draft_model is not None
)

if model_needs_scores_to_reload:
if state.scores is None:
raise StateReloadError(
"Model requires logits to be reloaded, but static cache does not store logits"
)
else:
model.load_state(state)
return

# Case where don't need logits from numpy and can just get last-token
# logits from llama.cpp struct
model.n_tokens = state.n_tokens
model.input_ids = state.input_ids.copy()
model.scores[:] = 0.0

state_size = state.llama_state_size

try:
llama_state_array_type = ctypes.c_uint8 * state_size
# Have to do from_buffer_copy since LlamaState.llama_state is
# non-mutable bytes, not mutable bytearray.
llama_state = llama_state_array_type.from_buffer_copy(state.llama_state)
reloaded_state_size = llama_cpp.llama_set_state_data(
model._ctx.ctx, llama_state
)

if reloaded_state_size != state_size:
raise StateReloadError(
"Failed to set llama state data - reloaded state size "
f"{reloaded_state_size} does not match original size {state_size}"
)

# cffi dtype, compatible w/ numpy through ducktyping :scared:
dtype = llama_cpp.llama_cpp.llama_get_logits_ith.restype._type_

# If model scores dtype doesn't match dtype from sig, then can't
# copy it.
if model.scores.dtype != dtype:
raise StateReloadError(
f"Expected scores to be {dtype} but got "
f"{model.scores.dtype} - are you running this in the future? Or the past?"
)

# Will have a ValueError for null pointers
last_position_logits = np.array(
ctypes.cast(
model._ctx.get_logits_ith(-1),
ctypes.POINTER(dtype * model.n_vocab()),
).contents,
# Otherwise will be a view into C array on llama.cpp context
copy=True,
dtype=dtype,
)

model._scores[-1, :] = last_position_logits

except ValueError as e:
raise StateReloadError from e
Loading