Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Expire session storage cache on an async timer (#8083)" #8281

Merged
merged 1 commit into from Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/streamlit/runtime/caching/cache_resource_api.py
Expand Up @@ -22,6 +22,7 @@
from datetime import timedelta
from typing import Any, Callable, Final, TypeVar, cast, overload

from cachetools import TTLCache
from typing_extensions import TypeAlias

import streamlit as st
Expand All @@ -47,7 +48,6 @@
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
from streamlit.util import TimedCleanupCache

_LOGGER: Final = get_logger(__name__)

Expand Down Expand Up @@ -472,7 +472,7 @@ def __init__(
super().__init__()
self.key = key
self.display_name = display_name
self._mem_cache: TimedCleanupCache[str, MultiCacheResults] = TimedCleanupCache(
self._mem_cache: TTLCache[str, MultiCacheResults] = TTLCache(
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
)
self._mem_cache_lock = threading.Lock()
Expand Down
Expand Up @@ -16,6 +16,8 @@
import math
import threading

from cachetools import TTLCache

from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.storage.cache_storage_protocol import (
Expand All @@ -24,7 +26,6 @@
CacheStorageKeyNotFoundError,
)
from streamlit.runtime.stats import CacheStat
from streamlit.util import TimedCleanupCache

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext):
self.function_display_name = context.function_display_name
self._ttl_seconds = context.ttl_seconds
self._max_entries = context.max_entries
self._mem_cache: TimedCleanupCache[str, bytes] = TimedCleanupCache(
self._mem_cache: TTLCache[str, bytes] = TTLCache(
maxsize=self.max_entries,
ttl=self.ttl_seconds,
timer=cache_utils.TTLCACHE_TIMER,
Expand Down
5 changes: 3 additions & 2 deletions lib/streamlit/runtime/memory_session_storage.py
Expand Up @@ -16,8 +16,9 @@

from typing import MutableMapping

from cachetools import TTLCache

from streamlit.runtime.session_manager import SessionInfo, SessionStorage
from streamlit.util import TimedCleanupCache


class MemorySessionStorage(SessionStorage):
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
inaccessible and will be removed eventually.
"""

self._cache: MutableMapping[str, SessionInfo] = TimedCleanupCache(
self._cache: MutableMapping[str, SessionInfo] = TTLCache(
maxsize=maxsize, ttl=ttl_seconds
)

Expand Down
39 changes: 1 addition & 38 deletions lib/streamlit/util.py
Expand Up @@ -16,16 +16,13 @@

from __future__ import annotations

import asyncio
import dataclasses
import functools
import hashlib
import os
import subprocess
import sys
from typing import Any, Callable, Final, Generic, Iterable, Mapping, TypeVar

from cachetools import TTLCache
from typing import Any, Callable, Final, Iterable, Mapping, TypeVar

from streamlit import env_util

Expand Down Expand Up @@ -202,37 +199,3 @@ def extract_key_query_params(
]
for item in sublist
}


K = TypeVar("K")
V = TypeVar("V")


class TimedCleanupCache(TTLCache, Generic[K, V]):
"""A TTLCache that asynchronously expires its entries."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._task: asyncio.Task[Any] | None = None

def __setitem__(self, key: K, value: V) -> None:
# Set an expiration task to run periodically
# Can't be created in init because that only runs once and
# the event loop might not exist yet.
if self._task is None:
try:
self._task = asyncio.create_task(expire_cache(self))
except RuntimeError:
# Just continue if the event loop isn't started yet.
pass
super().__setitem__(key, value)

def __del__(self):
if self._task is not None:
self._task.cancel()


async def expire_cache(cache: TTLCache) -> None:
while True:
await asyncio.sleep(30)
cache.expire()
25 changes: 0 additions & 25 deletions lib/tests/streamlit/util_test.py
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import gc
import random
import unittest
from typing import Dict, List, Set
Expand Down Expand Up @@ -187,26 +185,3 @@ def test_calc_md5_can_handle_bytes_and_strings(self):
util.calc_md5("eventually bytes"),
util.calc_md5("eventually bytes".encode("utf-8")),
)

def test_timed_cleanup_cache_gc(self):
"""Test that the TimedCleanupCache does not leave behind tasks when
the cache is not externally reachable"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

async def create_cache():
cache = util.TimedCleanupCache(maxsize=2, ttl=10)
cache["foo"] = "bar"

# expire_cache and create_cache
assert len(asyncio.all_tasks()) > 1

asyncio.run(create_cache())

gc.collect()

async def check():
# Only has this function running
assert len(asyncio.all_tasks()) == 1

asyncio.run(check())