Skip to content

Commit

Permalink
Rename ignore_hash to allow_output_mutation (#422)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrhone committed Oct 17, 2019
1 parent 25efbef commit 8a2f26f
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 34 deletions.
64 changes: 38 additions & 26 deletions lib/streamlit/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _build_caching_func_error_message(persisted, func, caller_frame):
{copy_code}
```
2. Add `ignore_hash=True` to the `@streamlit.cache` decorator for
2. Add `allow_output_mutation=True` to the `@streamlit.cache` decorator for
`{name}`. This is an escape hatch for advanced users who really know
what they're doing.
Expand Down Expand Up @@ -282,7 +282,7 @@ def _build_caching_block_error_message(persisted, code, line_number_range):
1. *Preferred:* fix the code by removing the mutation. The simplest way
to do this is to copy the cached value to a new variable, which you are
allowed to mutate.
2. Add `ignore_hash=True` to the constructor of `streamlit.Cache`. This
2. Add `allow_output_mutation=True` to the constructor of `streamlit.Cache`. This
is an escape hatch for advanced users who really know what they're
doing.
Expand Down Expand Up @@ -317,11 +317,11 @@ def _build_args_mutated_message(func):
return message.format(name=func.__name__)


def _read_from_mem_cache(key, ignore_hash):
def _read_from_mem_cache(key, allow_output_mutation):
if key in _mem_cache:
entry = _mem_cache[key]

if ignore_hash or get_hash(entry.value) == entry.hash:
if allow_output_mutation or get_hash(entry.value) == entry.hash:
LOGGER.debug("Memory cache HIT: %s", type(entry.value))
return entry.value, entry.args_mutated
else:
Expand All @@ -332,10 +332,10 @@ def _read_from_mem_cache(key, ignore_hash):
raise CacheKeyNotFoundError("Key not found in mem cache")


def _write_to_mem_cache(key, value, ignore_hash, args_mutated):
def _write_to_mem_cache(key, value, allow_output_mutation, args_mutated):
_mem_cache[key] = CacheEntry(
value=value,
hash=None if ignore_hash else get_hash(value),
hash=None if allow_output_mutation else get_hash(value),
args_mutated=args_mutated,
)

Expand Down Expand Up @@ -374,15 +374,15 @@ def _write_to_disk_cache(key, value, args_mutated):
raise CacheError("Unable to write to cache: %s" % e)


def _read_from_cache(key, persisted, ignore_hash, func_or_code, message_opts):
def _read_from_cache(key, persisted, allow_output_mutation, func_or_code, message_opts):
"""
Read the value from the cache. Our goal is to read from memory
if possible. If the data was mutated (hash changed), we show a
warning. If reading from memory fails, we either read from disk
or rerun the code.
"""
try:
return _read_from_mem_cache(key, ignore_hash)
return _read_from_mem_cache(key, allow_output_mutation)
except (CacheKeyNotFoundError, CachedObjectWasMutatedError) as e:
if isinstance(e, CachedObjectWasMutatedError):
if inspect.isroutine(func_or_code):
Expand All @@ -397,23 +397,24 @@ def _read_from_cache(key, persisted, ignore_hash, func_or_code, message_opts):

if persisted:
value, args_mutated = _read_from_disk_cache(key)
_write_to_mem_cache(key, value, ignore_hash, args_mutated)
_write_to_mem_cache(key, value, allow_output_mutation, args_mutated)
return value, args_mutated
raise e


def _write_to_cache(key, value, persist, ignore_hash, args_mutated):
_write_to_mem_cache(key, value, ignore_hash, args_mutated)
def _write_to_cache(key, value, persist, allow_output_mutation, args_mutated):
_write_to_mem_cache(key, value, allow_output_mutation, args_mutated)
if persist:
_write_to_disk_cache(key, value, args_mutated)


def cache(
func=None,
persist=False,
ignore_hash=False,
allow_output_mutation=False,
show_spinner=True,
suppress_st_warning=False,
**kwargs
):
"""Function decorator to memoize function executions.
Expand All @@ -427,9 +428,11 @@ def cache(
persist : boolean
Whether to persist the cache on disk.
ignore_hash : boolean
Disable hashing return values. These hash values are otherwise
used to validate that return values are not mutated.
allow_output_mutation : boolean
Streamlit normally shows a warning when return values are not mutated, as that
can have unintended consequences. This is done by hashing the return value internally.
If you know what you're doing and would like to override this warning, set this to True.
show_spinner : boolean
Enable the spinner. Default is True to show a spinner when there is
Expand Down Expand Up @@ -464,21 +467,28 @@ def cache(
... # Fetch data from URL here, and then clean it up.
... return data
To disable hashing return values, set the `ignore_hash` parameter to `True`:
To disable hashing return values, set the `allow_output_mutation` parameter to `True`:
>>> @st.cache(ignore_hash=True)
>>> @st.cache(allow_output_mutation=True)
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
"""
# Help users migrate to the new kwarg
# Remove this warning after 2020-03-16.
if "ignore_hash" in kwargs:
raise Exception(
"The `ignore_hash` argument has been renamed to `allow_output_mutation`."
)

# Support passing the params via function decorator, e.g.
# @st.cache(persist=True, ignore_hash=True)
# @st.cache(persist=True, allow_output_mutation=True)
if func is None:
return lambda f: cache(
func=f,
persist=persist,
ignore_hash=ignore_hash,
allow_output_mutation=allow_output_mutation,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
)
Expand Down Expand Up @@ -519,7 +529,7 @@ def get_or_set_cache():
caller_frame = inspect.currentframe().f_back
try:
return_value, args_mutated = _read_from_cache(
key, persist, ignore_hash, func, caller_frame
key, persist, allow_output_mutation, func, caller_frame
)
except (CacheKeyNotFoundError, CachedObjectWasMutatedError):
with _calling_cached_function():
Expand All @@ -533,7 +543,9 @@ def get_or_set_cache():
args_hasher_after.update([args, kwargs])
args_mutated = args_digest_before != args_hasher_after.digest()

_write_to_cache(key, return_value, persist, ignore_hash, args_mutated)
_write_to_cache(
key, return_value, persist, allow_output_mutation, args_mutated
)

if args_mutated:
# If we're inside a _nested_ cached function, our
Expand Down Expand Up @@ -587,9 +599,9 @@ class Cache(dict):
"""

def __init__(self, persist=False, ignore_hash=False):
def __init__(self, persist=False, allow_output_mutation=False):
self._persist = persist
self._ignore_hash = ignore_hash
self._allow_output_mutation = allow_output_mutation

dict.__init__(self)

Expand Down Expand Up @@ -643,20 +655,20 @@ def has_changes(self):
value, _ = _read_from_cache(
key,
self._persist,
self._ignore_hash,
self._allow_output_mutation,
code,
[caller_lineno + 1, caller_lineno + len(lines)],
)
self.update(value)
except (CacheKeyNotFoundError, CachedObjectWasMutatedError):
if self._ignore_hash and not self._persist:
if self._allow_output_mutation and not self._persist:
# If we don't hash the results, we don't need to use exec and just return True.
# This way line numbers will be correct.
_write_to_cache(key, self, False, True, None)
return True

exec(code, caller_frame.f_globals, caller_frame.f_locals)
_write_to_cache(key, self, self._persist, self._ignore_hash, None)
_write_to_cache(key, self, self._persist, self._allow_output_mutation, None)

# Return False so that we have control over the execution.
return False
Expand Down
2 changes: 1 addition & 1 deletion lib/streamlit/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _hashing_error_message(start):
following:
* **Preferred:** modify your code to avoid using this type of object.
* Or add the argument `ignore_hash=True` to the `st.cache` decorator.
* Or add the argument `allow_output_mutation=True` to the `st.cache` decorator.
"""
% {"start": start}
).strip("\n")
Expand Down
17 changes: 15 additions & 2 deletions lib/tests/streamlit/caching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""st.caching unit tests."""
import threading
import unittest
import pytest

from mock import patch

Expand All @@ -39,6 +40,18 @@ def foo():
self.assertEqual(foo(), 42)
self.assertEqual(foo(), 42)

def test_deprecated_kwarg(self):
with pytest.raises(Exception) as e:

@st.cache(ignore_hash=True)
def foo():
return 42

assert (
"The `ignore_hash` argument has been renamed to `allow_output_mutation`."
in str(e.value)
)

@patch.object(st, "warning")
def test_args(self, warning):
called = [False]
Expand Down Expand Up @@ -207,11 +220,11 @@ def off_test_simple(self):

self.assertEqual(c.value, val)

def off_test_ignore_hash(self):
def off_test_allow_output_mutation(self):
val = 42

for _ in range(2):
c = st.Cache(ignore_hash=True)
c = st.Cache(allow_output_mutation=True)
if c:
c.value = val

Expand Down
9 changes: 4 additions & 5 deletions lib/tests/streamlit/help_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,10 @@ def test_deltagenerator_func(self):
self.assertEqual("streamlit", ds.module)
if is_python_2:
self.assertEqual("<type 'function'>", ds.type)
self.assertEqual("(data, format=u'audio/wav', start_time=0)",
ds.signature)
self.assertEqual("(data, format=u'audio/wav', start_time=0)", ds.signature)
else:
self.assertEqual("<class 'function'>", ds.type)
self.assertEqual("(data, format='audio/wav', start_time=0)",
ds.signature)
self.assertEqual("(data, format='audio/wav', start_time=0)", ds.signature)
self.assertTrue(ds.doc_string.startswith("Display an audio player"))

def test_unwrapped_deltagenerator_func(self):
Expand Down Expand Up @@ -117,7 +115,8 @@ def test_st_cache(self):
ds.signature,
(
"(func=None, persist=False, "
"ignore_hash=False, show_spinner=True, suppress_st_warning=False)"
"allow_output_mutation=False, show_spinner=True, suppress_st_warning=False, "
"**kwargs)"
),
)
self.assertTrue(ds.doc_string.startswith("Function decorator to"))
Expand Down

0 comments on commit 8a2f26f

Please sign in to comment.