Skip to content

Commit

Permalink
Rearrange + rename "cached message replay" code (streamlit#6012)
Browse files Browse the repository at this point in the history
Moves "cached message replay" logic into its own module, and simplifies/renames a few things for clarity.

- message replay logic is moved out of `cache_utils` and into its own module, `cached_message_replay`. This is the bulk of the PR.
- The `CacheWarningCallStack` and `CacheMessagesCallStack` utilities have been simplified and merged into a single class called `CachedMessageReplayContext`. (This object has just a single context manager to wrapped cached-function calls, which helps with some context manager deep-nesting.)
- `replay_result_messages()` is now called `replay_cached_messages()`
  • Loading branch information
tconkling committed Jan 26, 2023
1 parent 25cd415 commit d140533
Show file tree
Hide file tree
Showing 8 changed files with 552 additions and 582 deletions.
32 changes: 17 additions & 15 deletions lib/streamlit/runtime/caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@

from streamlit.proto.Block_pb2 import Block
from streamlit.runtime.caching.cache_data_api import (
CACHE_DATA_CALL_STACK,
CACHE_DATA_MESSAGE_CALL_STACK,
CACHE_DATA_MESSAGE_REPLAY_CTX,
CacheDataAPI,
_data_caches,
)
from streamlit.runtime.caching.cache_resource_api import (
CACHE_RESOURCE_CALL_STACK,
CACHE_RESOURCE_MESSAGE_CALL_STACK,
CACHE_RESOURCE_MESSAGE_REPLAY_CTX,
CacheResourceAPI,
_resource_caches,
)
Expand All @@ -46,10 +44,10 @@ def save_element_message(
be used later to replay the element when a cache-decorated function's
execution is skipped.
"""
CACHE_DATA_MESSAGE_CALL_STACK.save_element_message(
CACHE_DATA_MESSAGE_REPLAY_CTX.save_element_message(
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
CACHE_RESOURCE_MESSAGE_CALL_STACK.save_element_message(
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_element_message(
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
)

Expand All @@ -64,10 +62,10 @@ def save_block_message(
be used later to replay the block when a cache-decorated function's
execution is skipped.
"""
CACHE_DATA_MESSAGE_CALL_STACK.save_block_message(
CACHE_DATA_MESSAGE_REPLAY_CTX.save_block_message(
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
CACHE_RESOURCE_MESSAGE_CALL_STACK.save_block_message(
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_block_message(
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
)

Expand All @@ -76,25 +74,29 @@ def save_widget_metadata(metadata: WidgetMetadata[Any]) -> None:
"""Save a widget's metadata to a thread-local callstack, so the widget
can be registered again when that widget is replayed.
"""
CACHE_DATA_MESSAGE_CALL_STACK.save_widget_metadata(metadata)
CACHE_RESOURCE_MESSAGE_CALL_STACK.save_widget_metadata(metadata)
CACHE_DATA_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)


def save_media_data(
image_data: Union[bytes, str], mimetype: str, image_id: str
) -> None:
CACHE_DATA_MESSAGE_CALL_STACK.save_image_data(image_data, mimetype, image_id)
CACHE_RESOURCE_MESSAGE_CALL_STACK.save_image_data(image_data, mimetype, image_id)
CACHE_DATA_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)


def maybe_show_cached_st_function_warning(dg, st_func_name: str) -> None:
CACHE_DATA_CALL_STACK.maybe_show_cached_st_function_warning(dg, st_func_name)
CACHE_RESOURCE_CALL_STACK.maybe_show_cached_st_function_warning(dg, st_func_name)
CACHE_DATA_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
dg, st_func_name
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
dg, st_func_name
)


@contextlib.contextmanager
def suppress_cached_st_function_warning() -> Iterator[None]:
with CACHE_DATA_CALL_STACK.suppress_cached_st_function_warning(), CACHE_RESOURCE_CALL_STACK.suppress_cached_st_function_warning():
with CACHE_DATA_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning(), CACHE_RESOURCE_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning():
yield


Expand Down
20 changes: 8 additions & 12 deletions lib/streamlit/runtime/caching/cache_data_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFunction,
create_cache_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
CacheMessagesCallStack,
CacheWarningCallStack,
ElementMsgData,
MsgData,
MultiCacheResults,
create_cache_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
Expand All @@ -64,8 +65,7 @@
# (`@st.cache_data` was originally called `@st.memo`)
_CACHED_FILE_EXTENSION = "memo"

CACHE_DATA_CALL_STACK = CacheWarningCallStack(CacheType.DATA)
CACHE_DATA_MESSAGE_CALL_STACK = CacheMessagesCallStack(CacheType.DATA)
CACHE_DATA_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.DATA)

# The cache persistence options we support: "disk" or None
CachePersistType: TypeAlias = Union[Literal["disk"], None]
Expand Down Expand Up @@ -97,12 +97,8 @@ def cache_type(self) -> CacheType:
return CacheType.DATA

@property
def warning_call_stack(self) -> CacheWarningCallStack:
return CACHE_DATA_CALL_STACK

@property
def message_call_stack(self) -> CacheMessagesCallStack:
return CACHE_DATA_MESSAGE_CALL_STACK
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_DATA_MESSAGE_REPLAY_CTX

@property
def display_name(self) -> str:
Expand Down
20 changes: 8 additions & 12 deletions lib/streamlit/runtime/caching/cache_resource_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFunction,
create_cache_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
CacheMessagesCallStack,
CacheWarningCallStack,
ElementMsgData,
MsgData,
MultiCacheResults,
create_cache_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
Expand All @@ -51,8 +52,7 @@
_LOGGER = get_logger(__name__)


CACHE_RESOURCE_CALL_STACK = CacheWarningCallStack(CacheType.RESOURCE)
CACHE_RESOURCE_MESSAGE_CALL_STACK = CacheMessagesCallStack(CacheType.RESOURCE)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE)

ValidateFunc: TypeAlias = Callable[[Any], bool]

Expand Down Expand Up @@ -168,12 +168,8 @@ def cache_type(self) -> CacheType:
return CacheType.RESOURCE

@property
def warning_call_stack(self) -> CacheWarningCallStack:
return CACHE_RESOURCE_CALL_STACK

@property
def message_call_stack(self) -> CacheMessagesCallStack:
return CACHE_RESOURCE_MESSAGE_CALL_STACK
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_RESOURCE_MESSAGE_REPLAY_CTX

@property
def display_name(self) -> str:
Expand Down
Loading

0 comments on commit d140533

Please sign in to comment.