Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hash_funcs for st.cache_data and st.cache_resource #6502

Merged
merged 18 commits into from Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions lib/streamlit/runtime/caching/cache_data_api.py
Expand Up @@ -44,6 +44,7 @@
MsgData,
MultiCacheResults,
)
from streamlit.runtime.caching.hashing import HashFuncsDict
from streamlit.runtime.caching.storage import (
CacheStorage,
CacheStorageContext,
Expand Down Expand Up @@ -80,11 +81,13 @@ def __init__(
max_entries: int | None,
ttl: float | timedelta | str | None,
allow_widgets: bool,
hash_funcs: HashFuncsDict | None = None,
):
super().__init__(
func,
show_spinner=show_spinner,
allow_widgets=allow_widgets,
hash_funcs=hash_funcs,
)
self.persist = persist
self.max_entries = max_entries
Expand Down Expand Up @@ -365,6 +368,7 @@ def __call__(
show_spinner: bool | str = True,
persist: CachePersistType | bool = None,
experimental_allow_widgets: bool = False,
hash_funcs: HashFuncsDict | None = None,
):
return self._decorator(
func,
Expand All @@ -373,6 +377,7 @@ def __call__(
persist=persist,
show_spinner=show_spinner,
experimental_allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)

def _decorator(
Expand All @@ -384,6 +389,7 @@ def _decorator(
show_spinner: bool | str,
persist: CachePersistType | bool,
experimental_allow_widgets: bool,
hash_funcs: HashFuncsDict | None = None,
):
"""Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference).

Expand Down Expand Up @@ -529,6 +535,7 @@ def wrapper(f):
max_entries=max_entries,
ttl=ttl,
allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)
)

Expand All @@ -543,6 +550,7 @@ def wrapper(f):
max_entries=max_entries,
ttl=ttl,
allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)
)

Expand Down
8 changes: 8 additions & 0 deletions lib/streamlit/runtime/caching/cache_resource_api.py
Expand Up @@ -45,6 +45,7 @@
MsgData,
MultiCacheResults,
)
from streamlit.runtime.caching.hashing import HashFuncsDict
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
Expand Down Expand Up @@ -153,11 +154,13 @@ def __init__(
ttl: float | timedelta | str | None,
validate: ValidateFunc | None,
allow_widgets: bool,
hash_funcs: HashFuncsDict | None = None,
):
super().__init__(
func,
show_spinner=show_spinner,
allow_widgets=allow_widgets,
hash_funcs=hash_funcs,
)
self.max_entries = max_entries
self.ttl = ttl
Expand Down Expand Up @@ -245,6 +248,7 @@ def __call__(
show_spinner: bool | str = True,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
hash_funcs: HashFuncsDict | None = None,
):
return self._decorator(
func,
Expand All @@ -253,6 +257,7 @@ def __call__(
show_spinner=show_spinner,
validate=validate,
experimental_allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)

def _decorator(
Expand All @@ -264,6 +269,7 @@ def _decorator(
show_spinner: bool | str,
validate: ValidateFunc | None,
experimental_allow_widgets: bool,
hash_funcs: HashFuncsDict | None = None,
):
"""Decorator to cache functions that return global resources (e.g. database connections, ML models).

Expand Down Expand Up @@ -389,6 +395,7 @@ def _decorator(
ttl=ttl,
validate=validate,
allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)
)

Expand All @@ -400,6 +407,7 @@ def _decorator(
ttl=ttl,
validate=validate,
allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)
)

Expand Down
24 changes: 19 additions & 5 deletions lib/streamlit/runtime/caching/cache_utils.py
Expand Up @@ -50,7 +50,7 @@
MsgData,
replay_cached_messages,
)
from streamlit.runtime.caching.hashing import update_hash
from streamlit.runtime.caching.hashing import HashFuncsDict, update_hash

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -169,10 +169,12 @@ def __init__(
func: types.FunctionType,
show_spinner: bool | str,
allow_widgets: bool,
hash_funcs: HashFuncsDict | None,
):
self.func = func
self.show_spinner = show_spinner
self.allow_widgets = allow_widgets
self.hash_funcs = hash_funcs

@property
def cache_type(self) -> CacheType:
Expand Down Expand Up @@ -254,6 +256,7 @@ def _get_or_create_cached_value(
func=self._info.func,
func_args=func_args,
func_kwargs=func_kwargs,
hash_funcs=self._info.hash_funcs,
)

try:
Expand Down Expand Up @@ -351,6 +354,7 @@ def _make_value_key(
func: types.FunctionType,
func_args: tuple[Any, ...],
func_kwargs: dict[str, Any],
hash_funcs: HashFuncsDict | None,
) -> str:
"""Create the key for a value within a cache.

Expand Down Expand Up @@ -388,9 +392,20 @@ def _make_value_key(

try:
update_hash(
(arg_name, arg_value),
arg_name,
hasher=args_hasher,
cache_type=cache_type,
hash_source=func,
)
# we call update_hash twice here, first time for `arg_name`
# without `hash_funcs`, and second time for `arg_value` with hash_funcs
# to evaluate user defined `hash_funcs` only for computing `arg_value` hash.
update_hash(
arg_value,
hasher=args_hasher,
cache_type=cache_type,
hash_funcs=hash_funcs,
hash_source=func,
)
except UnhashableTypeError as exc:
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
Expand All @@ -417,6 +432,7 @@ def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str:
(func.__module__, func.__qualname__),
hasher=func_hasher,
cache_type=cache_type,
hash_source=func,
)

# Include the function's source code in its hash. If the source code can't
Expand All @@ -432,9 +448,7 @@ def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str:
source_code = func.__code__.co_code

update_hash(
source_code,
hasher=func_hasher,
cache_type=cache_type,
source_code, hasher=func_hasher, cache_type=cache_type, hash_source=func
)

cache_key = func_hasher.hexdigest()
Expand Down