From 2a1f69eca191a98d805d8c59d64aa8cb964f7571 Mon Sep 17 00:00:00 2001 From: Asaurus1 Date: Sun, 25 Feb 2024 23:11:21 -0800 Subject: [PATCH] Explicit update() method for query_params --- lib/streamlit/runtime/state/query_params.py | 28 +++++++++++++++++-- .../runtime/state/query_params_proxy.py | 21 +++++++++++++- .../runtime/state/query_params_proxy_test.py | 5 ++++ .../runtime/state/query_params_test.py | 28 +++++++++++++++++-- 4 files changed, 77 insertions(+), 5 deletions(-) diff --git a/lib/streamlit/runtime/state/query_params.py b/lib/streamlit/runtime/state/query_params.py index eef89509acab..5687e96f9d76 100644 --- a/lib/streamlit/runtime/state/query_params.py +++ b/lib/streamlit/runtime/state/query_params.py @@ -15,13 +15,16 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Iterable, Iterator, MutableMapping +from typing import TYPE_CHECKING, Iterable, Iterator, MutableMapping from urllib import parse from streamlit.constants import EMBED_QUERY_PARAMS_KEYS from streamlit.errors import StreamlitAPIException from streamlit.proto.ForwardMsg_pb2 import ForwardMsg +if TYPE_CHECKING: + from _typeshed import SupportsKeysAndGetItem + @dataclass class QueryParams(MutableMapping[str, str]): @@ -62,6 +65,10 @@ def __getitem__(self, key: str) -> str: raise KeyError(missing_key_error_message(key)) def __setitem__(self, key: str, value: str | Iterable[str]) -> None: + self.__set_item_internal(key, value) + self._send_query_param_msg() + + def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None: if isinstance(value, dict): raise StreamlitAPIException( f"You cannot set a query params key `{key}` to a dictionary." @@ -77,7 +84,6 @@ def __setitem__(self, key: str, value: str | Iterable[str]) -> None: self._query_params[key] = [str(item) for item in value] else: self._query_params[key] = str(value) - self._send_query_param_msg() def __delitem__(self, key: str) -> None: try: @@ -88,6 +94,24 @@ def __delitem__(self, key: str) -> None: except KeyError: raise KeyError(missing_key_error_message(key)) + def update( + self, + other: Iterable[tuple[str, str]] | SupportsKeysAndGetItem[str, str] = (), + /, + **kwds: str, + ): + # an update function that only sends one ForwardMsg + # once all keys have been updated. + if hasattr(other, "keys") and hasattr(other, "__getitem__"): + for key in other.keys(): + self.__set_item_internal(key, other[key]) + else: + for (key, value) in other: + self.__set_item_internal(key, value) + for key, value in kwds.items(): + self.__set_item_internal(key, value) + self._send_query_param_msg() + def get_all(self, key: str) -> list[str]: self._ensure_single_query_api_used() if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS: diff --git a/lib/streamlit/runtime/state/query_params_proxy.py b/lib/streamlit/runtime/state/query_params_proxy.py index 3207875f350b..69c8444bc016 100644 --- a/lib/streamlit/runtime/state/query_params_proxy.py +++ b/lib/streamlit/runtime/state/query_params_proxy.py @@ -14,12 +14,15 @@ from __future__ import annotations -from typing import Iterable, Iterator, MutableMapping +from typing import TYPE_CHECKING, Iterable, Iterator, MutableMapping, overload from streamlit.runtime.metrics_util import gather_metrics from streamlit.runtime.state.query_params import missing_key_error_message from streamlit.runtime.state.session_state_proxy import get_session_state +if TYPE_CHECKING: + from _typeshed import SupportsKeysAndGetItem + class QueryParamsProxy(MutableMapping[str, str]): """ @@ -68,6 +71,22 @@ def __delattr__(self, key: str) -> None: except KeyError: raise AttributeError(missing_key_error_message(key)) + @overload + def update(self, mapping: SupportsKeysAndGetItem[str, str], /, **kwds: str): + ... + + @overload + def update(self, keys_and_values: Iterable[tuple[str, str]], /, **kwds: str): + ... + + @overload + def update(self, **kwds: str): + ... + + def update(self, other=(), /, **kwds): + with get_session_state().query_params() as qp: + qp.update(other, **kwds) + @gather_metrics("query_params.set_attr") def __setattr__(self, key: str, value: str | Iterable[str]) -> None: with get_session_state().query_params() as qp: diff --git a/lib/tests/streamlit/runtime/state/query_params_proxy_test.py b/lib/tests/streamlit/runtime/state/query_params_proxy_test.py index 9e73d2deab35..cb3ba29edd28 100644 --- a/lib/tests/streamlit/runtime/state/query_params_proxy_test.py +++ b/lib/tests/streamlit/runtime/state/query_params_proxy_test.py @@ -90,6 +90,11 @@ def test__setattr__sets_entry(self): self.query_params_proxy.key = "value" assert self.query_params_proxy["key"] == "value" + def test_update_sets_entries(self): + self.query_params_proxy.update({"key1": "value1", "key2": "value2"}) + assert self.query_params_proxy["key1"] == "value1" + assert self.query_params_proxy["key2"] == "value2" + def test__delattr__deletes_entry(self): del self.query_params_proxy.test assert "test" not in self.query_params_proxy diff --git a/lib/tests/streamlit/runtime/state/query_params_test.py b/lib/tests/streamlit/runtime/state/query_params_test.py index d6c1fc47f3a5..e1dfc4f53085 100644 --- a/lib/tests/streamlit/runtime/state/query_params_test.py +++ b/lib/tests/streamlit/runtime/state/query_params_test.py @@ -124,9 +124,33 @@ def test__setitem__raises_error_with_embed_key(self): with pytest.raises(StreamlitAPIException): self.query_params["embed"] = "true" - def test__setitem__raises_error_with_embed_options_key(self): + def test_update_adds_values(self): + self.query_params.update({"foo": "bar"}) + assert self.query_params.get("foo") == "bar" + message = self.get_message_from_queue(0) + assert "foo=bar" in message.page_info_changed.query_string + + def test_update_raises_error_with_embed_key(self): + with pytest.raises(StreamlitAPIException): + self.query_params.update({"foo": "bar", "embed": "true"}) + + def test_update_raises_error_with_embed_options_key(self): with pytest.raises(StreamlitAPIException): - self.query_params["embed_options"] = "disable_scrolling" + self.query_params.update({"foo": "bar", "embed_options": "show_toolbar"}) + + def test_update_raises_exception_with_dictionary_value(self): + with pytest.raises(StreamlitAPIException): + self.query_params.update({"a_dict": {"test": "test"}}) + + def test_update_changes_values_in_single_message(self): + self.query_params.set_with_no_forward_msg("foo", "test") + self.query_params.update({"foo": "bar", "baz": "test"}) + assert self.query_params.get("foo") == "bar" + assert self.query_params.get("baz") == "test" + assert len(self.forward_msg_queue) == 1 + message = self.get_message_from_queue(0) + assert "foo=bar" in message.page_info_changed.query_string + assert "baz=test" in message.page_info_changed.query_string def test__delitem__removes_existing_key(self): del self.query_params["foo"]