From 9dc2352b8fb7a6866430c53d6dd84ebaf5793de1 Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Mon, 17 Apr 2023 22:17:18 +0400 Subject: [PATCH 01/13] hash_funcs prototype --- .../runtime/caching/cache_data_api.py | 8 ++++ lib/streamlit/runtime/caching/cache_utils.py | 7 +++- lib/streamlit/runtime/caching/hashing.py | 41 +++++++++++++++++-- .../runtime/legacy_caching/caching_test.py | 2 +- 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/lib/streamlit/runtime/caching/cache_data_api.py b/lib/streamlit/runtime/caching/cache_data_api.py index 558545d858cf..14116e1da5b9 100644 --- a/lib/streamlit/runtime/caching/cache_data_api.py +++ b/lib/streamlit/runtime/caching/cache_data_api.py @@ -44,6 +44,7 @@ MsgData, MultiCacheResults, ) +from streamlit.runtime.caching.hashing import HashFuncsDict from streamlit.runtime.caching.storage import ( CacheStorage, CacheStorageContext, @@ -80,11 +81,13 @@ def __init__( max_entries: int | None, ttl: float | timedelta | None, allow_widgets: bool, + hash_funcs: HashFuncsDict | None = None, ): super().__init__( func, show_spinner=show_spinner, allow_widgets=allow_widgets, + hash_funcs=hash_funcs, ) self.persist = persist self.max_entries = max_entries @@ -365,6 +368,7 @@ def __call__( show_spinner: bool | str = True, persist: CachePersistType | bool = None, experimental_allow_widgets: bool = False, + hash_funcs: HashFuncsDict | None = None, ): return self._decorator( func, @@ -373,6 +377,7 @@ def __call__( persist=persist, show_spinner=show_spinner, experimental_allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) def _decorator( @@ -384,6 +389,7 @@ def _decorator( show_spinner: bool | str, persist: CachePersistType | bool, experimental_allow_widgets: bool, + hash_funcs: HashFuncsDict | None = None, ): """Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference). @@ -521,6 +527,7 @@ def wrapper(f): max_entries=max_entries, ttl=ttl, allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) ) @@ -535,6 +542,7 @@ def wrapper(f): max_entries=max_entries, ttl=ttl, allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) ) diff --git a/lib/streamlit/runtime/caching/cache_utils.py b/lib/streamlit/runtime/caching/cache_utils.py index 2c9581afe432..219d9fec0cb6 100644 --- a/lib/streamlit/runtime/caching/cache_utils.py +++ b/lib/streamlit/runtime/caching/cache_utils.py @@ -49,7 +49,7 @@ MsgData, replay_cached_messages, ) -from streamlit.runtime.caching.hashing import update_hash +from streamlit.runtime.caching.hashing import HashFuncsDict, update_hash _LOGGER = get_logger(__name__) @@ -154,10 +154,12 @@ def __init__( func: types.FunctionType, show_spinner: bool | str, allow_widgets: bool, + hash_funcs: HashFuncsDict | None, ): self.func = func self.show_spinner = show_spinner self.allow_widgets = allow_widgets + self.hash_funcs = hash_funcs @property def cache_type(self) -> CacheType: @@ -239,6 +241,7 @@ def _get_or_create_cached_value( func=self._info.func, func_args=func_args, func_kwargs=func_kwargs, + hash_funcs=self._info.hash_funcs, ) try: @@ -336,6 +339,7 @@ def _make_value_key( func: types.FunctionType, func_args: tuple[Any, ...], func_kwargs: dict[str, Any], + hash_funcs: HashFuncsDict | None, ) -> str: """Create the key for a value within a cache. @@ -376,6 +380,7 @@ def _make_value_key( (arg_name, arg_value), hasher=args_hasher, cache_type=cache_type, + hash_funcs=hash_funcs, ) except UnhashableTypeError as exc: raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc) diff --git a/lib/streamlit/runtime/caching/hashing.py b/lib/streamlit/runtime/caching/hashing.py index 8c995675a537..714b5f255701 100644 --- a/lib/streamlit/runtime/caching/hashing.py +++ b/lib/streamlit/runtime/caching/hashing.py @@ -28,7 +28,7 @@ import uuid import weakref from enum import Enum -from typing import Any, Dict, List, Optional, Pattern +from typing import Any, Callable, Dict, List, Optional, Pattern, Type, Union from streamlit import type_util, util from streamlit.runtime.caching.cache_errors import UnhashableTypeError @@ -43,17 +43,24 @@ _NP_SIZE_LARGE = 1000000 _NP_SAMPLE_SIZE = 100000 +HashFuncsDict = Dict[Union[str, Type[Any]], Callable[[Any], Any]] + # Arbitrary item to denote where we found a cycle in a hashed object. # This allows us to hash self-referencing lists, dictionaries, etc. _CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE" -def update_hash(val: Any, hasher, cache_type: CacheType) -> None: +def update_hash( + val: Any, + hasher, + cache_type: CacheType, + hash_funcs: Optional[HashFuncsDict] = None, +) -> None: """Updates a hashlib hasher with the hash of val. This is the main entrypoint to hashing.py. """ - ch = _CacheFuncHasher(cache_type) + ch = _CacheFuncHasher(cache_type, hash_funcs) ch.update(hasher, val) @@ -161,7 +168,24 @@ def is_simple(obj): class _CacheFuncHasher: """A hasher that can hash objects with cycles.""" - def __init__(self, cache_type: CacheType): + def __init__( + self, cache_type: CacheType, hash_funcs: Optional[HashFuncsDict] = None + ): + # Can't use types as the keys in the internal _hash_funcs because + # we always remove user-written modules from memory when rerunning a + # script in order to reload it and grab the latest code changes. + # (See LocalSourcesWatcher.py:on_file_changed) This causes + # the type object to refer to different underlying class instances each run, + # so type-based comparisons fail. To solve this, we use the types converted + # to fully-qualified strings as keys in our internal dict. + self._hash_funcs: HashFuncsDict + if hash_funcs: + self._hash_funcs = { + k if isinstance(k, str) else type_util.get_fqn(k): v + for k, v in hash_funcs.items() + } + else: + self._hash_funcs = {} self._hashes: Dict[Any, bytes] = {} # The number of the bytes in the hash. @@ -226,6 +250,15 @@ def _to_bytes(self, obj: Any) -> bytes: elif isinstance(obj, bytes) or isinstance(obj, bytearray): return obj + elif type_util.get_fqn_type(obj) in self._hash_funcs: + # Escape hatch for unsupported objects + hash_func = self._hash_funcs[type_util.get_fqn_type(obj)] + try: + output = hash_func(obj) + except Exception as ex: + raise UnhashableTypeError("AAAAAAAAA") from ex + return self.to_bytes(output) + elif isinstance(obj, str): return obj.encode() diff --git a/lib/tests/streamlit/runtime/legacy_caching/caching_test.py b/lib/tests/streamlit/runtime/legacy_caching/caching_test.py index d7e2c761604a..b55bd7baf36c 100644 --- a/lib/tests/streamlit/runtime/legacy_caching/caching_test.py +++ b/lib/tests/streamlit/runtime/legacy_caching/caching_test.py @@ -408,7 +408,7 @@ def test_function_name_does_not_use_hashfuncs(self): def foo(string_arg): return [] - # If our str hash_func is called multiple times, it's probably because + # If our str hash_funcs is called multiple times, it's probably because # it's being used to compute the function's cache_key (as opposed to # the value_key). It should only be used to compute the value_key! foo("ahoy") From a491314280d5d3691b2537bab12d45dada8736e9 Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Fri, 28 Apr 2023 20:39:21 +0400 Subject: [PATCH 02/13] add `UserHashError` machinery to @st.cache_data and @st.cache_resource hash_funcs --- .../runtime/caching/cache_resource_api.py | 8 + lib/streamlit/runtime/caching/cache_utils.py | 6 +- lib/streamlit/runtime/caching/hashing.py | 179 +++++++++++++++++- 3 files changed, 188 insertions(+), 5 deletions(-) diff --git a/lib/streamlit/runtime/caching/cache_resource_api.py b/lib/streamlit/runtime/caching/cache_resource_api.py index 7b10e1028270..ee65e29b9918 100644 --- a/lib/streamlit/runtime/caching/cache_resource_api.py +++ b/lib/streamlit/runtime/caching/cache_resource_api.py @@ -45,6 +45,7 @@ MsgData, MultiCacheResults, ) +from streamlit.runtime.caching.hashing import HashFuncsDict from streamlit.runtime.metrics_util import gather_metrics from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx from streamlit.runtime.stats import CacheStat, CacheStatsProvider @@ -153,11 +154,13 @@ def __init__( ttl: float | timedelta | None, validate: ValidateFunc | None, allow_widgets: bool, + hash_funcs: HashFuncsDict | None = None, ): super().__init__( func, show_spinner=show_spinner, allow_widgets=allow_widgets, + hash_funcs=hash_funcs, ) self.max_entries = max_entries self.ttl = ttl @@ -245,6 +248,7 @@ def __call__( show_spinner: bool | str = True, validate: ValidateFunc | None = None, experimental_allow_widgets: bool = False, + hash_funcs: HashFuncsDict | None = None, ): return self._decorator( func, @@ -253,6 +257,7 @@ def __call__( show_spinner=show_spinner, validate=validate, experimental_allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) def _decorator( @@ -264,6 +269,7 @@ def _decorator( show_spinner: bool | str, validate: ValidateFunc | None, experimental_allow_widgets: bool, + hash_funcs: HashFuncsDict | None = None, ): """Decorator to cache functions that return global resources (e.g. database connections, ML models). @@ -379,6 +385,7 @@ def _decorator( ttl=ttl, validate=validate, allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) ) @@ -390,6 +397,7 @@ def _decorator( ttl=ttl, validate=validate, allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) ) diff --git a/lib/streamlit/runtime/caching/cache_utils.py b/lib/streamlit/runtime/caching/cache_utils.py index 219d9fec0cb6..2f7858aa55c4 100644 --- a/lib/streamlit/runtime/caching/cache_utils.py +++ b/lib/streamlit/runtime/caching/cache_utils.py @@ -381,6 +381,7 @@ def _make_value_key( hasher=args_hasher, cache_type=cache_type, hash_funcs=hash_funcs, + hash_source=func, ) except UnhashableTypeError as exc: raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc) @@ -407,6 +408,7 @@ def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str: (func.__module__, func.__qualname__), hasher=func_hasher, cache_type=cache_type, + hash_source=func, ) # Include the function's source code in its hash. If the source code can't @@ -422,9 +424,7 @@ def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str: source_code = func.__code__.co_code update_hash( - source_code, - hasher=func_hasher, - cache_type=cache_type, + source_code, hasher=func_hasher, cache_type=cache_type, hash_source=func ) cache_key = func_hasher.hexdigest() diff --git a/lib/streamlit/runtime/caching/hashing.py b/lib/streamlit/runtime/caching/hashing.py index 714b5f255701..cb272b1bdffa 100644 --- a/lib/streamlit/runtime/caching/hashing.py +++ b/lib/streamlit/runtime/caching/hashing.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Hashing for st.memo and st.singleton.""" +"""Hashing for st.cache_data and st.cache_resource.""" import collections import dataclasses +import enum import functools import hashlib import inspect @@ -23,6 +24,7 @@ import pickle import sys import tempfile +import textwrap import threading import unittest.mock import uuid @@ -31,6 +33,7 @@ from typing import Any, Callable, Dict, List, Optional, Pattern, Type, Union from streamlit import type_util, util +from streamlit.errors import StreamlitAPIException from streamlit.runtime.caching.cache_errors import UnhashableTypeError from streamlit.runtime.caching.cache_type import CacheType from streamlit.runtime.uploaded_file_manager import UploadedFile @@ -50,16 +53,172 @@ _CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE" +class HashReason(enum.Enum): + CACHING_FUNC_ARGS = 0 + CACHING_FUNC_BODY = 1 + CACHING_FUNC_OUTPUT = 2 + CACHING_BLOCK = 3 + + +def _get_error_message_args( + orig_exc: BaseException, failed_obj: Any, cache_type: Optional[CacheType] = None +) -> Dict[str, Any]: + hash_reason = HashReason.CACHING_FUNC_ARGS + hash_source = hash_stacks.current.hash_source + + failed_obj_type_str = type_util.get_fqn_type(failed_obj) + object_part = "" + + if hash_source is None or hash_reason is None: + object_desc = "something" + + elif hash_reason is HashReason.CACHING_BLOCK: + object_desc = "a code block" + + else: + if hasattr(hash_source, "__name__"): + object_desc = f"`{hash_source.__name__}()`" + else: + object_desc = "a function" + + if hash_reason is HashReason.CACHING_FUNC_ARGS: + object_part = "the arguments of" + elif hash_reason is HashReason.CACHING_FUNC_BODY: + object_part = "the body of" + elif hash_reason is HashReason.CACHING_FUNC_OUTPUT: + object_part = "the return value of" + + decorator_name = "@st.cache" + if cache_type is CacheType.RESOURCE: + decorator_name = "@st.cache_resource" + elif cache_type is CacheType.DATA: + decorator_name = "@st.cache_data" + + return { + "orig_exception_desc": str(orig_exc), + "failed_obj_type_str": failed_obj_type_str, + "hash_stack": hash_stacks.current.pretty_print(), + "object_desc": object_desc, + "object_part": object_part, + "cache_primitive": decorator_name, + } + + +def _get_failing_lines(code, lineno: int) -> List[str]: + """Get list of strings (lines of code) from lineno to lineno+3. + + Ideally we'd return the exact line where the error took place, but there + are reasons why this is not possible without a lot of work, including + playing with the AST. So for now we're returning 3 lines near where + the error took place. + """ + source_lines, source_lineno = inspect.getsourcelines(code) + + start = lineno - source_lineno + end = min(start + 3, len(source_lines)) + lines = source_lines[start:end] + + return lines + + +class UserHashError(StreamlitAPIException): + def __init__( + self, + orig_exc, + cached_func_or_code, + hash_func=None, + lineno=None, + cache_type: Optional[CacheType] = None, + ): + self.alternate_name = type(orig_exc).__name__ + self.cache_type = cache_type + + if hash_func: + msg = self._get_message_from_func(orig_exc, cached_func_or_code, hash_func) + else: + msg = self._get_message_from_code(orig_exc, cached_func_or_code, lineno) + + super(UserHashError, self).__init__(msg) + self.with_traceback(orig_exc.__traceback__) + + def _get_message_from_func(self, orig_exc, cached_func, hash_func): + args = _get_error_message_args(orig_exc, cached_func, self.cache_type) + + if hasattr(hash_func, "__name__"): + args["hash_func_name"] = "`%s()`" % hash_func.__name__ + else: + args["hash_func_name"] = "a function" + + return ( + """ +%(orig_exception_desc)s + +This error is likely due to a bug in %(hash_func_name)s, which is a +user-defined hash function that was passed into the `%(cache_primitive)s` decorator of +%(object_desc)s. + +%(hash_func_name)s failed when hashing an object of type +`%(failed_obj_type_str)s`. If you don't know where that object is coming from, +try looking at the hash chain below for an object that you do recognize, then +pass that to `hash_funcs` instead: + +``` +%(hash_stack)s +``` + +If you think this is actually a Streamlit bug, please [file a bug report here.] +(https://github.com/streamlit/streamlit/issues/new/choose) + """ + % args + ).strip("\n") + + def _get_message_from_code(self, orig_exc: BaseException, cached_code, lineno: int): + args = _get_error_message_args(orig_exc, cached_code) + + failing_lines = _get_failing_lines(cached_code, lineno) + failing_lines_str = "".join(failing_lines) + failing_lines_str = textwrap.dedent(failing_lines_str).strip("\n") + + args["failing_lines_str"] = failing_lines_str + args["filename"] = cached_code.co_filename + args["lineno"] = lineno + + # This needs to have zero indentation otherwise %(lines_str)s will + # render incorrectly in Markdown. + return ( + """ +%(orig_exception_desc)s + +Streamlit encountered an error while caching %(object_part)s %(object_desc)s. +This is likely due to a bug in `%(filename)s` near line `%(lineno)s`: + +``` +%(failing_lines_str)s +``` + +Please modify the code above to address this. + +If you think this is actually a Streamlit bug, you may [file a bug report +here.] (https://github.com/streamlit/streamlit/issues/new/choose) + """ + % args + ).strip("\n") + + def update_hash( val: Any, hasher, cache_type: CacheType, + hash_source: Callable[..., Any], hash_funcs: Optional[HashFuncsDict] = None, ) -> None: """Updates a hashlib hasher with the hash of val. This is the main entrypoint to hashing.py. """ + + hash_stacks.current.hash_source = hash_source + ch = _CacheFuncHasher(cache_type, hash_funcs) ch.update(hasher, val) @@ -91,6 +250,20 @@ def pop(self): def __contains__(self, val: Any): return id(val) in self._stack + def pretty_print(self): + def to_str(v): + try: + return "Object of type %s: %s" % (type_util.get_fqn_type(v), str(v)) + except Exception: + return "" + + # IDEA: Maybe we should remove our internal "hash_funcs" from the + # stack. I'm not removing those now because even though those aren't + # useful to users I think they might be useful when we're debugging an + # issue sent by a user. So let's wait a few months and see if they're + # indeed useful... + return "\n".join(to_str(x) for x in reversed(self._stack.values())) + class _HashStacks: """Stacks of what has been hashed, with at most 1 stack per thread.""" @@ -256,7 +429,9 @@ def _to_bytes(self, obj: Any) -> bytes: try: output = hash_func(obj) except Exception as ex: - raise UnhashableTypeError("AAAAAAAAA") from ex + raise UserHashError( + ex, obj, hash_func=hash_func, cache_type=self.cache_type + ) from ex return self.to_bytes(output) elif isinstance(obj, str): From d5c3ad1f030c43181924cbf2e237507c2b17cc1b Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Fri, 5 May 2023 19:29:37 +0400 Subject: [PATCH 03/13] add comment --- lib/streamlit/runtime/caching/hashing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/streamlit/runtime/caching/hashing.py b/lib/streamlit/runtime/caching/hashing.py index cb272b1bdffa..9ad3f7f1c23c 100644 --- a/lib/streamlit/runtime/caching/hashing.py +++ b/lib/streamlit/runtime/caching/hashing.py @@ -428,6 +428,8 @@ def _to_bytes(self, obj: Any) -> bytes: hash_func = self._hash_funcs[type_util.get_fqn_type(obj)] try: output = hash_func(obj) + # check `output` type, it should be primitive type + # if not, raise error except Exception as ex: raise UserHashError( ex, obj, hash_func=hash_func, cache_type=self.cache_type From 85c31e2155615c1c7f6d3114427297a3d3b8620c Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Tue, 9 May 2023 20:28:59 +0400 Subject: [PATCH 04/13] Remove legacy details from hash_func implementation for new cache primitives --- lib/streamlit/runtime/caching/hashing.py | 163 ++++++----------------- 1 file changed, 44 insertions(+), 119 deletions(-) diff --git a/lib/streamlit/runtime/caching/hashing.py b/lib/streamlit/runtime/caching/hashing.py index 9ad3f7f1c23c..5abdbbed4e7b 100644 --- a/lib/streamlit/runtime/caching/hashing.py +++ b/lib/streamlit/runtime/caching/hashing.py @@ -15,7 +15,6 @@ """Hashing for st.cache_data and st.cache_resource.""" import collections import dataclasses -import enum import functools import hashlib import inspect @@ -24,7 +23,6 @@ import pickle import sys import tempfile -import textwrap import threading import unittest.mock import uuid @@ -53,101 +51,25 @@ _CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE" -class HashReason(enum.Enum): - CACHING_FUNC_ARGS = 0 - CACHING_FUNC_BODY = 1 - CACHING_FUNC_OUTPUT = 2 - CACHING_BLOCK = 3 - - -def _get_error_message_args( - orig_exc: BaseException, failed_obj: Any, cache_type: Optional[CacheType] = None -) -> Dict[str, Any]: - hash_reason = HashReason.CACHING_FUNC_ARGS - hash_source = hash_stacks.current.hash_source - - failed_obj_type_str = type_util.get_fqn_type(failed_obj) - object_part = "" - - if hash_source is None or hash_reason is None: - object_desc = "something" - - elif hash_reason is HashReason.CACHING_BLOCK: - object_desc = "a code block" - - else: - if hasattr(hash_source, "__name__"): - object_desc = f"`{hash_source.__name__}()`" - else: - object_desc = "a function" - - if hash_reason is HashReason.CACHING_FUNC_ARGS: - object_part = "the arguments of" - elif hash_reason is HashReason.CACHING_FUNC_BODY: - object_part = "the body of" - elif hash_reason is HashReason.CACHING_FUNC_OUTPUT: - object_part = "the return value of" - - decorator_name = "@st.cache" - if cache_type is CacheType.RESOURCE: - decorator_name = "@st.cache_resource" - elif cache_type is CacheType.DATA: - decorator_name = "@st.cache_data" - - return { - "orig_exception_desc": str(orig_exc), - "failed_obj_type_str": failed_obj_type_str, - "hash_stack": hash_stacks.current.pretty_print(), - "object_desc": object_desc, - "object_part": object_part, - "cache_primitive": decorator_name, - } - - -def _get_failing_lines(code, lineno: int) -> List[str]: - """Get list of strings (lines of code) from lineno to lineno+3. - - Ideally we'd return the exact line where the error took place, but there - are reasons why this is not possible without a lot of work, including - playing with the AST. So for now we're returning 3 lines near where - the error took place. - """ - source_lines, source_lineno = inspect.getsourcelines(code) - - start = lineno - source_lineno - end = min(start + 3, len(source_lines)) - lines = source_lines[start:end] - - return lines - - class UserHashError(StreamlitAPIException): def __init__( self, orig_exc, - cached_func_or_code, - hash_func=None, - lineno=None, + object_to_hash, + hash_func, cache_type: Optional[CacheType] = None, ): self.alternate_name = type(orig_exc).__name__ + self.hash_func = hash_func self.cache_type = cache_type - if hash_func: - msg = self._get_message_from_func(orig_exc, cached_func_or_code, hash_func) - else: - msg = self._get_message_from_code(orig_exc, cached_func_or_code, lineno) + msg = self._get_message_from_func(orig_exc, object_to_hash) - super(UserHashError, self).__init__(msg) + super().__init__(msg) self.with_traceback(orig_exc.__traceback__) - def _get_message_from_func(self, orig_exc, cached_func, hash_func): - args = _get_error_message_args(orig_exc, cached_func, self.cache_type) - - if hasattr(hash_func, "__name__"): - args["hash_func_name"] = "`%s()`" % hash_func.__name__ - else: - args["hash_func_name"] = "a function" + def _get_message_from_func(self, orig_exc, cached_func): + args = self._get_error_message_args(orig_exc, cached_func) return ( """ @@ -172,44 +94,49 @@ def _get_message_from_func(self, orig_exc, cached_func, hash_func): % args ).strip("\n") - def _get_message_from_code(self, orig_exc: BaseException, cached_code, lineno: int): - args = _get_error_message_args(orig_exc, cached_code) - - failing_lines = _get_failing_lines(cached_code, lineno) - failing_lines_str = "".join(failing_lines) - failing_lines_str = textwrap.dedent(failing_lines_str).strip("\n") - - args["failing_lines_str"] = failing_lines_str - args["filename"] = cached_code.co_filename - args["lineno"] = lineno - - # This needs to have zero indentation otherwise %(lines_str)s will - # render incorrectly in Markdown. - return ( - """ -%(orig_exception_desc)s + def _get_error_message_args( + self, + orig_exc: BaseException, + failed_obj: Any, + ) -> Dict[str, Any]: + hash_source = hash_stacks.current.hash_source -Streamlit encountered an error while caching %(object_part)s %(object_desc)s. -This is likely due to a bug in `%(filename)s` near line `%(lineno)s`: + failed_obj_type_str = type_util.get_fqn_type(failed_obj) -``` -%(failing_lines_str)s -``` - -Please modify the code above to address this. + if hash_source is None: + object_desc = "something" + else: + if hasattr(hash_source, "__name__"): + object_desc = f"`{hash_source.__name__}()`" + else: + object_desc = "a function" + + decorator_name = "" + if self.cache_type is CacheType.RESOURCE: + decorator_name = "@st.cache_resource" + elif self.cache_type is CacheType.DATA: + decorator_name = "@st.cache_data" + + if hasattr(self.hash_func, "__name__"): + hash_func_name = f"`{self.hash_func.__name__}()`" + else: + hash_func_name = "a function" -If you think this is actually a Streamlit bug, you may [file a bug report -here.] (https://github.com/streamlit/streamlit/issues/new/choose) - """ - % args - ).strip("\n") + return { + "orig_exception_desc": str(orig_exc), + "failed_obj_type_str": failed_obj_type_str, + "hash_stack": hash_stacks.current.pretty_print(), + "object_desc": object_desc, + "cache_primitive": decorator_name, + "hash_func_name": hash_func_name, + } def update_hash( val: Any, hasher, cache_type: CacheType, - hash_source: Callable[..., Any], + hash_source: Optional[Callable[..., Any]] = None, hash_funcs: Optional[HashFuncsDict] = None, ) -> None: """Updates a hashlib hasher with the hash of val. @@ -237,6 +164,9 @@ class _HashStack: def __init__(self): self._stack: collections.OrderedDict[int, List[Any]] = collections.OrderedDict() + # A function that we decorate with streamlit cache + # primitive (st.cache_data or st.cache_resource). + self.hash_source: Optional[Callable[..., Any]] = None def __repr__(self) -> str: return util.repr_(self) @@ -257,11 +187,6 @@ def to_str(v): except Exception: return "" - # IDEA: Maybe we should remove our internal "hash_funcs" from the - # stack. I'm not removing those now because even though those aren't - # useful to users I think they might be useful when we're debugging an - # issue sent by a user. So let's wait a few months and see if they're - # indeed useful... return "\n".join(to_str(x) for x in reversed(self._stack.values())) From adcc3bdcb80bfd0b6c67567abb78a546f3aed8ac Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Wed, 31 May 2023 17:51:24 +0400 Subject: [PATCH 05/13] fix broken `file a bug report` markdown link --- lib/streamlit/runtime/caching/hashing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/streamlit/runtime/caching/hashing.py b/lib/streamlit/runtime/caching/hashing.py index 5abdbbed4e7b..9624429daf43 100644 --- a/lib/streamlit/runtime/caching/hashing.py +++ b/lib/streamlit/runtime/caching/hashing.py @@ -88,9 +88,9 @@ def _get_message_from_func(self, orig_exc, cached_func): %(hash_stack)s ``` -If you think this is actually a Streamlit bug, please [file a bug report here.] -(https://github.com/streamlit/streamlit/issues/new/choose) - """ +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose). +""" % args ).strip("\n") From b6ca511eb081c3c2db74d7c9ef4f71446ede24d3 Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Wed, 31 May 2023 17:56:33 +0400 Subject: [PATCH 06/13] fix broken `file a bug report` markdown link in old st.cache too --- lib/streamlit/runtime/legacy_caching/hashing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/streamlit/runtime/legacy_caching/hashing.py b/lib/streamlit/runtime/legacy_caching/hashing.py index 4a0f04741d30..1c2b48bdd5cf 100644 --- a/lib/streamlit/runtime/legacy_caching/hashing.py +++ b/lib/streamlit/runtime/legacy_caching/hashing.py @@ -867,8 +867,8 @@ def _get_message_from_func(self, orig_exc, cached_func, hash_func): %(hash_stack)s ``` -If you think this is actually a Streamlit bug, please [file a bug report here.] -(https://github.com/streamlit/streamlit/issues/new/choose) +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose). """ % args ).strip("\n") From d1ce8b440993743523396db3dbedf7d2649af7a0 Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Wed, 31 May 2023 20:19:33 +0400 Subject: [PATCH 07/13] Add tests for hash_funcs --- .../streamlit/runtime/caching/hashing_test.py | 71 ++++++++++++++++++- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/lib/tests/streamlit/runtime/caching/hashing_test.py b/lib/tests/streamlit/runtime/caching/hashing_test.py index 2e067fa68ad7..0b9c251fefcf 100644 --- a/lib/tests/streamlit/runtime/caching/hashing_test.py +++ b/lib/tests/streamlit/runtime/caching/hashing_test.py @@ -19,6 +19,7 @@ import os import re import tempfile +import time import types import unittest import uuid @@ -33,10 +34,12 @@ from parameterized import parameterized from PIL import Image +from streamlit.runtime.caching import cache_data, cache_resource from streamlit.runtime.caching.cache_errors import UnhashableTypeError from streamlit.runtime.caching.hashing import ( _NP_SIZE_LARGE, _PANDAS_ROWS_LARGE, + UserHashError, _CacheFuncHasher, ) @@ -56,9 +59,9 @@ get_main_script_director = MagicMock(return_value=os.getcwd()) -def get_hash(f): +def get_hash(f, hash_funcs=None): hasher = hashlib.new("md5") - ch = _CacheFuncHasher(MagicMock()) + ch = _CacheFuncHasher(MagicMock(), hash_funcs=hash_funcs) ch.update(hasher, f) return hasher.digest() @@ -114,6 +117,23 @@ def test_list(self): b.append(b) self.assertEqual(get_hash(a), get_hash(b)) + @parameterized.expand( + [("cache_data", cache_data), ("cache_resource", cache_resource)] + ) + def test_recursive_hash_func(self, _, cache_decorator): + def hash_int(x): + return x + + @cache_decorator(hash_funcs={int: hash_int}) + def foo(x): + return x + + self.assertEqual(foo(1), foo(1)) + # Note: We're able to break the recursive cycle caused by the identity + # hash func but it causes all cycles to hash to the same thing. + # https://github.com/streamlit/streamlit/issues/1659 + # self.assertNotEqual(foo(2), foo(1)) + def test_tuple(self): self.assertEqual(get_hash((1, 2)), get_hash((1, 2))) self.assertNotEqual(get_hash((1, 2)), get_hash((2, 2))) @@ -391,6 +411,53 @@ def test_generator_not_hashable(self): with self.assertRaises(UnhashableTypeError): get_hash((x for x in range(1))) + def test_hash_funcs_acceptable_keys(self): + test_generator = (x for x in range(1)) + + with self.assertRaises(UnhashableTypeError): + get_hash(test_generator) + + # Assert that hashes are equivalent when hash_func key is supplied both as a + # type literal, and as a type name string. + + self.assertEqual( + get_hash(test_generator, hash_funcs={types.GeneratorType: id}), + get_hash(test_generator, hash_funcs={"builtins.generator": id}), + ) + + def test_hash_funcs_error(self): + with self.assertRaises(UserHashError): + get_hash(1, hash_funcs={int: lambda x: "a" + x}) + + def test_non_hashable(self): + """Test user provided hash functions.""" + + g = (x for x in range(1)) + + # Unhashable object raises an error + with self.assertRaises(UnhashableTypeError): + get_hash(g) + + id_hash_func = {types.GeneratorType: id} + + self.assertEqual( + get_hash(g, hash_funcs=id_hash_func), + get_hash(g, hash_funcs=id_hash_func), + ) + + unique_hash_func = {types.GeneratorType: lambda x: time.time()} + + self.assertNotEqual( + get_hash(g, hash_funcs=unique_hash_func), + get_hash(g, hash_funcs=unique_hash_func), + ) + + def test_override_streamlit_hash_func(self): + """Test that a user provided hash function has priority over a streamlit one.""" + + hash_funcs = {int: lambda x: "hello"} + self.assertNotEqual(get_hash(1), get_hash(1, hash_funcs=hash_funcs)) + def test_function_not_hashable(self): def foo(): pass From 6024d2849ebfa2c1ce0ec9e26194a2fb7762ea64 Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Tue, 6 Jun 2023 19:29:22 +0400 Subject: [PATCH 08/13] add tests for `hash_funcs` --- .../runtime/caching/cache_data_api_test.py | 86 +++++++++++++++++++ .../streamlit/runtime/caching/hashing_test.py | 30 +++++-- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/lib/tests/streamlit/runtime/caching/cache_data_api_test.py b/lib/tests/streamlit/runtime/caching/cache_data_api_test.py index 5ae59f5ce255..ec1b079dc5c3 100644 --- a/lib/tests/streamlit/runtime/caching/cache_data_api_test.py +++ b/lib/tests/streamlit/runtime/caching/cache_data_api_test.py @@ -41,6 +41,7 @@ MultiCacheResults, _make_widget_key, ) +from streamlit.runtime.caching.hashing import UserHashError from streamlit.runtime.caching.storage import ( CacheStorage, CacheStorageContext, @@ -177,6 +178,91 @@ def foo(): else: show_warning_mock.assert_not_called() + def test_cached_member_function_with_hash_func(self): + """@st.cache_data can be applied to class member functions + with corresponding hash_func. + """ + + class TestClass: + @st.cache_data( + hash_funcs={ + "tests.streamlit.runtime.caching.cache_data_api_test.CacheDataTest.test_cached_member_function_with_hash_func..TestClass": id + } + ) + def member_func(self): + return "member func!" + + @classmethod + @st.cache_data + def class_method(cls): + return "class method!" + + @staticmethod + @st.cache_data + def static_method(): + return "static method!" + + obj = TestClass() + self.assertEqual("member func!", obj.member_func()) + self.assertEqual("class method!", obj.class_method()) + self.assertEqual("static method!", obj.static_method()) + + def test_function_name_does_not_use_hashfuncs(self): + """Hash funcs should only be used on arguments to a function, + and not when computing the key for a function's unique MemCache. + """ + + str_hash_func = Mock(return_value=None) + + @st.cache(hash_funcs={str: str_hash_func}) + def foo(string_arg): + return [] + + # If our str hash_funcs is called multiple times, it's probably because + # it's being used to compute the function's function_key (as opposed to + # the value_key). It should only be used to compute the value_key! + foo("ahoy") + str_hash_func.assert_called_once_with("ahoy") + + def test_user_hash_error(self): + class MyObj: + # we specify __repr__ here, to avoid `MyObj object at 0x1347a3f70` + # in error message + def __repr__(self): + return "MyObj class" + + def bad_hash_func(x): + x += 10 # Throws a TypeError since x has type MyObj. + return x + + @st.cache_data(hash_funcs={MyObj: bad_hash_func}) + def user_hash_error_func(x): + pass + + with self.assertRaises(UserHashError) as ctx: + my_obj = MyObj() + user_hash_error_func(my_obj) + + expected_message = """unsupported operand type(s) for +=: 'MyObj' and 'int' + +This error is likely due to a bug in `bad_hash_func()`, which is a +user-defined hash function that was passed into the `@st.cache_data` decorator of +`user_hash_error_func()`. + +`bad_hash_func()` failed when hashing an object of type +`tests.streamlit.runtime.caching.cache_data_api_test.CacheDataTest.test_user_hash_error..MyObj`. If you don't know where that object is coming from, +try looking at the hash chain below for an object that you do recognize, then +pass that to `hash_funcs` instead: + +``` +Object of type tests.streamlit.runtime.caching.cache_data_api_test.CacheDataTest.test_user_hash_error..MyObj: MyObj class +Object of type builtins.tuple: ('x', MyObj class) +``` + +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).""" + self.assertEqual(str(ctx.exception), expected_message) + class CacheDataPersistTest(DeltaGeneratorTestCase): """st.cache_data disk persistence tests""" diff --git a/lib/tests/streamlit/runtime/caching/hashing_test.py b/lib/tests/streamlit/runtime/caching/hashing_test.py index 0b9c251fefcf..f654eb447934 100644 --- a/lib/tests/streamlit/runtime/caching/hashing_test.py +++ b/lib/tests/streamlit/runtime/caching/hashing_test.py @@ -36,6 +36,7 @@ from streamlit.runtime.caching import cache_data, cache_resource from streamlit.runtime.caching.cache_errors import UnhashableTypeError +from streamlit.runtime.caching.cache_type import CacheType from streamlit.runtime.caching.hashing import ( _NP_SIZE_LARGE, _PANDAS_ROWS_LARGE, @@ -59,10 +60,10 @@ get_main_script_director = MagicMock(return_value=os.getcwd()) -def get_hash(f, hash_funcs=None): +def get_hash(value, hash_funcs=None, cache_type=None): hasher = hashlib.new("md5") - ch = _CacheFuncHasher(MagicMock(), hash_funcs=hash_funcs) - ch.update(hasher, f) + ch = _CacheFuncHasher(cache_type or MagicMock(), hash_funcs=hash_funcs) + ch.update(hasher, value) return hasher.digest() @@ -426,8 +427,27 @@ def test_hash_funcs_acceptable_keys(self): ) def test_hash_funcs_error(self): - with self.assertRaises(UserHashError): - get_hash(1, hash_funcs={int: lambda x: "a" + x}) + with self.assertRaises(UserHashError) as ctx: + get_hash(1, cache_type=CacheType.DATA, hash_funcs={int: lambda x: "a" + x}) + + expected_message = """can only concatenate str (not "int") to str + +This error is likely due to a bug in `()`, which is a +user-defined hash function that was passed into the `@st.cache_data` decorator of +something. + +`()` failed when hashing an object of type +`builtins.int`. If you don't know where that object is coming from, +try looking at the hash chain below for an object that you do recognize, then +pass that to `hash_funcs` instead: + +``` +Object of type builtins.int: 1 +``` + +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).""" + self.assertEqual(str(ctx.exception), expected_message) def test_non_hashable(self): """Test user provided hash functions.""" From f433d858c8d6baaaaaec5e41ae029e914f732868 Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Tue, 6 Jun 2023 21:07:19 +0400 Subject: [PATCH 09/13] add tests for `hash_funcs` cache_resource --- lib/streamlit/runtime/caching/cache_utils.py | 12 ++- .../runtime/caching/cache_data_api_test.py | 3 +- .../caching/cache_resource_api_test.py | 85 +++++++++++++++++++ .../streamlit/runtime/caching/hashing_test.py | 7 +- 4 files changed, 101 insertions(+), 6 deletions(-) diff --git a/lib/streamlit/runtime/caching/cache_utils.py b/lib/streamlit/runtime/caching/cache_utils.py index 3b515f818c65..c80ba5ff876a 100644 --- a/lib/streamlit/runtime/caching/cache_utils.py +++ b/lib/streamlit/runtime/caching/cache_utils.py @@ -392,7 +392,17 @@ def _make_value_key( try: update_hash( - (arg_name, arg_value), + arg_name, + hasher=args_hasher, + cache_type=cache_type, + hash_source=func, + ) + # we call update_hash twice here, first time for `arg name` + # without `hash_funcs`, and second time for `arg_value` with hash_funcs + # to evaluate user defined `hash_funcs` only for computing `arg_value` hash. + + update_hash( + arg_value, hasher=args_hasher, cache_type=cache_type, hash_funcs=hash_funcs, diff --git a/lib/tests/streamlit/runtime/caching/cache_data_api_test.py b/lib/tests/streamlit/runtime/caching/cache_data_api_test.py index ec1b079dc5c3..438c8c2c6b8b 100644 --- a/lib/tests/streamlit/runtime/caching/cache_data_api_test.py +++ b/lib/tests/streamlit/runtime/caching/cache_data_api_test.py @@ -214,7 +214,7 @@ def test_function_name_does_not_use_hashfuncs(self): str_hash_func = Mock(return_value=None) - @st.cache(hash_funcs={str: str_hash_func}) + @st.cache_data(hash_funcs={str: str_hash_func}) def foo(string_arg): return [] @@ -256,7 +256,6 @@ def user_hash_error_func(x): ``` Object of type tests.streamlit.runtime.caching.cache_data_api_test.CacheDataTest.test_user_hash_error..MyObj: MyObj class -Object of type builtins.tuple: ('x', MyObj class) ``` If you think this is actually a Streamlit bug, please diff --git a/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py b/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py index 01e8e90a01f7..7be1c57024e3 100644 --- a/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py +++ b/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py @@ -29,6 +29,7 @@ ) from streamlit.runtime.caching.cache_type import CacheType from streamlit.runtime.caching.cached_message_replay import MultiCacheResults +from streamlit.runtime.caching.hashing import UserHashError from streamlit.runtime.scriptrunner import add_script_run_ctx from streamlit.runtime.stats import CacheStat from tests.streamlit.runtime.caching.common_cache_test import ( @@ -132,6 +133,90 @@ def foo(): else: show_warning_mock.assert_not_called() + def test_cached_member_function_with_hash_func(self): + """@st.cache_resource can be applied to class member functions + with corresponding hash_func. + """ + + class TestClass: + @st.cache_resource( + hash_funcs={ + "tests.streamlit.runtime.caching.cache_resource_api_test.CacheResourceTest.test_cached_member_function_with_hash_func..TestClass": id + } + ) + def member_func(self): + return "member func!" + + @classmethod + @st.cache_resource + def class_method(cls): + return "class method!" + + @staticmethod + @st.cache_resource + def static_method(): + return "static method!" + + obj = TestClass() + self.assertEqual("member func!", obj.member_func()) + self.assertEqual("class method!", obj.class_method()) + self.assertEqual("static method!", obj.static_method()) + + def test_function_name_does_not_use_hashfuncs(self): + """Hash funcs should only be used on arguments to a function, + and not when computing the key for a function's unique MemCache. + """ + + str_hash_func = Mock(return_value=None) + + @st.cache(hash_funcs={str: str_hash_func}) + def foo(string_arg): + return [] + + # If our str hash_funcs is called multiple times, it's probably because + # it's being used to compute the function's function_key (as opposed to + # the value_key). It should only be used to compute the value_key! + foo("ahoy") + str_hash_func.assert_called_once_with("ahoy") + + def test_user_hash_error(self): + class MyObj: + # we specify __repr__ here, to avoid `MyObj object at 0x1347a3f70` + # in error message + def __repr__(self): + return "MyObj class" + + def bad_hash_func(x): + x += 10 # Throws a TypeError since x has type MyObj. + return x + + @st.cache_resource(hash_funcs={MyObj: bad_hash_func}) + def user_hash_error_func(x): + pass + + with self.assertRaises(UserHashError) as ctx: + my_obj = MyObj() + user_hash_error_func(my_obj) + + expected_message = """unsupported operand type(s) for +=: 'MyObj' and 'int' + +This error is likely due to a bug in `bad_hash_func()`, which is a +user-defined hash function that was passed into the `@st.cache_resource` decorator of +`user_hash_error_func()`. + +`bad_hash_func()` failed when hashing an object of type +`tests.streamlit.runtime.caching.cache_resource_api_test.CacheResourceTest.test_user_hash_error..MyObj`. If you don't know where that object is coming from, +try looking at the hash chain below for an object that you do recognize, then +pass that to `hash_funcs` instead: + +``` +Object of type tests.streamlit.runtime.caching.cache_resource_api_test.CacheResourceTest.test_user_hash_error..MyObj: MyObj class +``` + +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).""" + self.assertEqual(str(ctx.exception), expected_message) + class CacheResourceValidateTest(unittest.TestCase): def setUp(self) -> None: diff --git a/lib/tests/streamlit/runtime/caching/hashing_test.py b/lib/tests/streamlit/runtime/caching/hashing_test.py index f654eb447934..3237726f1992 100644 --- a/lib/tests/streamlit/runtime/caching/hashing_test.py +++ b/lib/tests/streamlit/runtime/caching/hashing_test.py @@ -41,7 +41,7 @@ _NP_SIZE_LARGE, _PANDAS_ROWS_LARGE, UserHashError, - _CacheFuncHasher, + update_hash, ) try: @@ -62,8 +62,9 @@ def get_hash(value, hash_funcs=None, cache_type=None): hasher = hashlib.new("md5") - ch = _CacheFuncHasher(cache_type or MagicMock(), hash_funcs=hash_funcs) - ch.update(hasher, value) + update_hash( + value, hasher, cache_type=cache_type or MagicMock(), hash_funcs=hash_funcs + ) return hasher.digest() From 2cb941e8217aaebd29dd7756753d4cd94ab652ad Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Tue, 6 Jun 2023 21:30:57 +0400 Subject: [PATCH 10/13] fix typo --- lib/tests/streamlit/runtime/caching/cache_data_api_test.py | 2 +- lib/tests/streamlit/runtime/caching/cache_resource_api_test.py | 2 +- lib/tests/streamlit/runtime/legacy_caching/caching_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/tests/streamlit/runtime/caching/cache_data_api_test.py b/lib/tests/streamlit/runtime/caching/cache_data_api_test.py index 438c8c2c6b8b..9de072c2297a 100644 --- a/lib/tests/streamlit/runtime/caching/cache_data_api_test.py +++ b/lib/tests/streamlit/runtime/caching/cache_data_api_test.py @@ -218,7 +218,7 @@ def test_function_name_does_not_use_hashfuncs(self): def foo(string_arg): return [] - # If our str hash_funcs is called multiple times, it's probably because + # If our str hash_func is called multiple times, it's probably because # it's being used to compute the function's function_key (as opposed to # the value_key). It should only be used to compute the value_key! foo("ahoy") diff --git a/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py b/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py index 7be1c57024e3..0c71efad2581 100644 --- a/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py +++ b/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py @@ -173,7 +173,7 @@ def test_function_name_does_not_use_hashfuncs(self): def foo(string_arg): return [] - # If our str hash_funcs is called multiple times, it's probably because + # If our str hash_func is called multiple times, it's probably because # it's being used to compute the function's function_key (as opposed to # the value_key). It should only be used to compute the value_key! foo("ahoy") diff --git a/lib/tests/streamlit/runtime/legacy_caching/caching_test.py b/lib/tests/streamlit/runtime/legacy_caching/caching_test.py index b55bd7baf36c..d7e2c761604a 100644 --- a/lib/tests/streamlit/runtime/legacy_caching/caching_test.py +++ b/lib/tests/streamlit/runtime/legacy_caching/caching_test.py @@ -408,7 +408,7 @@ def test_function_name_does_not_use_hashfuncs(self): def foo(string_arg): return [] - # If our str hash_funcs is called multiple times, it's probably because + # If our str hash_func is called multiple times, it's probably because # it's being used to compute the function's cache_key (as opposed to # the value_key). It should only be used to compute the value_key! foo("ahoy") From 5696a8c3466f4f2c9cd8d3ed2c7103c41a3c6bdf Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Tue, 6 Jun 2023 23:13:57 +0400 Subject: [PATCH 11/13] remove unused blank line --- lib/streamlit/runtime/caching/cache_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/streamlit/runtime/caching/cache_utils.py b/lib/streamlit/runtime/caching/cache_utils.py index c80ba5ff876a..4b4b52e3cc17 100644 --- a/lib/streamlit/runtime/caching/cache_utils.py +++ b/lib/streamlit/runtime/caching/cache_utils.py @@ -400,7 +400,6 @@ def _make_value_key( # we call update_hash twice here, first time for `arg name` # without `hash_funcs`, and second time for `arg_value` with hash_funcs # to evaluate user defined `hash_funcs` only for computing `arg_value` hash. - update_hash( arg_value, hasher=args_hasher, From 10167b116144989952116c0ea2e83fcbe412faed Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Wed, 7 Jun 2023 18:06:23 +0400 Subject: [PATCH 12/13] clean up --- lib/streamlit/runtime/caching/cache_utils.py | 2 +- lib/streamlit/runtime/caching/hashing.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/streamlit/runtime/caching/cache_utils.py b/lib/streamlit/runtime/caching/cache_utils.py index 4b4b52e3cc17..95f185431fd4 100644 --- a/lib/streamlit/runtime/caching/cache_utils.py +++ b/lib/streamlit/runtime/caching/cache_utils.py @@ -397,7 +397,7 @@ def _make_value_key( cache_type=cache_type, hash_source=func, ) - # we call update_hash twice here, first time for `arg name` + # we call update_hash twice here, first time for `arg_name` # without `hash_funcs`, and second time for `arg_value` with hash_funcs # to evaluate user defined `hash_funcs` only for computing `arg_value` hash. update_hash( diff --git a/lib/streamlit/runtime/caching/hashing.py b/lib/streamlit/runtime/caching/hashing.py index 9624429daf43..43b4902e9960 100644 --- a/lib/streamlit/runtime/caching/hashing.py +++ b/lib/streamlit/runtime/caching/hashing.py @@ -353,8 +353,6 @@ def _to_bytes(self, obj: Any) -> bytes: hash_func = self._hash_funcs[type_util.get_fqn_type(obj)] try: output = hash_func(obj) - # check `output` type, it should be primitive type - # if not, raise error except Exception as ex: raise UserHashError( ex, obj, hash_func=hash_func, cache_type=self.cache_type From ac31562055db22845ef737a0fe558ee80b14a613 Mon Sep 17 00:00:00 2001 From: Karen Javadyan Date: Thu, 8 Jun 2023 17:30:30 +0400 Subject: [PATCH 13/13] fixes after review --- lib/streamlit/runtime/caching/hashing.py | 6 +++--- lib/tests/streamlit/runtime/caching/hashing_test.py | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lib/streamlit/runtime/caching/hashing.py b/lib/streamlit/runtime/caching/hashing.py index 43b4902e9960..06ee14619fe9 100644 --- a/lib/streamlit/runtime/caching/hashing.py +++ b/lib/streamlit/runtime/caching/hashing.py @@ -180,10 +180,10 @@ def pop(self): def __contains__(self, val: Any): return id(val) in self._stack - def pretty_print(self): - def to_str(v): + def pretty_print(self) -> str: + def to_str(v: Any) -> str: try: - return "Object of type %s: %s" % (type_util.get_fqn_type(v), str(v)) + return f"Object of type {type_util.get_fqn_type(v)}: {str(v)}" except Exception: return "" diff --git a/lib/tests/streamlit/runtime/caching/hashing_test.py b/lib/tests/streamlit/runtime/caching/hashing_test.py index 3237726f1992..1cc9fefb715a 100644 --- a/lib/tests/streamlit/runtime/caching/hashing_test.py +++ b/lib/tests/streamlit/runtime/caching/hashing_test.py @@ -123,6 +123,10 @@ def test_list(self): [("cache_data", cache_data), ("cache_resource", cache_resource)] ) def test_recursive_hash_func(self, _, cache_decorator): + """Test that if user defined hash_func returns the value of the same type + that hash_funcs tries to cache, we break the recursive cycle with predefined + placeholder""" + def hash_int(x): return x @@ -414,14 +418,14 @@ def test_generator_not_hashable(self): get_hash((x for x in range(1))) def test_hash_funcs_acceptable_keys(self): + """Test that hashes are equivalent when hash_func key is supplied both as a + type literal, and as a type name string. + """ test_generator = (x for x in range(1)) with self.assertRaises(UnhashableTypeError): get_hash(test_generator) - # Assert that hashes are equivalent when hash_func key is supplied both as a - # type literal, and as a type name string. - self.assertEqual( get_hash(test_generator, hash_funcs={types.GeneratorType: id}), get_hash(test_generator, hash_funcs={"builtins.generator": id}),