From 7ca7f284de69c622e8758c1d6ef6481b67904309 Mon Sep 17 00:00:00 2001 From: Sloane Date: Mon, 14 Oct 2024 12:46:21 -0500 Subject: [PATCH 1/9] Allow reloading without scores - Create `reload_from_cache_state` method - Still using LLamaState as container - Use low level `ctx.get_logits_ith` to get last calculated logits. - Add StateReloadError so that can be fallible. - Change Llama class to use this instead of `load_state` directly. - Default implementation still uses `load_state`. --- llama_cpp/llama.py | 2 +- llama_cpp/llama_cache.py | 79 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 07bd1d1ca..48ca6241a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1235,7 +1235,7 @@ def logit_bias_processor( ) before = time.time() - self.load_state(cache_item) + self.cache.reload_from_cache_state(self, cache_item) after = time.time() if self.verbose: print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr) diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index aafba6e1d..9bf90edf1 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -1,3 +1,4 @@ +import ctypes import pickle import sys from abc import ABC, abstractmethod @@ -5,6 +6,7 @@ from typing import Optional, Sequence, Tuple import diskcache +import numpy as np import pytrie import llama_cpp.llama @@ -12,6 +14,12 @@ from .llama_types import * +class StateReloadError(Exception): + """ + Error for when state from cache cannot be read by current model. + """ + + class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" @@ -48,6 +56,20 @@ def __setitem__( ) -> None: raise NotImplementedError + @classmethod + def reload_from_cache_state( + cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState" + ) -> None: + """ + Reload the state onto the model. Normally this is done with load_state + (as state is created with the corresponding `save_state`), but for some + caches may need special handling as an optimization. + + Throws a StateReloadError if the state is not compatible with the model + (for example, logits ) + """ + model.load_state(state) + class LlamaRAMCache(BaseLlamaCache): """Cache for a llama.cpp model using RAM.""" @@ -278,3 +300,60 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): # Should this just be a warning? raise ValueError("Cannot set items in a static cache") + + @classmethod + def reload_from_cache_state( + cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState" + ) -> None: + """ + Skip reloading logits and set last logits from llama.cpp context struct + as the scores for last token of prompt. + """ + # pylint: disable=protected-access + # Check if model needs logits (draft model, log probs required, etc.) + if ( + # May be overly pessimistic if don't want embeddings for prompt tokens. + model.context_params.embeddings + or model.context_params.logits_all + # Same: is this really a hard requirement? We need token IDs from + # draft model and all the logits from base model to do verification + # of candidate tokens, but not for prompt tokens. + or model.draft_model is not None + ): + raise StateReloadError( + "Model requires logits to be reloaded, but static cache does not store logits" + ) + + model.n_tokens = state.n_tokens + model.input_ids = state.input_ids.copy() + model.scores[:] = 0.0 + + state_size = state.llama_state_size + + try: + llama_state_array_type = ctypes.c_uint8 * state_size + # Don't have to do `from_buffer_copy` since `llama_set_state_data` + # will copy anyway. + llama_state = llama_state_array_type.from_buffer(state.llama_state) + reloaded_state_size = llama_cpp.llama_set_state_data( + model._ctx.ctx, llama_state + ) + + if reloaded_state_size != state_size: + raise StateReloadError( + "Failed to set llama state data - reloaded state size " + f"{reloaded_state_size} does not match original size {state_size}" + ) + + # Will have a ValueError for null pointers + last_position_logits = np.array( + ctypes.cast( + model._ctx.get_logits_ith(-1), + ctypes.POINTER(ctypes.c_float * model.n_vocab()), + ) + ) + + model._scores[-1, :] = last_position_logits.copy() + + except ValueError as e: + raise StateReloadError from e From ed6c3542e2ece215eb68bf7bd3f5d9bf0181f5cd Mon Sep 17 00:00:00 2001 From: Sloane Date: Mon, 14 Oct 2024 15:55:57 -0500 Subject: [PATCH 2/9] Fix bug - Use ptr.contents, not ptr in `np.array` - Get dtype from return type on annotated signature - Explicitly set copy=True and dtype on np.array - Should not strictly be necessary since pointer is typed --- llama_cpp/llama_cache.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index 9bf90edf1..328584dd0 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -345,15 +345,29 @@ def reload_from_cache_state( f"{reloaded_state_size} does not match original size {state_size}" ) + # cffi dtype, compatible w/ numpy through ducktyping :scared: + dtype = llama_cpp.llama_cpp.llama_get_logits_ith.restype._type_ + + # If model scores dtype doesn't match dtype from sig, then can't + # copy it. + if model.scores.dtype != dtype: + raise StateReloadError( + f"Expected scores to be {dtype} but got " + f"{model.scores.dtype} - are you running this in the future? Or the past?" + ) + # Will have a ValueError for null pointers last_position_logits = np.array( ctypes.cast( model._ctx.get_logits_ith(-1), - ctypes.POINTER(ctypes.c_float * model.n_vocab()), - ) + ctypes.POINTER(dtype * model.n_vocab()), + ).contents, + # Otherwise will be a view into C array on llama.cpp context + copy=True, + dtype=dtype, ) - model._scores[-1, :] = last_position_logits.copy() + model._scores[-1, :] = last_position_logits except ValueError as e: raise StateReloadError from e From 90d42c31fbdc3ccfa958b870377b2a49817d3ad1 Mon Sep 17 00:00:00 2001 From: Sloane Date: Tue, 15 Oct 2024 13:02:32 -0500 Subject: [PATCH 3/9] Catch StateReloadError Catch StateReloadError and add logging if runs into this when running. --- llama_cpp/llama.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 48ca6241a..5cc93c39b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -38,6 +38,7 @@ LlamaDiskCache, # type: ignore LlamaStaticDiskCache, # type: ignore LlamaRAMCache, # type: ignore + StateReloadError, # type: ignore ) import numpy as np @@ -1234,16 +1235,23 @@ def logit_bias_processor( file=sys.stderr, ) - before = time.time() - self.cache.reload_from_cache_state(self, cache_item) - after = time.time() - if self.verbose: - print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr) - if self.verbose: - print( - f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}", - file=sys.stderr, - ) + try: + before = time.time() + self.cache.reload_from_cache_state(self, cache_item) + after = time.time() + if self.verbose: + print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr) + print( + f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}", + file=sys.stderr, + ) + except StateReloadError as e: + if self.verbose: + print( + f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}, but failed to reload state: {e}", + file=sys.stderr, + ) + print("Falling back to re-evaluating prompt", file=sys.stderr) elif self.verbose: print( f"Llama._create_completion: not reloading from cache, cache prefix len {cache_prefix_len} < eval prefix len {eval_prefix_len}", From 46718c9795bb534de4e233fe530ad9d55ed2c9d6 Mon Sep 17 00:00:00 2001 From: Sloane Date: Tue, 15 Oct 2024 14:45:26 -0500 Subject: [PATCH 4/9] Add tests - Fix loading state (from_buffer -> from_buffer_copy since bytes aren't mutable) - Add tests (E2E, errors when should, reloads successfully, logits correct, etc.) Have to set LLAMA_TEST_MODEL to point to model path in order to get this to run. --- llama_cpp/llama_cache.py | 6 +- tests/test_llama_cache.py | 148 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 tests/test_llama_cache.py diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index 328584dd0..edb561027 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -332,9 +332,9 @@ def reload_from_cache_state( try: llama_state_array_type = ctypes.c_uint8 * state_size - # Don't have to do `from_buffer_copy` since `llama_set_state_data` - # will copy anyway. - llama_state = llama_state_array_type.from_buffer(state.llama_state) + # Have to do from_buffer_copy since LlamaState.llama_state is + # non-mutable bytes, not mutable bytearray. + llama_state = llama_state_array_type.from_buffer_copy(state.llama_state) reloaded_state_size = llama_cpp.llama_set_state_data( model._ctx.ctx, llama_state ) diff --git a/tests/test_llama_cache.py b/tests/test_llama_cache.py new file mode 100644 index 000000000..89999af12 --- /dev/null +++ b/tests/test_llama_cache.py @@ -0,0 +1,148 @@ +import os +import tempfile + +import pytest + +from llama_cpp.llama import Llama, LlamaState +from llama_cpp.llama_cache import LlamaStaticDiskCache, StateReloadError + + +@pytest.fixture +def small_model(): + model_filename = os.getenv("LLAMA_TEST_MODEL") + if not model_filename: + pytest.skip("LLAMA_TEST_MODEL environment variable is not set") + return + + model_filename = os.path.expanduser(model_filename) + + test_model = Llama( + model_filename, + n_ctx=2_048, + n_gpu_layers=0, + offload_kqv=False, + n_batch=512, + embedding=False, + verbose=False, + ) + + system_prompt = r""" +You are an advanced intelligence "Hal" aboard a spaceship. You are required to +act as the primary interface between the ship and its crew. You can: +* Provide information on the current status of the ship +* Turn on/off the lights in the crew quarters +* Open/close the airlocks + +Respond in a terse, professional manner. Do not engage in casual conversation. + +The current state of the ship is: +* Airlocks: closed +* Lights: on +* Oxygen levels: normal +""".strip() + + user_prompt = "Hal, please open the airlocks." + + # Ingest prompt and create completion so that will have some state. + # Last token of prompt + all tokens of generated completion will have + # non-zero logits. + _ = test_model.create_chat_completion( + [ + {"role": "system", "text": system_prompt}, + {"role": "user", "text": user_prompt}, + ], + seed=1234, + ) + + assert test_model.n_tokens > 0 + + # Have at least some scores, and last entry is non-zero + assert ~(test_model.scores == 0).all() + # pylint: disable=protected-access + assert (test_model._scores[-1, :] != 0.0).all() + + return test_model + + +@pytest.fixture +def llama_state(small_model) -> LlamaState: + state = small_model.save_state() + return state + + +def test_reload_from_cache_state_success(small_model, llama_state: LlamaState): + current_state = small_model.save_state() + old_score = small_model.scores.copy() + + LlamaStaticDiskCache.reload_from_cache_state(small_model, llama_state) + new_state = small_model.save_state() + new_score = small_model.scores.copy() + + assert (current_state.input_ids == new_state.input_ids).all() + + assert current_state.n_tokens == new_state.n_tokens + + # Logits for last token should match, others may not if n_batch < n_tokens + assert ( + old_score[small_model.n_tokens - 1, :] == new_score[small_model.n_tokens - 1, :] + ).all() + + +def test_reload_from_cache_state_state_reload_error(small_model, llama_state): + small_model.context_params.logits_all = True + small_model.context_params.embeddings = True + try: + with pytest.raises(StateReloadError): + LlamaStaticDiskCache.reload_from_cache_state(small_model, llama_state) + finally: + small_model.context_params.logits_all = False + small_model.context_params.embeddings = False + + +def test_disk_cache_e2e(small_model: Llama): + prompts = ["this is a test prompt", "and this is another test prompt"] + capacity_bytes = 2 << 30 + + small_model.reset() + # This is a weird thing to reset, but input_ids > n_tokens are not + # significant (like a scratchpad), left over if had previous prompt that + # was longer. + # + # Reset for ease of comparison later. + small_model.input_ids[:] = 0 + + with tempfile.TemporaryDirectory() as cache_dir: + cache = LlamaStaticDiskCache.build_cache( + cache_dir=cache_dir, + prompts=prompts, + model=small_model, + capacity_bytes=capacity_bytes, + add_bos=True, + seed=1234, + ) + + for p in prompts: + key = tuple( + small_model.tokenize(p.encode("utf-8"), add_bos=True, special=True) + ) + assert key in cache + state = cache[key] + assert ~(state.input_ids == 0).all() + assert state is not None + + small_model.reset() + small_model.input_ids[:] = 0 + small_model.eval(key) + + state2 = small_model.save_state() + assert state2.n_tokens == state.n_tokens + assert ~(state2.input_ids == 0).all() + assert (state2.input_ids == state.input_ids).all() + + last_logits = small_model.scores[small_model.n_tokens - 1, :] + + LlamaStaticDiskCache.reload_from_cache_state(small_model, state) + + last_logits2 = small_model.scores[small_model.n_tokens - 1, :] + + assert (last_logits == last_logits2).all() From 8362cfaeae45de5ffd23021c25660fc8f504591d Mon Sep 17 00:00:00 2001 From: Sloane Date: Tue, 15 Oct 2024 17:23:18 -0500 Subject: [PATCH 5/9] Skip saving logits D'oh --- llama_cpp/llama_cache.py | 5 +++++ tests/test_llama_cache.py | 1 + 2 files changed, 6 insertions(+) diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index edb561027..d5ea5682d 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -245,6 +245,7 @@ def build_cache( capacity_bytes: int = 2 << 30, seed: Optional[int] = None, add_bos=True, + save_logits: bool = True, ) -> "LlamaStaticDiskCache": """ Using model passed in, evaluates each prompt and stores LlamaState in cache. @@ -268,6 +269,10 @@ def build_cache( print("LlamaStaticDiskCache.build_cache: eval", file=sys.stderr) model.eval(eval_toks) state = model.save_state() + + if not save_logits: + state.scores = None + cache._private_setitem(toks, state) # pylint: disable=protected-access # Set up Trie for efficient prefix search diff --git a/tests/test_llama_cache.py b/tests/test_llama_cache.py index 89999af12..ad28567c4 100644 --- a/tests/test_llama_cache.py +++ b/tests/test_llama_cache.py @@ -119,6 +119,7 @@ def test_disk_cache_e2e(small_model: Llama): capacity_bytes=capacity_bytes, add_bos=True, seed=1234, + save_logits=False, ) for p in prompts: From de7f862d190ec34d6caccb54e66ba26c326c4348 Mon Sep 17 00:00:00 2001 From: Sloane Date: Tue, 15 Oct 2024 17:27:29 -0500 Subject: [PATCH 6/9] Add check + note - Check when saving that model doesn't need logits - Ad note in `reload_from_cache` state to revisit --- llama_cpp/llama_cache.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index d5ea5682d..ed133211e 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -271,6 +271,12 @@ def build_cache( state = model.save_state() if not save_logits: + if ( + model.context_params.logits_all + or model.draft_model is not None + or model.context_params.embeddings + ): + raise ValueError("Cannot save state without logits") state.scores = None cache._private_setitem(toks, state) # pylint: disable=protected-access @@ -313,6 +319,9 @@ def reload_from_cache_state( """ Skip reloading logits and set last logits from llama.cpp context struct as the scores for last token of prompt. + + TODO: This always assumes want to skip loading logits, but could check + if state has scores that are not None. """ # pylint: disable=protected-access # Check if model needs logits (draft model, log probs required, etc.) From b4e2156a4e34bac228fc00aa1baf43b6cd1fd6b5 Mon Sep 17 00:00:00 2001 From: Sloane Date: Mon, 21 Oct 2024 11:29:10 -0500 Subject: [PATCH 7/9] Finalize llama cache changes - Make default to *not* save logits - Error if needed and save_logits False in build_cache - Handle reloading with/without scores if needed + available --- llama_cpp/llama_cache.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index ed133211e..f9c1a8cb0 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -245,7 +245,7 @@ def build_cache( capacity_bytes: int = 2 << 30, seed: Optional[int] = None, add_bos=True, - save_logits: bool = True, + save_logits: bool = False, ) -> "LlamaStaticDiskCache": """ Using model passed in, evaluates each prompt and stores LlamaState in cache. @@ -276,7 +276,10 @@ def build_cache( or model.draft_model is not None or model.context_params.embeddings ): - raise ValueError("Cannot save state without logits") + # Erroring instead of falling back to just saving with scores + raise ValueError( + "Cannot save state without logits - model requires logits to sample." + ) state.scores = None cache._private_setitem(toks, state) # pylint: disable=protected-access @@ -319,13 +322,11 @@ def reload_from_cache_state( """ Skip reloading logits and set last logits from llama.cpp context struct as the scores for last token of prompt. - - TODO: This always assumes want to skip loading logits, but could check - if state has scores that are not None. """ # pylint: disable=protected-access + # Check if model needs logits (draft model, log probs required, etc.) - if ( + need_to_reload_without_scores = ( # May be overly pessimistic if don't want embeddings for prompt tokens. model.context_params.embeddings or model.context_params.logits_all @@ -333,11 +334,19 @@ def reload_from_cache_state( # draft model and all the logits from base model to do verification # of candidate tokens, but not for prompt tokens. or model.draft_model is not None - ): - raise StateReloadError( - "Model requires logits to be reloaded, but static cache does not store logits" - ) + ) + + if need_to_reload_without_scores: + if state.scores is None: + raise StateReloadError( + "Model requires logits to be reloaded, but static cache does not store logits" + ) + else: + model.load_state(state) + return + # Case where don't need logits from numpy and can just get last-token + # logits from llama.cpp struct model.n_tokens = state.n_tokens model.input_ids = state.input_ids.copy() model.scores[:] = 0.0 From ca5d1a466d203986e722bd994c800a042a3584ce Mon Sep 17 00:00:00 2001 From: Sloane Date: Mon, 21 Oct 2024 11:29:47 -0500 Subject: [PATCH 8/9] Finalize tests - Add more tests - Make llama_state / small_model module scope (so don't need to reload for each test) - Setting env var in `.env` file --- tests/test_llama_cache.py | 96 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/tests/test_llama_cache.py b/tests/test_llama_cache.py index ad28567c4..ab7afb3e5 100644 --- a/tests/test_llama_cache.py +++ b/tests/test_llama_cache.py @@ -7,7 +7,9 @@ from llama_cpp.llama_cache import LlamaStaticDiskCache, StateReloadError -@pytest.fixture +# Have to be careful to reset to good state when testing, but don't want to +# recreate model each time. +@pytest.fixture(scope="module") def small_model(): model_filename = os.getenv("LLAMA_TEST_MODEL") if not model_filename: @@ -64,9 +66,11 @@ def small_model(): return test_model -@pytest.fixture +@pytest.fixture(scope="module") def llama_state(small_model) -> LlamaState: state = small_model.save_state() + # Clear scores so that can test reloading from cache without them. + state.scores = None return state @@ -130,6 +134,9 @@ def test_disk_cache_e2e(small_model: Llama): state = cache[key] assert ~(state.input_ids == 0).all() assert state is not None + assert ( + state.scores is None + ), "Logits should not be stored when save_logits=False and model doesn't require them." small_model.reset() small_model.input_ids[:] = 0 @@ -147,3 +154,88 @@ def test_disk_cache_e2e(small_model: Llama): last_logits2 = small_model.scores[small_model.n_tokens - 1, :] assert (last_logits == last_logits2).all() + + +def test_cache_save_reload_scores_when_needed( + small_model: Llama, +): + """ + When model requires it, can reload from state with scores. + """ + test_prompt = "this is a test prompt" + with tempfile.TemporaryDirectory() as cache_dir: + cache = LlamaStaticDiskCache.build_cache( + cache_dir=cache_dir, + prompts=[test_prompt], + model=small_model, + capacity_bytes=2 << 30, + add_bos=True, + seed=1234, + save_logits=True, + ) + + llama_state = small_model.save_state() + cur_scores = llama_state.scores.copy() + assert ~(cur_scores == 0.0).all() + + try: + small_model.context_params.logits_all = True + state_from_cache = cache[ + tuple(llama_state.input_ids[: llama_state.n_tokens].tolist()) + ] + assert state_from_cache.scores is not None, "Scores should be saved." + LlamaStaticDiskCache.reload_from_cache_state(small_model, state_from_cache) + # Do I have to limit these to n_tokens? + assert (state_from_cache.input_ids == llama_state.input_ids).all() + assert ( + cur_scores == small_model.scores[: small_model.n_tokens] + ).all(), "Reloaded scores should match" + finally: + small_model.scores[:] = 0.0 + small_model.context_params.logits_all = False + small_model.reset() + + +def test_cache_reload_errors_when_requires_scores_and_state_doesnt_have_it( + small_model: Llama, llama_state: LlamaState +): + """ + If model requires logits for sampling and state doesn't have it, should raise error. + """ + old_state_scores = ( + llama_state.scores.copy() + if llama_state.scores is not None + else llama_state.scores + ) + try: + small_model.context_params.logits_all = True + llama_state.scores = None + + with pytest.raises(StateReloadError): + LlamaStaticDiskCache.reload_from_cache_state(small_model, llama_state) + finally: + small_model.context_params.logits_all = False + llama_state.scores = old_state_scores + + +# pylint: disable=invalid-name +def test_cache_errors_when_save_logits_False_but_model_requires(small_model: Llama): + """ + If model requires logits but save_logits is False, should raise error. + """ + + try: + small_model.context_params.logits_all = True + with pytest.raises(ValueError): + with tempfile.TemporaryDirectory() as cache_dir: + LlamaStaticDiskCache.build_cache( + cache_dir=cache_dir, + prompts=["this is a test prompt"], + model=small_model, + capacity_bytes=2 << 30, + add_bos=True, + seed=1234, + save_logits=False, + ) + finally: + small_model.context_params.logits_all = False From 0967eda5326e6894def3286c8eb2b3c9ad4e227d Mon Sep 17 00:00:00 2001 From: Sloane Date: Mon, 21 Oct 2024 11:31:55 -0500 Subject: [PATCH 9/9] Better variable name --- llama_cpp/llama_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_cache.py b/llama_cpp/llama_cache.py index f9c1a8cb0..4f754a8cc 100644 --- a/llama_cpp/llama_cache.py +++ b/llama_cpp/llama_cache.py @@ -326,7 +326,7 @@ def reload_from_cache_state( # pylint: disable=protected-access # Check if model needs logits (draft model, log probs required, etc.) - need_to_reload_without_scores = ( + model_needs_scores_to_reload = ( # May be overly pessimistic if don't want embeddings for prompt tokens. model.context_params.embeddings or model.context_params.logits_all @@ -336,7 +336,7 @@ def reload_from_cache_state( or model.draft_model is not None ) - if need_to_reload_without_scores: + if model_needs_scores_to_reload: if state.scores is None: raise StateReloadError( "Model requires logits to be reloaded, but static cache does not store logits"