Skip to content

Commit

Permalink
Explicit update() method for query_params
Browse files Browse the repository at this point in the history
  • Loading branch information
Asaurus1 committed Mar 4, 2024
1 parent ee05bfc commit 2a1f69e
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 5 deletions.
28 changes: 26 additions & 2 deletions lib/streamlit/runtime/state/query_params.py
Expand Up @@ -15,13 +15,16 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Iterable, Iterator, MutableMapping
from typing import TYPE_CHECKING, Iterable, Iterator, MutableMapping
from urllib import parse

from streamlit.constants import EMBED_QUERY_PARAMS_KEYS
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg

if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem


@dataclass
class QueryParams(MutableMapping[str, str]):
Expand Down Expand Up @@ -62,6 +65,10 @@ def __getitem__(self, key: str) -> str:
raise KeyError(missing_key_error_message(key))

def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
self.__set_item_internal(key, value)
self._send_query_param_msg()

def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
if isinstance(value, dict):
raise StreamlitAPIException(
f"You cannot set a query params key `{key}` to a dictionary."
Expand All @@ -77,7 +84,6 @@ def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
self._query_params[key] = [str(item) for item in value]
else:
self._query_params[key] = str(value)
self._send_query_param_msg()

def __delitem__(self, key: str) -> None:
try:
Expand All @@ -88,6 +94,24 @@ def __delitem__(self, key: str) -> None:
except KeyError:
raise KeyError(missing_key_error_message(key))

def update(
self,
other: Iterable[tuple[str, str]] | SupportsKeysAndGetItem[str, str] = (),
/,
**kwds: str,
):
# an update function that only sends one ForwardMsg
# once all keys have been updated.
if hasattr(other, "keys") and hasattr(other, "__getitem__"):
for key in other.keys():
self.__set_item_internal(key, other[key])
else:
for (key, value) in other:
self.__set_item_internal(key, value)
for key, value in kwds.items():
self.__set_item_internal(key, value)
self._send_query_param_msg()

def get_all(self, key: str) -> list[str]:
self._ensure_single_query_api_used()
if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS:
Expand Down
21 changes: 20 additions & 1 deletion lib/streamlit/runtime/state/query_params_proxy.py
Expand Up @@ -14,12 +14,15 @@

from __future__ import annotations

from typing import Iterable, Iterator, MutableMapping
from typing import TYPE_CHECKING, Iterable, Iterator, MutableMapping, overload

from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.state.query_params import missing_key_error_message
from streamlit.runtime.state.session_state_proxy import get_session_state

if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem


class QueryParamsProxy(MutableMapping[str, str]):
"""
Expand Down Expand Up @@ -68,6 +71,22 @@ def __delattr__(self, key: str) -> None:
except KeyError:
raise AttributeError(missing_key_error_message(key))

@overload
def update(self, mapping: SupportsKeysAndGetItem[str, str], /, **kwds: str):
...

@overload
def update(self, keys_and_values: Iterable[tuple[str, str]], /, **kwds: str):
...

@overload
def update(self, **kwds: str):
...

def update(self, other=(), /, **kwds):
with get_session_state().query_params() as qp:
qp.update(other, **kwds)

@gather_metrics("query_params.set_attr")
def __setattr__(self, key: str, value: str | Iterable[str]) -> None:
with get_session_state().query_params() as qp:
Expand Down
5 changes: 5 additions & 0 deletions lib/tests/streamlit/runtime/state/query_params_proxy_test.py
Expand Up @@ -90,6 +90,11 @@ def test__setattr__sets_entry(self):
self.query_params_proxy.key = "value"
assert self.query_params_proxy["key"] == "value"

def test_update_sets_entries(self):
self.query_params_proxy.update({"key1": "value1", "key2": "value2"})
assert self.query_params_proxy["key1"] == "value1"
assert self.query_params_proxy["key2"] == "value2"

def test__delattr__deletes_entry(self):
del self.query_params_proxy.test
assert "test" not in self.query_params_proxy
Expand Down
28 changes: 26 additions & 2 deletions lib/tests/streamlit/runtime/state/query_params_test.py
Expand Up @@ -124,9 +124,33 @@ def test__setitem__raises_error_with_embed_key(self):
with pytest.raises(StreamlitAPIException):
self.query_params["embed"] = "true"

def test__setitem__raises_error_with_embed_options_key(self):
def test_update_adds_values(self):
self.query_params.update({"foo": "bar"})
assert self.query_params.get("foo") == "bar"
message = self.get_message_from_queue(0)
assert "foo=bar" in message.page_info_changed.query_string

def test_update_raises_error_with_embed_key(self):
with pytest.raises(StreamlitAPIException):
self.query_params.update({"foo": "bar", "embed": "true"})

def test_update_raises_error_with_embed_options_key(self):
with pytest.raises(StreamlitAPIException):
self.query_params["embed_options"] = "disable_scrolling"
self.query_params.update({"foo": "bar", "embed_options": "show_toolbar"})

def test_update_raises_exception_with_dictionary_value(self):
with pytest.raises(StreamlitAPIException):
self.query_params.update({"a_dict": {"test": "test"}})

def test_update_changes_values_in_single_message(self):
self.query_params.set_with_no_forward_msg("foo", "test")
self.query_params.update({"foo": "bar", "baz": "test"})
assert self.query_params.get("foo") == "bar"
assert self.query_params.get("baz") == "test"
assert len(self.forward_msg_queue) == 1
message = self.get_message_from_queue(0)
assert "foo=bar" in message.page_info_changed.query_string
assert "baz=test" in message.page_info_changed.query_string

def test__delitem__removes_existing_key(self):
del self.query_params["foo"]
Expand Down

0 comments on commit 2a1f69e

Please sign in to comment.