diff --git a/lib/streamlit/caching/__init__.py b/lib/streamlit/caching/__init__.py index 65885a0dc7da..3e3363491b42 100644 --- a/lib/streamlit/caching/__init__.py +++ b/lib/streamlit/caching/__init__.py @@ -14,8 +14,8 @@ import contextlib from typing import Iterator -from .memo_decorator import MEMO_CALL_STACK -from .singleton_decorator import SINGLETON_CALL_STACK +from .memo_decorator import MEMO_CALL_STACK, MemoCache +from .singleton_decorator import SINGLETON_CALL_STACK, SingletonCache def maybe_show_cached_st_function_warning(dg, st_func_name: str) -> None: @@ -29,6 +29,14 @@ def suppress_cached_st_function_warning() -> Iterator[None]: yield +def clear_singleton_cache() -> None: + SingletonCache.clear_all() + + +def clear_memo_cache() -> None: + MemoCache.clear_all() + + # Explicitly export `memo` and `singleton` from .memo_decorator import memo as memo from .singleton_decorator import singleton as singleton diff --git a/lib/streamlit/cli.py b/lib/streamlit/cli.py index 30f168dabcda..90b5a303af58 100644 --- a/lib/streamlit/cli.py +++ b/lib/streamlit/cli.py @@ -229,8 +229,9 @@ def cache(): @cache.command("clear") def cache_clear(): - """Clear the Streamlit on-disk cache.""" + """Clear st.cache, st.memo, and st.singleton caches.""" import streamlit.legacy_caching + import streamlit.caching result = streamlit.legacy_caching.clear_cache() cache_path = streamlit.legacy_caching.get_cache_path() @@ -239,6 +240,9 @@ def cache_clear(): else: print("Nothing to clear at %s." % cache_path) + streamlit.caching.clear_memo_cache() + streamlit.caching.clear_singleton_cache() + # SUBCOMMAND: config diff --git a/lib/streamlit/report_session.py b/lib/streamlit/report_session.py index c16e321521f4..97105c724429 100644 --- a/lib/streamlit/report_session.py +++ b/lib/streamlit/report_session.py @@ -22,7 +22,7 @@ import streamlit.elements.exception as exception_utils import streamlit.server.server_util as server_util -from streamlit import __version__, config, legacy_caching, secrets, url_util +from streamlit import __version__, config, legacy_caching, secrets, url_util, caching from streamlit.case_converters import to_snake_case from streamlit.credentials import Credentials from streamlit.in_memory_file_manager import in_memory_file_manager @@ -509,12 +509,9 @@ def handle_clear_cache_request(self): Because this cache is global, it will be cleared for all users. """ - # Setting verbose=True causes clear_cache to print to stdout. - # Since this command was initiated from the browser, the user - # doesn't need to see the results of the command in their - # terminal. legacy_caching.clear_cache() - + caching.clear_memo_cache() + caching.clear_singleton_cache() self._session_state.clear_state() def handle_set_run_on_save_request(self, new_value): diff --git a/lib/tests/streamlit/caching/common_cache_test.py b/lib/tests/streamlit/caching/common_cache_test.py index a848c2e01384..01c66cbe1530 100644 --- a/lib/tests/streamlit/caching/common_cache_test.py +++ b/lib/tests/streamlit/caching/common_cache_test.py @@ -22,7 +22,12 @@ import streamlit as st from streamlit import report_thread -from streamlit.caching import MEMO_CALL_STACK, SINGLETON_CALL_STACK +from streamlit.caching import ( + MEMO_CALL_STACK, + SINGLETON_CALL_STACK, + clear_memo_cache, + clear_singleton_cache, +) memo = st.experimental_memo singleton = st.experimental_singleton @@ -255,3 +260,41 @@ def thread_test(): # The other thread should not have modified the main thread self.assertEqual(1, get_counter()) + + @parameterized.expand( + [ + ("memo", memo, clear_memo_cache), + ("singleton", singleton, clear_singleton_cache), + ] + ) + def test_clear_all_caches(self, _, cache_decorator, clear_cache_func): + """Calling a cache's global `clear_all` function should remove all + items from all caches of the appropriate type. + """ + foo_vals = [] + + @cache_decorator + def foo(x): + foo_vals.append(x) + return x + + bar_vals = [] + + @cache_decorator + def bar(x): + bar_vals.append(x) + return x + + foo(0), foo(1), foo(2) + bar(0), bar(1), bar(2) + self.assertEqual([0, 1, 2], foo_vals) + self.assertEqual([0, 1, 2], bar_vals) + + # Clear the cache and access our original values again. They + # should be recomputed. + clear_cache_func() + + foo(0), foo(1), foo(2) + bar(0), bar(1), bar(2) + self.assertEqual([0, 1, 2, 0, 1, 2], foo_vals) + self.assertEqual([0, 1, 2, 0, 1, 2], bar_vals) diff --git a/lib/tests/streamlit/caching/memo_test.py b/lib/tests/streamlit/caching/memo_test.py index e1db77021007..06bb013a7a4d 100644 --- a/lib/tests/streamlit/caching/memo_test.py +++ b/lib/tests/streamlit/caching/memo_test.py @@ -20,8 +20,9 @@ import streamlit as st from streamlit import StreamlitAPIException, file_util -from streamlit.caching import memo_decorator +from streamlit.caching import memo_decorator, clear_memo_cache from streamlit.caching.cache_errors import CacheError +from streamlit.caching.memo_decorator import get_cache_path class MemoTest(unittest.TestCase): @@ -139,3 +140,19 @@ def foo(): "Unsupported persist option 'yesplz'. Valid values are 'disk' or None.", str(e.exception), ) + + @patch("shutil.rmtree") + def test_clear_disk_cache(self, mock_rmtree): + """`clear_all` should remove the disk cache directory if it exists.""" + + # If the cache dir exists, we should delete it. + with patch("os.path.isdir", MagicMock(return_value=True)): + clear_memo_cache() + mock_rmtree.assert_called_once_with(get_cache_path()) + + mock_rmtree.reset_mock() + + # If the cache dir does not exist, we shouldn't try to delete it. + with patch("os.path.isdir", MagicMock(return_value=False)): + clear_memo_cache() + mock_rmtree.assert_not_called() diff --git a/lib/tests/streamlit/cli_test.py b/lib/tests/streamlit/cli_test.py index becd69df9931..0febfbf5a43d 100644 --- a/lib/tests/streamlit/cli_test.py +++ b/lib/tests/streamlit/cli_test.py @@ -347,6 +347,18 @@ def test_config_show_command_with_flag_config_options(self): self.assertEqual(kwargs["flag_options"]["server_port"], 8502) self.assertEqual(0, result.exit_code) + @patch("streamlit.legacy_caching.clear_cache") + @patch("streamlit.caching.clear_memo_cache") + @patch("streamlit.caching.clear_singleton_cache") + def test_cache_clear_all_caches( + self, clear_singleton_cache, clear_memo_cache, clear_legacy_cache + ): + """cli.clear_cache should clear st.cache, st.memo and st.singleton""" + self.runner.invoke(cli, ["cache", "clear"]) + clear_singleton_cache.assert_called_once() + clear_memo_cache.assert_called_once() + clear_legacy_cache.assert_called_once() + @patch("builtins.print") def test_cache_clear_command_with_cache(self, mock_print): """Tests clear cache announces that cache is cleared when completed""" diff --git a/lib/tests/streamlit/report_session_test.py b/lib/tests/streamlit/report_session_test.py index cc2836ed3bb7..a241d275dff3 100644 --- a/lib/tests/streamlit/report_session_test.py +++ b/lib/tests/streamlit/report_session_test.py @@ -144,14 +144,25 @@ def test_creates_session_state_on_init(self, _): rs = ReportSession(None, "", "", UploadedFileManager(), None) self.assertTrue(isinstance(rs.session_state, SessionState)) - @patch("streamlit.report_session.legacy_caching.clear_cache") @patch("streamlit.report_session.LocalSourcesWatcher") - def test_clear_cache_resets_session_state(self, _1, _2): + def test_clear_cache_resets_session_state(self, _1): rs = ReportSession(None, "", "", UploadedFileManager(), None) rs._session_state["foo"] = "bar" rs.handle_clear_cache_request() self.assertTrue("foo" not in rs._session_state) + @patch("streamlit.legacy_caching.clear_cache") + @patch("streamlit.caching.clear_memo_cache") + @patch("streamlit.caching.clear_singleton_cache") + def test_clear_cache_all_caches( + self, clear_singleton_cache, clear_memo_cache, clear_legacy_cache + ): + rs = ReportSession(MagicMock(), "", "", UploadedFileManager(), None) + rs.handle_clear_cache_request() + clear_singleton_cache.assert_called_once() + clear_memo_cache.assert_called_once() + clear_legacy_cache.assert_called_once() + @patch("streamlit.report_session.secrets._file_change_listener.connect") @patch("streamlit.report_session.LocalSourcesWatcher") def test_request_rerun_on_secrets_file_change(self, _, patched_connect):