Skip to content

Commit

Permalink
Improve st.connection caching behavior (streamlit#7730)
Browse files Browse the repository at this point in the history
* hack to compute connection cache function keys based on ttl, max_entries and dsn

* Fix ttl and connection name behavior for all connections and factory

* Add explanatory comments

* More explanatory comment tweaks

* Avoid adding new `.` characters to function qualnames

---------

Co-authored-by: Vincent Donato <vincent@streamlit.io>
  • Loading branch information
2 people authored and Your Name committed Mar 22, 2024
1 parent a4fde14 commit 135ac0b
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 18 deletions.
17 changes: 13 additions & 4 deletions lib/streamlit/connections/snowflake_connection.py
Expand Up @@ -199,15 +199,24 @@ def query(
),
wait=wait_fixed(1),
)
@cache_data(
show_spinner=show_spinner,
ttl=ttl,
)
def _query(sql: str) -> pd.DataFrame:
cur = self._instance.cursor()
cur.execute(sql, params=params, **kwargs)
return cur.fetch_pandas_all()

# We modify our helper function's `__qualname__` here to work around default
# `@st.cache_data` behavior. Otherwise, `.query()` being called with different
# `ttl` values will reset the cache with each call, and the query caches won't
# be scoped by connection.
ttl_str = str( # Avoid adding extra `.` characters to `__qualname__`
ttl
).replace(".", "_")
_query.__qualname__ = f"{_query.__qualname__}_{self._connection_name}_{ttl_str}"
_query = cache_data(
show_spinner=show_spinner,
ttl=ttl,
)(_query)

return _query(sql)

def write_pandas(
Expand Down
17 changes: 13 additions & 4 deletions lib/streamlit/connections/snowpark_connection.py
Expand Up @@ -144,14 +144,23 @@ def query(
retry=retry_if_exception_type(SnowparkServerException),
wait=wait_fixed(1),
)
@cache_data(
show_spinner="Running `snowpark.query(...)`.",
ttl=ttl,
)
def _query(sql: str) -> pd.DataFrame:
with self._lock:
return self._instance.sql(sql).to_pandas()

# We modify our helper function's `__qualname__` here to work around default
# `@st.cache_data` behavior. Otherwise, `.query()` being called with different
# `ttl` values will reset the cache with each call, and the query caches won't
# be scoped by connection.
ttl_str = str( # Avoid adding extra `.` characters to `__qualname__`
ttl
).replace(".", "_")
_query.__qualname__ = f"{_query.__qualname__}_{self._connection_name}_{ttl_str}"
_query = cache_data(
show_spinner="Running `snowpark.query(...)`.",
ttl=ttl,
)(_query)

return _query(sql)

@property
Expand Down
17 changes: 13 additions & 4 deletions lib/streamlit/connections/sql_connection.py
Expand Up @@ -200,10 +200,6 @@ def query(
),
wait=wait_fixed(1),
)
@cache_data(
show_spinner=show_spinner,
ttl=ttl,
)
def _query(
sql: str,
index_col=None,
Expand All @@ -221,6 +217,19 @@ def _query(
**kwargs,
)

# We modify our helper function's `__qualname__` here to work around default
# `@st.cache_data` behavior. Otherwise, `.query()` being called with different
# `ttl` values will reset the cache with each call, and the query caches won't
# be scoped by connection.
ttl_str = str( # Avoid adding extra `.` characters to `__qualname__`
ttl
).replace(".", "_")
_query.__qualname__ = f"{_query.__qualname__}_{self._connection_name}_{ttl_str}"
_query = cache_data(
show_spinner=show_spinner,
ttl=ttl,
)(_query)

return _query(
sql,
index_col=index_col,
Expand Down
20 changes: 15 additions & 5 deletions lib/streamlit/runtime/connection_factory.py
Expand Up @@ -76,11 +76,6 @@ def _create_connection(
* Allow the user to specify ttl and max_entries when calling st.connection.
"""

@cache_resource(
max_entries=max_entries,
show_spinner="Running `st.connection(...)`.",
ttl=ttl,
)
def __create_connection(
name: str, connection_class: Type[ConnectionClass], **kwargs
) -> ConnectionClass:
Expand All @@ -91,6 +86,21 @@ def __create_connection(
f"{connection_class} is not a subclass of BaseConnection!"
)

# We modify our helper function's `__qualname__` here to work around default
# `@st.cache_resource` behavior. Otherwise, `st.connection` being called with
# different `ttl` or `max_entries` values will reset the cache with each call.
ttl_str = str(ttl).replace( # Avoid adding extra `.` characters to `__qualname__`
".", "_"
)
__create_connection.__qualname__ = (
f"{__create_connection.__qualname__}_{ttl_str}_{max_entries}"
)
__create_connection = cache_resource(
max_entries=max_entries,
show_spinner="Running `st.connection(...)`.",
ttl=ttl,
)(__create_connection)

return __create_connection(name, connection_class, **kwargs)


Expand Down
46 changes: 45 additions & 1 deletion lib/tests/streamlit/connections/snowflake_connection_test.py
Expand Up @@ -93,7 +93,6 @@ def test_query_caches_value(self):

mock_cursor = MagicMock()
mock_cursor.fetch_pandas_all = MagicMock(return_value="i am a dataframe")

conn = SnowflakeConnection("my_snowflake_connection")
conn._instance.cursor.return_value = mock_cursor

Expand All @@ -103,6 +102,51 @@ def test_query_caches_value(self):
conn._instance.cursor.assert_called_once()
mock_cursor.execute.assert_called_once_with("SELECT 1;", params=None)

@patch(
"streamlit.connections.snowflake_connection.SnowflakeConnection._connect",
MagicMock(),
)
def test_does_not_reset_cache_when_ttl_changes(self):
# Caching functions rely on an active script run ctx
add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx())

mock_cursor = MagicMock()
mock_cursor.fetch_pandas_all = MagicMock(return_value="i am a dataframe")
conn = SnowflakeConnection("my_snowflake_connection")
conn._instance.cursor.return_value = mock_cursor

conn.query("SELECT 1;", ttl=10)
conn.query("SELECT 2;", ttl=20)
conn.query("SELECT 1;", ttl=10)
conn.query("SELECT 2;", ttl=20)

assert conn._instance.cursor.call_count == 2
assert mock_cursor.execute.call_count == 2

@patch(
"streamlit.connections.snowflake_connection.SnowflakeConnection._connect",
MagicMock(),
)
def test_scopes_caches_by_connection_name(self):
# Caching functions rely on an active script run ctx
add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx())
mock_cursor = MagicMock()
mock_cursor.fetch_pandas_all = MagicMock(return_value="i am a dataframe")

conn1 = SnowflakeConnection("my_snowflake_connection1")
conn1._instance.cursor.return_value = mock_cursor
conn2 = SnowflakeConnection("my_snowflake_connection2")
conn2._instance.cursor.return_value = mock_cursor

conn1.query("SELECT 1;")
conn1.query("SELECT 1;")
conn2.query("SELECT 1;")
conn2.query("SELECT 1;")

assert conn1._instance.cursor is conn2._instance.cursor
assert conn1._instance.cursor.call_count == 2
assert mock_cursor.execute.call_count == 2

@patch(
"streamlit.connections.snowflake_connection.SnowflakeConnection._connect",
MagicMock(),
Expand Down
43 changes: 43 additions & 0 deletions lib/tests/streamlit/connections/snowpark_connection_test.py
Expand Up @@ -100,6 +100,49 @@ def test_query_caches_value(self):
assert conn.query("SELECT 1;") == "i am a dataframe"
conn._instance.sql.assert_called_once()

@patch(
"streamlit.connections.snowpark_connection.SnowparkConnection._connect",
MagicMock(),
)
def test_does_not_reset_cache_when_ttl_changes(self):
# Caching functions rely on an active script run ctx
add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx())

mock_sql_return = MagicMock()
mock_sql_return.to_pandas = MagicMock(return_value="i am a dataframe")
conn = SnowparkConnection("my_snowpark_connection")
conn._instance.sql.return_value = mock_sql_return

conn.query("SELECT 1;", ttl=10)
conn.query("SELECT 2;", ttl=20)
conn.query("SELECT 1;", ttl=10)
conn.query("SELECT 2;", ttl=20)

assert conn._instance.sql.call_count == 2

@patch(
"streamlit.connections.snowpark_connection.SnowparkConnection._connect",
MagicMock(),
)
def test_scopes_caches_by_connection_name(self):
# Caching functions rely on an active script run ctx
add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx())
mock_sql_return = MagicMock()
mock_sql_return.to_pandas = MagicMock(return_value="i am a dataframe")

conn1 = SnowparkConnection("my_snowpark_connection1")
conn1._instance.sql.return_value = mock_sql_return
conn2 = SnowparkConnection("my_snowpark_connection2")
conn2._instance.sql.return_value = mock_sql_return

conn1.query("SELECT 1;")
conn1.query("SELECT 1;")
conn2.query("SELECT 1;")
conn2.query("SELECT 1;")

assert conn1._instance.sql is conn2._instance.sql
assert conn1._instance.sql.call_count == 2

@patch(
"streamlit.connections.snowpark_connection.SnowparkConnection._connect",
MagicMock(),
Expand Down
33 changes: 33 additions & 0 deletions lib/tests/streamlit/connections/sql_connection_test.py
Expand Up @@ -141,6 +141,39 @@ def test_query_caches_value(self, patched_read_sql):
assert conn.query("SELECT 1;") == "i am a dataframe"
patched_read_sql.assert_called_once()

@patch("streamlit.connections.sql_connection.SQLConnection._connect", MagicMock())
@patch("streamlit.connections.sql_connection.pd.read_sql")
def test_does_not_reset_cache_when_ttl_changes(self, patched_read_sql):
# Caching functions rely on an active script run ctx
add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx())
patched_read_sql.return_value = "i am a dataframe"

conn = SQLConnection("my_sql_connection")

conn.query("SELECT 1;", ttl=10)
conn.query("SELECT 2;", ttl=20)
conn.query("SELECT 1;", ttl=10)
conn.query("SELECT 2;", ttl=20)

assert patched_read_sql.call_count == 2

@patch("streamlit.connections.sql_connection.SQLConnection._connect", MagicMock())
@patch("streamlit.connections.sql_connection.pd.read_sql")
def test_scopes_caches_by_connection_name(self, patched_read_sql):
# Caching functions rely on an active script run ctx
add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx())
patched_read_sql.return_value = "i am a dataframe"

conn1 = SQLConnection("my_sql_connection1")
conn2 = SQLConnection("my_sql_connection2")

conn1.query("SELECT 1;")
conn1.query("SELECT 1;")
conn2.query("SELECT 1;")
conn2.query("SELECT 1;")

assert patched_read_sql.call_count == 2

@patch("streamlit.connections.sql_connection.SQLConnection._connect", MagicMock())
def test_repr_html_(self):
conn = SQLConnection("my_sql_connection")
Expand Down
22 changes: 22 additions & 0 deletions lib/tests/streamlit/runtime/connection_factory_test.py
Expand Up @@ -184,6 +184,28 @@ def test_caches_connection_instance(self):
conn = connection_factory("my_connection", MockConnection)
assert connection_factory("my_connection", MockConnection) is conn

def test_does_not_clear_cache_when_ttl_changes(self):
with patch.object(
MockConnection, "__init__", return_value=None
) as patched_init:
connection_factory("my_connection1", MockConnection, ttl=10)
connection_factory("my_connection2", MockConnection, ttl=20)
connection_factory("my_connection1", MockConnection, ttl=10)
connection_factory("my_connection2", MockConnection, ttl=20)

assert patched_init.call_count == 2

def test_does_not_clear_cache_when_max_entries_changes(self):
with patch.object(
MockConnection, "__init__", return_value=None
) as patched_init:
connection_factory("my_connection1", MockConnection, max_entries=10)
connection_factory("my_connection2", MockConnection, max_entries=20)
connection_factory("my_connection1", MockConnection, max_entries=10)
connection_factory("my_connection2", MockConnection, max_entries=20)

assert patched_init.call_count == 2

@parameterized.expand(
[
("MySQLdb", "mysqlclient"),
Expand Down

0 comments on commit 135ac0b

Please sign in to comment.