From 44227adc2fda53e7ccb00a7099b23902e6fb1604 Mon Sep 17 00:00:00 2001 From: Amanda Walker Date: Wed, 7 Feb 2024 15:42:07 -0800 Subject: [PATCH] Expire session storage cache on an async timer (#8083) ## Describe your changes To reduce the tendency of expired sessions to stick around for a long time for lower traffic apps, and potentially consume lots of memory, add an async task to periodically expire the TTLCache used in the default session storage implementation. ## GitHub Issue Link (if applicable) ## Testing Plan Was manually tested, and unit testing might not be that helpful. --- **Contribution License Agreement** By submitting this pull request you agree that all contributions to this project are made under the Apache 2.0 license. --- .../runtime/caching/cache_resource_api.py | 4 +- .../in_memory_cache_storage_wrapper.py | 5 +- .../runtime/memory_session_storage.py | 6 +-- lib/streamlit/util.py | 49 ++++++++++++++++++- lib/tests/streamlit/util_test.py | 25 ++++++++++ 5 files changed, 80 insertions(+), 9 deletions(-) diff --git a/lib/streamlit/runtime/caching/cache_resource_api.py b/lib/streamlit/runtime/caching/cache_resource_api.py index 6969a56b43ef..e6280ee5ae8f 100644 --- a/lib/streamlit/runtime/caching/cache_resource_api.py +++ b/lib/streamlit/runtime/caching/cache_resource_api.py @@ -22,7 +22,6 @@ from datetime import timedelta from typing import Any, Callable, TypeVar, cast, overload -from cachetools import TTLCache from typing_extensions import TypeAlias import streamlit as st @@ -48,6 +47,7 @@ from streamlit.runtime.metrics_util import gather_metrics from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats +from streamlit.util import TimedCleanupCache from streamlit.vendor.pympler.asizeof import asizeof _LOGGER = get_logger(__name__) @@ -473,7 +473,7 @@ def __init__( super().__init__() self.key = key self.display_name = display_name - self._mem_cache: TTLCache[str, MultiCacheResults] = TTLCache( + self._mem_cache: TimedCleanupCache[str, MultiCacheResults] = TimedCleanupCache( maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER ) self._mem_cache_lock = threading.Lock() diff --git a/lib/streamlit/runtime/caching/storage/in_memory_cache_storage_wrapper.py b/lib/streamlit/runtime/caching/storage/in_memory_cache_storage_wrapper.py index bb0d19e78294..dc750c394676 100644 --- a/lib/streamlit/runtime/caching/storage/in_memory_cache_storage_wrapper.py +++ b/lib/streamlit/runtime/caching/storage/in_memory_cache_storage_wrapper.py @@ -16,8 +16,6 @@ import math import threading -from cachetools import TTLCache - from streamlit.logger import get_logger from streamlit.runtime.caching import cache_utils from streamlit.runtime.caching.storage.cache_storage_protocol import ( @@ -26,6 +24,7 @@ CacheStorageKeyNotFoundError, ) from streamlit.runtime.stats import CacheStat +from streamlit.util import TimedCleanupCache _LOGGER = get_logger(__name__) @@ -62,7 +61,7 @@ def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext): self.function_display_name = context.function_display_name self._ttl_seconds = context.ttl_seconds self._max_entries = context.max_entries - self._mem_cache: TTLCache[str, bytes] = TTLCache( + self._mem_cache: TimedCleanupCache[str, bytes] = TimedCleanupCache( maxsize=self.max_entries, ttl=self.ttl_seconds, timer=cache_utils.TTLCACHE_TIMER, diff --git a/lib/streamlit/runtime/memory_session_storage.py b/lib/streamlit/runtime/memory_session_storage.py index c9ee7511bf68..5e7088d6eeda 100644 --- a/lib/streamlit/runtime/memory_session_storage.py +++ b/lib/streamlit/runtime/memory_session_storage.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, MutableMapping, Optional -from cachetools import TTLCache +from typing import List, MutableMapping, Optional from streamlit.runtime.session_manager import SessionInfo, SessionStorage +from streamlit.util import TimedCleanupCache class MemorySessionStorage(SessionStorage): @@ -55,7 +55,7 @@ def __init__( inaccessible and will be removed eventually. """ - self._cache: MutableMapping[str, SessionInfo] = TTLCache( + self._cache: MutableMapping[str, SessionInfo] = TimedCleanupCache( maxsize=maxsize, ttl=ttl_seconds ) diff --git a/lib/streamlit/util.py b/lib/streamlit/util.py index 4b4f3383309d..2f4bc126cdc4 100644 --- a/lib/streamlit/util.py +++ b/lib/streamlit/util.py @@ -16,14 +16,27 @@ from __future__ import annotations +import asyncio import dataclasses import functools import hashlib import os import subprocess import sys -from typing import Any, Dict, Iterable, List, Mapping, Set, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Set, + TypeVar, + Union, +) +from cachetools import TTLCache from typing_extensions import Final from streamlit import env_util @@ -203,3 +216,37 @@ def extract_key_query_params( for item in sublist ] ) + + +K = TypeVar("K") +V = TypeVar("V") + + +class TimedCleanupCache(TTLCache, Generic[K, V]): + """A TTLCache that asynchronously expires its entries.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._task: Optional[asyncio.Task[Any]] = None + + def __setitem__(self, key: K, value: V) -> None: + # Set an expiration task to run periodically + # Can't be created in init because that only runs once and + # the event loop might not exist yet. + if self._task is None: + try: + self._task = asyncio.create_task(expire_cache(self)) + except RuntimeError: + # Just continue if the event loop isn't started yet. + pass + super().__setitem__(key, value) + + def __del__(self): + if self._task is not None: + self._task.cancel() + + +async def expire_cache(cache: TTLCache) -> None: + while True: + await asyncio.sleep(30) + cache.expire() diff --git a/lib/tests/streamlit/util_test.py b/lib/tests/streamlit/util_test.py index fd26c6fd0182..628ea3893149 100644 --- a/lib/tests/streamlit/util_test.py +++ b/lib/tests/streamlit/util_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import gc import random import unittest from typing import Dict, List, Set @@ -185,3 +187,26 @@ def test_calc_md5_can_handle_bytes_and_strings(self): util.calc_md5("eventually bytes"), util.calc_md5("eventually bytes".encode("utf-8")), ) + + def test_timed_cleanup_cache_gc(self): + """Test that the TimedCleanupCache does not leave behind tasks when + the cache is not externally reachable""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def create_cache(): + cache = util.TimedCleanupCache(maxsize=2, ttl=10) + cache["foo"] = "bar" + + # expire_cache and create_cache + assert len(asyncio.all_tasks()) > 1 + + asyncio.run(create_cache()) + + gc.collect() + + async def check(): + # Only has this function running + assert len(asyncio.all_tasks()) == 1 + + asyncio.run(check())