diff --git a/lib/streamlit/runtime/caching/cache_data_api.py b/lib/streamlit/runtime/caching/cache_data_api.py index b46a04616fe3..8197e534f734 100644 --- a/lib/streamlit/runtime/caching/cache_data_api.py +++ b/lib/streamlit/runtime/caching/cache_data_api.py @@ -44,6 +44,7 @@ MsgData, MultiCacheResults, ) +from streamlit.runtime.caching.hashing import HashFuncsDict from streamlit.runtime.caching.storage import ( CacheStorage, CacheStorageContext, @@ -80,11 +81,13 @@ def __init__( max_entries: int | None, ttl: float | timedelta | str | None, allow_widgets: bool, + hash_funcs: HashFuncsDict | None = None, ): super().__init__( func, show_spinner=show_spinner, allow_widgets=allow_widgets, + hash_funcs=hash_funcs, ) self.persist = persist self.max_entries = max_entries @@ -365,6 +368,7 @@ def __call__( show_spinner: bool | str = True, persist: CachePersistType | bool = None, experimental_allow_widgets: bool = False, + hash_funcs: HashFuncsDict | None = None, ): return self._decorator( func, @@ -373,6 +377,7 @@ def __call__( persist=persist, show_spinner=show_spinner, experimental_allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) def _decorator( @@ -384,6 +389,7 @@ def _decorator( show_spinner: bool | str, persist: CachePersistType | bool, experimental_allow_widgets: bool, + hash_funcs: HashFuncsDict | None = None, ): """Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference). @@ -529,6 +535,7 @@ def wrapper(f): max_entries=max_entries, ttl=ttl, allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) ) @@ -543,6 +550,7 @@ def wrapper(f): max_entries=max_entries, ttl=ttl, allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) ) diff --git a/lib/streamlit/runtime/caching/cache_resource_api.py b/lib/streamlit/runtime/caching/cache_resource_api.py index 02348549da1f..4d7a48bdadc1 100644 --- a/lib/streamlit/runtime/caching/cache_resource_api.py +++ b/lib/streamlit/runtime/caching/cache_resource_api.py @@ -45,6 +45,7 @@ MsgData, MultiCacheResults, ) +from streamlit.runtime.caching.hashing import HashFuncsDict from streamlit.runtime.metrics_util import gather_metrics from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx from streamlit.runtime.stats import CacheStat, CacheStatsProvider @@ -153,11 +154,13 @@ def __init__( ttl: float | timedelta | str | None, validate: ValidateFunc | None, allow_widgets: bool, + hash_funcs: HashFuncsDict | None = None, ): super().__init__( func, show_spinner=show_spinner, allow_widgets=allow_widgets, + hash_funcs=hash_funcs, ) self.max_entries = max_entries self.ttl = ttl @@ -245,6 +248,7 @@ def __call__( show_spinner: bool | str = True, validate: ValidateFunc | None = None, experimental_allow_widgets: bool = False, + hash_funcs: HashFuncsDict | None = None, ): return self._decorator( func, @@ -253,6 +257,7 @@ def __call__( show_spinner=show_spinner, validate=validate, experimental_allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) def _decorator( @@ -264,6 +269,7 @@ def _decorator( show_spinner: bool | str, validate: ValidateFunc | None, experimental_allow_widgets: bool, + hash_funcs: HashFuncsDict | None = None, ): """Decorator to cache functions that return global resources (e.g. database connections, ML models). @@ -389,6 +395,7 @@ def _decorator( ttl=ttl, validate=validate, allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) ) @@ -400,6 +407,7 @@ def _decorator( ttl=ttl, validate=validate, allow_widgets=experimental_allow_widgets, + hash_funcs=hash_funcs, ) ) diff --git a/lib/streamlit/runtime/caching/cache_utils.py b/lib/streamlit/runtime/caching/cache_utils.py index 4a8597715c35..95f185431fd4 100644 --- a/lib/streamlit/runtime/caching/cache_utils.py +++ b/lib/streamlit/runtime/caching/cache_utils.py @@ -50,7 +50,7 @@ MsgData, replay_cached_messages, ) -from streamlit.runtime.caching.hashing import update_hash +from streamlit.runtime.caching.hashing import HashFuncsDict, update_hash _LOGGER = get_logger(__name__) @@ -169,10 +169,12 @@ def __init__( func: types.FunctionType, show_spinner: bool | str, allow_widgets: bool, + hash_funcs: HashFuncsDict | None, ): self.func = func self.show_spinner = show_spinner self.allow_widgets = allow_widgets + self.hash_funcs = hash_funcs @property def cache_type(self) -> CacheType: @@ -254,6 +256,7 @@ def _get_or_create_cached_value( func=self._info.func, func_args=func_args, func_kwargs=func_kwargs, + hash_funcs=self._info.hash_funcs, ) try: @@ -351,6 +354,7 @@ def _make_value_key( func: types.FunctionType, func_args: tuple[Any, ...], func_kwargs: dict[str, Any], + hash_funcs: HashFuncsDict | None, ) -> str: """Create the key for a value within a cache. @@ -388,9 +392,20 @@ def _make_value_key( try: update_hash( - (arg_name, arg_value), + arg_name, hasher=args_hasher, cache_type=cache_type, + hash_source=func, + ) + # we call update_hash twice here, first time for `arg_name` + # without `hash_funcs`, and second time for `arg_value` with hash_funcs + # to evaluate user defined `hash_funcs` only for computing `arg_value` hash. + update_hash( + arg_value, + hasher=args_hasher, + cache_type=cache_type, + hash_funcs=hash_funcs, + hash_source=func, ) except UnhashableTypeError as exc: raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc) @@ -417,6 +432,7 @@ def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str: (func.__module__, func.__qualname__), hasher=func_hasher, cache_type=cache_type, + hash_source=func, ) # Include the function's source code in its hash. If the source code can't @@ -432,9 +448,7 @@ def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str: source_code = func.__code__.co_code update_hash( - source_code, - hasher=func_hasher, - cache_type=cache_type, + source_code, hasher=func_hasher, cache_type=cache_type, hash_source=func ) cache_key = func_hasher.hexdigest() diff --git a/lib/streamlit/runtime/caching/hashing.py b/lib/streamlit/runtime/caching/hashing.py index 8c995675a537..06ee14619fe9 100644 --- a/lib/streamlit/runtime/caching/hashing.py +++ b/lib/streamlit/runtime/caching/hashing.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Hashing for st.memo and st.singleton.""" +"""Hashing for st.cache_data and st.cache_resource.""" import collections import dataclasses import functools @@ -28,9 +28,10 @@ import uuid import weakref from enum import Enum -from typing import Any, Dict, List, Optional, Pattern +from typing import Any, Callable, Dict, List, Optional, Pattern, Type, Union from streamlit import type_util, util +from streamlit.errors import StreamlitAPIException from streamlit.runtime.caching.cache_errors import UnhashableTypeError from streamlit.runtime.caching.cache_type import CacheType from streamlit.runtime.uploaded_file_manager import UploadedFile @@ -43,17 +44,109 @@ _NP_SIZE_LARGE = 1000000 _NP_SAMPLE_SIZE = 100000 +HashFuncsDict = Dict[Union[str, Type[Any]], Callable[[Any], Any]] + # Arbitrary item to denote where we found a cycle in a hashed object. # This allows us to hash self-referencing lists, dictionaries, etc. _CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE" -def update_hash(val: Any, hasher, cache_type: CacheType) -> None: +class UserHashError(StreamlitAPIException): + def __init__( + self, + orig_exc, + object_to_hash, + hash_func, + cache_type: Optional[CacheType] = None, + ): + self.alternate_name = type(orig_exc).__name__ + self.hash_func = hash_func + self.cache_type = cache_type + + msg = self._get_message_from_func(orig_exc, object_to_hash) + + super().__init__(msg) + self.with_traceback(orig_exc.__traceback__) + + def _get_message_from_func(self, orig_exc, cached_func): + args = self._get_error_message_args(orig_exc, cached_func) + + return ( + """ +%(orig_exception_desc)s + +This error is likely due to a bug in %(hash_func_name)s, which is a +user-defined hash function that was passed into the `%(cache_primitive)s` decorator of +%(object_desc)s. + +%(hash_func_name)s failed when hashing an object of type +`%(failed_obj_type_str)s`. If you don't know where that object is coming from, +try looking at the hash chain below for an object that you do recognize, then +pass that to `hash_funcs` instead: + +``` +%(hash_stack)s +``` + +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose). +""" + % args + ).strip("\n") + + def _get_error_message_args( + self, + orig_exc: BaseException, + failed_obj: Any, + ) -> Dict[str, Any]: + hash_source = hash_stacks.current.hash_source + + failed_obj_type_str = type_util.get_fqn_type(failed_obj) + + if hash_source is None: + object_desc = "something" + else: + if hasattr(hash_source, "__name__"): + object_desc = f"`{hash_source.__name__}()`" + else: + object_desc = "a function" + + decorator_name = "" + if self.cache_type is CacheType.RESOURCE: + decorator_name = "@st.cache_resource" + elif self.cache_type is CacheType.DATA: + decorator_name = "@st.cache_data" + + if hasattr(self.hash_func, "__name__"): + hash_func_name = f"`{self.hash_func.__name__}()`" + else: + hash_func_name = "a function" + + return { + "orig_exception_desc": str(orig_exc), + "failed_obj_type_str": failed_obj_type_str, + "hash_stack": hash_stacks.current.pretty_print(), + "object_desc": object_desc, + "cache_primitive": decorator_name, + "hash_func_name": hash_func_name, + } + + +def update_hash( + val: Any, + hasher, + cache_type: CacheType, + hash_source: Optional[Callable[..., Any]] = None, + hash_funcs: Optional[HashFuncsDict] = None, +) -> None: """Updates a hashlib hasher with the hash of val. This is the main entrypoint to hashing.py. """ - ch = _CacheFuncHasher(cache_type) + + hash_stacks.current.hash_source = hash_source + + ch = _CacheFuncHasher(cache_type, hash_funcs) ch.update(hasher, val) @@ -71,6 +164,9 @@ class _HashStack: def __init__(self): self._stack: collections.OrderedDict[int, List[Any]] = collections.OrderedDict() + # A function that we decorate with streamlit cache + # primitive (st.cache_data or st.cache_resource). + self.hash_source: Optional[Callable[..., Any]] = None def __repr__(self) -> str: return util.repr_(self) @@ -84,6 +180,15 @@ def pop(self): def __contains__(self, val: Any): return id(val) in self._stack + def pretty_print(self) -> str: + def to_str(v: Any) -> str: + try: + return f"Object of type {type_util.get_fqn_type(v)}: {str(v)}" + except Exception: + return "" + + return "\n".join(to_str(x) for x in reversed(self._stack.values())) + class _HashStacks: """Stacks of what has been hashed, with at most 1 stack per thread.""" @@ -161,7 +266,24 @@ def is_simple(obj): class _CacheFuncHasher: """A hasher that can hash objects with cycles.""" - def __init__(self, cache_type: CacheType): + def __init__( + self, cache_type: CacheType, hash_funcs: Optional[HashFuncsDict] = None + ): + # Can't use types as the keys in the internal _hash_funcs because + # we always remove user-written modules from memory when rerunning a + # script in order to reload it and grab the latest code changes. + # (See LocalSourcesWatcher.py:on_file_changed) This causes + # the type object to refer to different underlying class instances each run, + # so type-based comparisons fail. To solve this, we use the types converted + # to fully-qualified strings as keys in our internal dict. + self._hash_funcs: HashFuncsDict + if hash_funcs: + self._hash_funcs = { + k if isinstance(k, str) else type_util.get_fqn(k): v + for k, v in hash_funcs.items() + } + else: + self._hash_funcs = {} self._hashes: Dict[Any, bytes] = {} # The number of the bytes in the hash. @@ -226,6 +348,17 @@ def _to_bytes(self, obj: Any) -> bytes: elif isinstance(obj, bytes) or isinstance(obj, bytearray): return obj + elif type_util.get_fqn_type(obj) in self._hash_funcs: + # Escape hatch for unsupported objects + hash_func = self._hash_funcs[type_util.get_fqn_type(obj)] + try: + output = hash_func(obj) + except Exception as ex: + raise UserHashError( + ex, obj, hash_func=hash_func, cache_type=self.cache_type + ) from ex + return self.to_bytes(output) + elif isinstance(obj, str): return obj.encode() diff --git a/lib/streamlit/runtime/legacy_caching/hashing.py b/lib/streamlit/runtime/legacy_caching/hashing.py index 4a0f04741d30..1c2b48bdd5cf 100644 --- a/lib/streamlit/runtime/legacy_caching/hashing.py +++ b/lib/streamlit/runtime/legacy_caching/hashing.py @@ -867,8 +867,8 @@ def _get_message_from_func(self, orig_exc, cached_func, hash_func): %(hash_stack)s ``` -If you think this is actually a Streamlit bug, please [file a bug report here.] -(https://github.com/streamlit/streamlit/issues/new/choose) +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose). """ % args ).strip("\n") diff --git a/lib/tests/streamlit/runtime/caching/cache_data_api_test.py b/lib/tests/streamlit/runtime/caching/cache_data_api_test.py index 5ae59f5ce255..9de072c2297a 100644 --- a/lib/tests/streamlit/runtime/caching/cache_data_api_test.py +++ b/lib/tests/streamlit/runtime/caching/cache_data_api_test.py @@ -41,6 +41,7 @@ MultiCacheResults, _make_widget_key, ) +from streamlit.runtime.caching.hashing import UserHashError from streamlit.runtime.caching.storage import ( CacheStorage, CacheStorageContext, @@ -177,6 +178,90 @@ def foo(): else: show_warning_mock.assert_not_called() + def test_cached_member_function_with_hash_func(self): + """@st.cache_data can be applied to class member functions + with corresponding hash_func. + """ + + class TestClass: + @st.cache_data( + hash_funcs={ + "tests.streamlit.runtime.caching.cache_data_api_test.CacheDataTest.test_cached_member_function_with_hash_func..TestClass": id + } + ) + def member_func(self): + return "member func!" + + @classmethod + @st.cache_data + def class_method(cls): + return "class method!" + + @staticmethod + @st.cache_data + def static_method(): + return "static method!" + + obj = TestClass() + self.assertEqual("member func!", obj.member_func()) + self.assertEqual("class method!", obj.class_method()) + self.assertEqual("static method!", obj.static_method()) + + def test_function_name_does_not_use_hashfuncs(self): + """Hash funcs should only be used on arguments to a function, + and not when computing the key for a function's unique MemCache. + """ + + str_hash_func = Mock(return_value=None) + + @st.cache_data(hash_funcs={str: str_hash_func}) + def foo(string_arg): + return [] + + # If our str hash_func is called multiple times, it's probably because + # it's being used to compute the function's function_key (as opposed to + # the value_key). It should only be used to compute the value_key! + foo("ahoy") + str_hash_func.assert_called_once_with("ahoy") + + def test_user_hash_error(self): + class MyObj: + # we specify __repr__ here, to avoid `MyObj object at 0x1347a3f70` + # in error message + def __repr__(self): + return "MyObj class" + + def bad_hash_func(x): + x += 10 # Throws a TypeError since x has type MyObj. + return x + + @st.cache_data(hash_funcs={MyObj: bad_hash_func}) + def user_hash_error_func(x): + pass + + with self.assertRaises(UserHashError) as ctx: + my_obj = MyObj() + user_hash_error_func(my_obj) + + expected_message = """unsupported operand type(s) for +=: 'MyObj' and 'int' + +This error is likely due to a bug in `bad_hash_func()`, which is a +user-defined hash function that was passed into the `@st.cache_data` decorator of +`user_hash_error_func()`. + +`bad_hash_func()` failed when hashing an object of type +`tests.streamlit.runtime.caching.cache_data_api_test.CacheDataTest.test_user_hash_error..MyObj`. If you don't know where that object is coming from, +try looking at the hash chain below for an object that you do recognize, then +pass that to `hash_funcs` instead: + +``` +Object of type tests.streamlit.runtime.caching.cache_data_api_test.CacheDataTest.test_user_hash_error..MyObj: MyObj class +``` + +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).""" + self.assertEqual(str(ctx.exception), expected_message) + class CacheDataPersistTest(DeltaGeneratorTestCase): """st.cache_data disk persistence tests""" diff --git a/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py b/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py index 01e8e90a01f7..0c71efad2581 100644 --- a/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py +++ b/lib/tests/streamlit/runtime/caching/cache_resource_api_test.py @@ -29,6 +29,7 @@ ) from streamlit.runtime.caching.cache_type import CacheType from streamlit.runtime.caching.cached_message_replay import MultiCacheResults +from streamlit.runtime.caching.hashing import UserHashError from streamlit.runtime.scriptrunner import add_script_run_ctx from streamlit.runtime.stats import CacheStat from tests.streamlit.runtime.caching.common_cache_test import ( @@ -132,6 +133,90 @@ def foo(): else: show_warning_mock.assert_not_called() + def test_cached_member_function_with_hash_func(self): + """@st.cache_resource can be applied to class member functions + with corresponding hash_func. + """ + + class TestClass: + @st.cache_resource( + hash_funcs={ + "tests.streamlit.runtime.caching.cache_resource_api_test.CacheResourceTest.test_cached_member_function_with_hash_func..TestClass": id + } + ) + def member_func(self): + return "member func!" + + @classmethod + @st.cache_resource + def class_method(cls): + return "class method!" + + @staticmethod + @st.cache_resource + def static_method(): + return "static method!" + + obj = TestClass() + self.assertEqual("member func!", obj.member_func()) + self.assertEqual("class method!", obj.class_method()) + self.assertEqual("static method!", obj.static_method()) + + def test_function_name_does_not_use_hashfuncs(self): + """Hash funcs should only be used on arguments to a function, + and not when computing the key for a function's unique MemCache. + """ + + str_hash_func = Mock(return_value=None) + + @st.cache(hash_funcs={str: str_hash_func}) + def foo(string_arg): + return [] + + # If our str hash_func is called multiple times, it's probably because + # it's being used to compute the function's function_key (as opposed to + # the value_key). It should only be used to compute the value_key! + foo("ahoy") + str_hash_func.assert_called_once_with("ahoy") + + def test_user_hash_error(self): + class MyObj: + # we specify __repr__ here, to avoid `MyObj object at 0x1347a3f70` + # in error message + def __repr__(self): + return "MyObj class" + + def bad_hash_func(x): + x += 10 # Throws a TypeError since x has type MyObj. + return x + + @st.cache_resource(hash_funcs={MyObj: bad_hash_func}) + def user_hash_error_func(x): + pass + + with self.assertRaises(UserHashError) as ctx: + my_obj = MyObj() + user_hash_error_func(my_obj) + + expected_message = """unsupported operand type(s) for +=: 'MyObj' and 'int' + +This error is likely due to a bug in `bad_hash_func()`, which is a +user-defined hash function that was passed into the `@st.cache_resource` decorator of +`user_hash_error_func()`. + +`bad_hash_func()` failed when hashing an object of type +`tests.streamlit.runtime.caching.cache_resource_api_test.CacheResourceTest.test_user_hash_error..MyObj`. If you don't know where that object is coming from, +try looking at the hash chain below for an object that you do recognize, then +pass that to `hash_funcs` instead: + +``` +Object of type tests.streamlit.runtime.caching.cache_resource_api_test.CacheResourceTest.test_user_hash_error..MyObj: MyObj class +``` + +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).""" + self.assertEqual(str(ctx.exception), expected_message) + class CacheResourceValidateTest(unittest.TestCase): def setUp(self) -> None: diff --git a/lib/tests/streamlit/runtime/caching/hashing_test.py b/lib/tests/streamlit/runtime/caching/hashing_test.py index 2e067fa68ad7..1cc9fefb715a 100644 --- a/lib/tests/streamlit/runtime/caching/hashing_test.py +++ b/lib/tests/streamlit/runtime/caching/hashing_test.py @@ -19,6 +19,7 @@ import os import re import tempfile +import time import types import unittest import uuid @@ -33,11 +34,14 @@ from parameterized import parameterized from PIL import Image +from streamlit.runtime.caching import cache_data, cache_resource from streamlit.runtime.caching.cache_errors import UnhashableTypeError +from streamlit.runtime.caching.cache_type import CacheType from streamlit.runtime.caching.hashing import ( _NP_SIZE_LARGE, _PANDAS_ROWS_LARGE, - _CacheFuncHasher, + UserHashError, + update_hash, ) try: @@ -56,10 +60,11 @@ get_main_script_director = MagicMock(return_value=os.getcwd()) -def get_hash(f): +def get_hash(value, hash_funcs=None, cache_type=None): hasher = hashlib.new("md5") - ch = _CacheFuncHasher(MagicMock()) - ch.update(hasher, f) + update_hash( + value, hasher, cache_type=cache_type or MagicMock(), hash_funcs=hash_funcs + ) return hasher.digest() @@ -114,6 +119,27 @@ def test_list(self): b.append(b) self.assertEqual(get_hash(a), get_hash(b)) + @parameterized.expand( + [("cache_data", cache_data), ("cache_resource", cache_resource)] + ) + def test_recursive_hash_func(self, _, cache_decorator): + """Test that if user defined hash_func returns the value of the same type + that hash_funcs tries to cache, we break the recursive cycle with predefined + placeholder""" + + def hash_int(x): + return x + + @cache_decorator(hash_funcs={int: hash_int}) + def foo(x): + return x + + self.assertEqual(foo(1), foo(1)) + # Note: We're able to break the recursive cycle caused by the identity + # hash func but it causes all cycles to hash to the same thing. + # https://github.com/streamlit/streamlit/issues/1659 + # self.assertNotEqual(foo(2), foo(1)) + def test_tuple(self): self.assertEqual(get_hash((1, 2)), get_hash((1, 2))) self.assertNotEqual(get_hash((1, 2)), get_hash((2, 2))) @@ -391,6 +417,72 @@ def test_generator_not_hashable(self): with self.assertRaises(UnhashableTypeError): get_hash((x for x in range(1))) + def test_hash_funcs_acceptable_keys(self): + """Test that hashes are equivalent when hash_func key is supplied both as a + type literal, and as a type name string. + """ + test_generator = (x for x in range(1)) + + with self.assertRaises(UnhashableTypeError): + get_hash(test_generator) + + self.assertEqual( + get_hash(test_generator, hash_funcs={types.GeneratorType: id}), + get_hash(test_generator, hash_funcs={"builtins.generator": id}), + ) + + def test_hash_funcs_error(self): + with self.assertRaises(UserHashError) as ctx: + get_hash(1, cache_type=CacheType.DATA, hash_funcs={int: lambda x: "a" + x}) + + expected_message = """can only concatenate str (not "int") to str + +This error is likely due to a bug in `()`, which is a +user-defined hash function that was passed into the `@st.cache_data` decorator of +something. + +`()` failed when hashing an object of type +`builtins.int`. If you don't know where that object is coming from, +try looking at the hash chain below for an object that you do recognize, then +pass that to `hash_funcs` instead: + +``` +Object of type builtins.int: 1 +``` + +If you think this is actually a Streamlit bug, please +[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).""" + self.assertEqual(str(ctx.exception), expected_message) + + def test_non_hashable(self): + """Test user provided hash functions.""" + + g = (x for x in range(1)) + + # Unhashable object raises an error + with self.assertRaises(UnhashableTypeError): + get_hash(g) + + id_hash_func = {types.GeneratorType: id} + + self.assertEqual( + get_hash(g, hash_funcs=id_hash_func), + get_hash(g, hash_funcs=id_hash_func), + ) + + unique_hash_func = {types.GeneratorType: lambda x: time.time()} + + self.assertNotEqual( + get_hash(g, hash_funcs=unique_hash_func), + get_hash(g, hash_funcs=unique_hash_func), + ) + + def test_override_streamlit_hash_func(self): + """Test that a user provided hash function has priority over a streamlit one.""" + + hash_funcs = {int: lambda x: "hello"} + self.assertNotEqual(get_hash(1), get_hash(1, hash_funcs=hash_funcs)) + def test_function_not_hashable(self): def foo(): pass