Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update create_states_mapping function to include tokenizer parameter #873

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion outlines/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def cache(key_function: Optional[Callable] = None):
----------
key_function
A callable function used to generate a unique key for each function call. It's
called with the arguments of the decorated function as arguments
called with the arguments of the decorated function as arguments and returns an
iterable of values that are used to create the cache key.
Returns
-------
A decorator function that can be applied to other functions.
Expand Down
17 changes: 14 additions & 3 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,19 @@ class RegexGuide(Guide):
initial_state = 0

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
@cache(
key_function=lambda regex_string, tokenizer: (
regex_string,
tokenizer.eos_token,
tokenizer.eos_token_id,
tokenizer.pad_token_id,
tuple(sorted(tokenizer.vocabulary.items())),
tuple(sorted(tokenizer.special_tokens)),
)
)
def create_states_mapping(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
Expand Down Expand Up @@ -142,7 +153,7 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string)
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

Expand Down
61 changes: 61 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
from unittest.mock import MagicMock, call

import diskcache
import pytest
Expand Down Expand Up @@ -41,6 +42,21 @@ def test_cache(refresh_environment):
memory.clear()


@pytest.fixture
def test_cache_custom_key_function(refresh_environment):
"""Initialize a temporary cache and delete it after the test has run."""
with tempfile.TemporaryDirectory() as tempdir:
os.environ["OUTLINES_CACHE_DIR"] = tempdir
import outlines

memory = outlines.get_cache()
assert memory.directory == tempdir

yield outlines.caching.cache(key_function=lambda x: (sorted(x.keys())))

memory.clear()


def test_get_cache(test_cache):
import outlines

Expand All @@ -67,6 +83,51 @@ def f(x):
assert len(store) == store_size + 1


def test_get_cache_custom_key_function(test_cache_custom_key_function):
import outlines

memory = outlines.get_cache()
assert isinstance(memory, diskcache.Cache)

# GIVEN a cached function with a custom `key_function`
@test_cache_custom_key_function
def f(x: dict):
return len(x)

# and a particular cache state
cached_items = len(list(memory.iterkeys()))

# WHEN the function is called with a dictionary
mocked_dict = MagicMock()
mocked_dict.keys.return_value = ["a", "b", "c"]
f(mocked_dict)

# THEN the cache should have increased
assert len(list(memory.iterkeys())) > cached_items

# and the parameter used in the call has been "treated" by the custom `key_function`
mocked_dict.keys.assert_has_calls([call()])

# GIVEN the new cache state
cached_items = len(list(memory.iterkeys()))

# WHEN the function is called with the same dictionary
f(mocked_dict)

# THEN the cache should not have increased
assert len(list(memory.iterkeys())) == cached_items

# but the key_function has been called again
mocked_dict.keys.assert_has_calls([call(), call()])

# GIVEN the unchanged cache state
# WHEN the function is called with a different dictionary
f({"foo": "bar"})

# THEN the cache should have increased once again
assert len(list(memory.iterkeys())) > cached_items


def test_disable_cache(test_cache):
"""Make sure that we can disable the cache."""
import outlines
Expand Down