diff --git a/lib/streamlit/runtime/state/query_params.py b/lib/streamlit/runtime/state/query_params.py index 5687e96f9d76..907207b7937f 100644 --- a/lib/streamlit/runtime/state/query_params.py +++ b/lib/streamlit/runtime/state/query_params.py @@ -65,6 +65,7 @@ def __getitem__(self, key: str) -> str: raise KeyError(missing_key_error_message(key)) def __setitem__(self, key: str, value: str | Iterable[str]) -> None: + self._ensure_single_query_api_used() self.__set_item_internal(key, value) self._send_query_param_msg() @@ -86,6 +87,7 @@ def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None: self._query_params[key] = str(value) def __delitem__(self, key: str) -> None: + self._ensure_single_query_api_used() try: if key in EMBED_QUERY_PARAMS_KEYS: raise KeyError(missing_key_error_message(key)) @@ -100,13 +102,14 @@ def update( /, **kwds: str, ): - # an update function that only sends one ForwardMsg - # once all keys have been updated. + # This overrides the `update` provided by MutableMapping + # to ensure only one one ForwardMsg is sent. + self._ensure_single_query_api_used() if hasattr(other, "keys") and hasattr(other, "__getitem__"): for key in other.keys(): self.__set_item_internal(key, other[key]) else: - for (key, value) in other: + for key, value in other: self.__set_item_internal(key, value) for key, value in kwds.items(): self.__set_item_internal(key, value) @@ -146,12 +149,8 @@ def _send_query_param_msg(self) -> None: ctx.enqueue(msg) def clear(self) -> None: - new_query_params = {} - for key, value in self._query_params.items(): - if key in EMBED_QUERY_PARAMS_KEYS: - new_query_params[key] = value - self._query_params = new_query_params - + self._ensure_single_query_api_used() + self.clear_with_no_forward_msg(preserve_embed=True) self._send_query_param_msg() def to_dict(self) -> dict[str, str]: @@ -163,11 +162,29 @@ def to_dict(self) -> dict[str, str]: if key not in EMBED_QUERY_PARAMS_KEYS } + def from_dict( + self, + _dict: Iterable[tuple[str, str]] | SupportsKeysAndGetItem[str, str], + ): + self._ensure_single_query_api_used() + old_value = self._query_params.copy() + self.clear_with_no_forward_msg(preserve_embed=True) + try: + self.update(_dict) + except StreamlitAPIException: + # restore the original from before we made any changes. + self._query_params = old_value + raise + def set_with_no_forward_msg(self, key: str, val: list[str] | str) -> None: self._query_params[key] = val - def clear_with_no_forward_msg(self) -> None: - self._query_params.clear() + def clear_with_no_forward_msg(self, preserve_embed: bool = False) -> None: + self._query_params = { + key: value + for key, value in self._query_params.items() + if key in EMBED_QUERY_PARAMS_KEYS and preserve_embed + } def _ensure_single_query_api_used(self): # Avoid circular imports diff --git a/lib/streamlit/runtime/state/query_params_proxy.py b/lib/streamlit/runtime/state/query_params_proxy.py index 69c8444bc016..430ccca452f9 100644 --- a/lib/streamlit/runtime/state/query_params_proxy.py +++ b/lib/streamlit/runtime/state/query_params_proxy.py @@ -72,18 +72,33 @@ def __delattr__(self, key: str) -> None: raise AttributeError(missing_key_error_message(key)) @overload - def update(self, mapping: SupportsKeysAndGetItem[str, str], /, **kwds: str): + def update(self, mapping: SupportsKeysAndGetItem[str, str], /, **kwds: str) -> None: ... @overload - def update(self, keys_and_values: Iterable[tuple[str, str]], /, **kwds: str): + def update( + self, keys_and_values: Iterable[tuple[str, str]], /, **kwds: str + ) -> None: ... @overload - def update(self, **kwds: str): + def update(self, **kwds: str) -> None: ... def update(self, other=(), /, **kwds): + """ + Update one or more values in query_params at once from a dictionary or + dictionary-like object. + + See `Mapping.update()` from Python's `collections` library. + + Parameters + ---------- + other: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]] + A dictionary or mapping of strings to strings. + **kwds: str + Additional key/value pairs to update passed as keyword arguments. + """ with get_session_state().query_params() as qp: qp.update(other, **kwds) @@ -146,3 +161,36 @@ def to_dict(self) -> dict[str, str]: """ with get_session_state().query_params() as qp: return qp.to_dict() + + @overload + def from_dict(self, keys_and_values: Iterable[tuple[str, str]]) -> None: + ... + + @overload + def from_dict(self, mapping: SupportsKeysAndGetItem[str, str]) -> None: + ... + + @gather_metrics("query_params.from_dict") + def from_dict(self, other): + """ + Set all of the query parameters from a dictionary or dictionary-like object. + + This method primarily exists for advanced users who want to be able to control + multiple query string parameters in a single update. To set individual + query string parameters you should still use `st.query_params["parameter"] = "value"` + or `st.query_params.parameter = "value"`. + + `embed` and `embed_options` may not be set via this method and may not be keys in the + `other` dictionary. + + Note that this method is NOT a direct inverse of `st.query_params.to_dict()` when + the URL query string contains multiple values for a single key. A true inverse + operation for from_dict is `{key: st.query_params.get_all(key) for key st.query_params}`. + + Parameters + ------- + other: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]] + A dictionary used to replace the current query_params. + """ + with get_session_state().query_params() as qp: + return qp.from_dict(other) diff --git a/lib/tests/streamlit/runtime/state/query_params_proxy_test.py b/lib/tests/streamlit/runtime/state/query_params_proxy_test.py index cb3ba29edd28..3e9acd03f0c0 100644 --- a/lib/tests/streamlit/runtime/state/query_params_proxy_test.py +++ b/lib/tests/streamlit/runtime/state/query_params_proxy_test.py @@ -106,3 +106,18 @@ def test__getattr__raises_Attribute_exception(self): def test__delattr__raises_Attribute_exception(self): with pytest.raises(AttributeError): del self.query_params_proxy.nonexistent + + def test_to_dict(self): + self.query_params_proxy["test_multi"] = ["value1", "value2"] + assert self.query_params_proxy.to_dict() == { + "test": "value", + "test_multi": "value2", + } + + def test_from_dict(self): + new_dict = {"test_new": "value_new", "test_multi": ["value1", "value2"]} + self.query_params_proxy.from_dict(new_dict) + assert self.query_params_proxy.test_new == "value_new" + assert self.query_params_proxy["test_multi"] == "value2" + assert self.query_params_proxy.get_all("test_multi") == ["value1", "value2"] + assert len(self.query_params_proxy) == 2 diff --git a/lib/tests/streamlit/runtime/state/query_params_test.py b/lib/tests/streamlit/runtime/state/query_params_test.py index e1dfc4f53085..6aadb649d740 100644 --- a/lib/tests/streamlit/runtime/state/query_params_test.py +++ b/lib/tests/streamlit/runtime/state/query_params_test.py @@ -233,6 +233,79 @@ def test_to_dict_doesnt_include_embed_params(self): result_dict = {"foo": "bar"} assert self.query_params.to_dict() == result_dict + def test_from_dict(self): + result_dict = {"hello": "world"} + self.query_params.from_dict(result_dict) + assert self.query_params.to_dict() == result_dict + + def test_from_dict_iterable(self): + self.query_params.from_dict((("key1", 5), ("key2", 6))) + assert self.query_params._query_params == {"key1": "5", "key2": "6"} + + def test_from_dict_mixed_values(self): + result_dict = {"hello": ["world", "janice", "amy"], "snow": "flake"} + self.query_params.from_dict(result_dict) + + # self.query_params.to_dict() has behavior consistent with fetching values using + # self.query_params["some_key"]. That is, if the value is an array, the last + # element of the array is returned rather than the array in its entirety. + assert self.query_params.to_dict() == {"hello": "amy", "snow": "flake"} + + result_as_list = {"hello": ["world", "janice", "amy"], "snow": ["flake"]} + qp_as_list = {key: self.query_params.get_all(key) for key in self.query_params} + assert result_as_list == qp_as_list + + def test_from_dict_preserves_embed_keys(self): + self.query_params._query_params.update( + {"embed_options": ["disable_scrolling", "show_colored_line"]} + ) + self.query_params.from_dict({"a": "b", "c": "d"}) + assert self.query_params._query_params == { + "a": "b", + "c": "d", + "embed_options": ["disable_scrolling", "show_colored_line"], + } + + def test_from_dict_preserves_last_value_on_error(self): + old_value = self.query_params._query_params + with pytest.raises(StreamlitAPIException): + self.query_params.from_dict({"a": "b", "embed": False}) + assert self.query_params._query_params == old_value + + def test_from_dict_changes_values_in_single_message(self): + self.query_params.set_with_no_forward_msg("hello", "world") + self.query_params.from_dict({"foo": "bar", "baz": "test"}) + assert self.query_params.get("foo") == "bar" + assert self.query_params.get("baz") == "test" + assert len(self.forward_msg_queue) == 1 + message = self.get_message_from_queue(0) + assert message.page_info_changed.query_string == "foo=bar&baz=test" + + def test_from_dict_raises_error_with_embed_key(self): + with pytest.raises(StreamlitAPIException): + self.query_params.from_dict({"foo": "bar", "embed": "true"}) + + def test_from_dict_raises_error_with_embed_options_key(self): + with pytest.raises(StreamlitAPIException): + self.query_params.from_dict({"foo": "bar", "embed_options": "show_toolbar"}) + + def test_from_dict_raises_exception_with_dictionary_value(self): + with pytest.raises(StreamlitAPIException): + self.query_params.from_dict({"a_dict": {"test": "test"}}) + + def test_from_dict_inverse(self): + self.query_params.from_dict({"a": "b", "c": "d"}) + assert self.query_params._query_params == {"a": "b", "c": "d"} + message = self.get_message_from_queue(0) + assert message.page_info_changed.query_string == "a=b&c=d" + from_dict_inverse = { + key: self.query_params.get_all(key) for key in self.query_params + } + self.query_params.from_dict(from_dict_inverse) + assert self.query_params._query_params == {"a": ["b"], "c": ["d"]} + message = self.get_message_from_queue(0) + assert message.page_info_changed.query_string == "a=b&c=d" + def test_set_with_no_forward_msg_sends_no_msg_and_sets_query_params(self): self.query_params.set_with_no_forward_msg("test", "test") assert self.query_params["test"] == "test" @@ -267,8 +340,23 @@ def test_set_with_no_forward_msg_accepts_multiple_embed_options(self): self.get_message_from_queue(0) def test_clear_with_no_forward_msg_sends_no_msg_and_clears_query_params(self): + self.query_params._query_params.update( + {"embed_options": ["disable_scrolling", "show_colored_line"]} + ) self.query_params.clear_with_no_forward_msg() assert len(self.query_params) == 0 + assert len(self.query_params._query_params) == 0 with pytest.raises(IndexError): # no forward message should be sent self.get_message_from_queue(0) + + def test_clear_with_no_forward_msg_preserve_embed_keys(self): + self.query_params._query_params.update( + {"embed_options": ["disable_scrolling", "show_colored_line"]} + ) + self.query_params.clear_with_no_forward_msg(preserve_embed=True) + assert len(self.query_params) == 0 + assert len(self.query_params._query_params) == 1 + assert self.query_params._query_params["embed_options"] == ( + ["disable_scrolling", "show_colored_line"] + )