Skip to content

Commit

Permalink
from_dict for query params (#8470)
Browse files Browse the repository at this point in the history
## Describe your changes
Implements an atomic way to overwrite all query_params

## GitHub Issue Link (if applicable)
Closes #8407

## Testing Plan
Unit tests for query_params and query_params_proxy.

---------

Co-authored-by: Vincent Donato <vincent@streamlit.io>
  • Loading branch information
Asaurus1 and vdonato committed Apr 13, 2024
1 parent 2e87989 commit fc01c1f
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 14 deletions.
39 changes: 28 additions & 11 deletions lib/streamlit/runtime/state/query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __getitem__(self, key: str) -> str:
raise KeyError(missing_key_error_message(key))

def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
self._ensure_single_query_api_used()
self.__set_item_internal(key, value)
self._send_query_param_msg()

Expand All @@ -86,6 +87,7 @@ def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
self._query_params[key] = str(value)

def __delitem__(self, key: str) -> None:
self._ensure_single_query_api_used()
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
Expand All @@ -100,13 +102,14 @@ def update(
/,
**kwds: str,
):
# an update function that only sends one ForwardMsg
# once all keys have been updated.
# This overrides the `update` provided by MutableMapping
# to ensure only one one ForwardMsg is sent.
self._ensure_single_query_api_used()
if hasattr(other, "keys") and hasattr(other, "__getitem__"):
for key in other.keys():
self.__set_item_internal(key, other[key])
else:
for (key, value) in other:
for key, value in other:
self.__set_item_internal(key, value)
for key, value in kwds.items():
self.__set_item_internal(key, value)
Expand Down Expand Up @@ -146,12 +149,8 @@ def _send_query_param_msg(self) -> None:
ctx.enqueue(msg)

def clear(self) -> None:
new_query_params = {}
for key, value in self._query_params.items():
if key in EMBED_QUERY_PARAMS_KEYS:
new_query_params[key] = value
self._query_params = new_query_params

self._ensure_single_query_api_used()
self.clear_with_no_forward_msg(preserve_embed=True)
self._send_query_param_msg()

def to_dict(self) -> dict[str, str]:
Expand All @@ -163,11 +162,29 @@ def to_dict(self) -> dict[str, str]:
if key not in EMBED_QUERY_PARAMS_KEYS
}

def from_dict(
self,
_dict: Iterable[tuple[str, str]] | SupportsKeysAndGetItem[str, str],
):
self._ensure_single_query_api_used()
old_value = self._query_params.copy()
self.clear_with_no_forward_msg(preserve_embed=True)
try:
self.update(_dict)
except StreamlitAPIException:
# restore the original from before we made any changes.
self._query_params = old_value
raise

def set_with_no_forward_msg(self, key: str, val: list[str] | str) -> None:
self._query_params[key] = val

def clear_with_no_forward_msg(self) -> None:
self._query_params.clear()
def clear_with_no_forward_msg(self, preserve_embed: bool = False) -> None:
self._query_params = {
key: value
for key, value in self._query_params.items()
if key in EMBED_QUERY_PARAMS_KEYS and preserve_embed
}

def _ensure_single_query_api_used(self):
# Avoid circular imports
Expand Down
54 changes: 51 additions & 3 deletions lib/streamlit/runtime/state/query_params_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,33 @@ def __delattr__(self, key: str) -> None:
raise AttributeError(missing_key_error_message(key))

@overload
def update(self, mapping: SupportsKeysAndGetItem[str, str], /, **kwds: str):
def update(self, mapping: SupportsKeysAndGetItem[str, str], /, **kwds: str) -> None:
...

@overload
def update(self, keys_and_values: Iterable[tuple[str, str]], /, **kwds: str):
def update(
self, keys_and_values: Iterable[tuple[str, str]], /, **kwds: str
) -> None:
...

@overload
def update(self, **kwds: str):
def update(self, **kwds: str) -> None:
...

def update(self, other=(), /, **kwds):
"""
Update one or more values in query_params at once from a dictionary or
dictionary-like object.
See `Mapping.update()` from Python's `collections` library.
Parameters
----------
other: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]]
A dictionary or mapping of strings to strings.
**kwds: str
Additional key/value pairs to update passed as keyword arguments.
"""
with get_session_state().query_params() as qp:
qp.update(other, **kwds)

Expand Down Expand Up @@ -146,3 +161,36 @@ def to_dict(self) -> dict[str, str]:
"""
with get_session_state().query_params() as qp:
return qp.to_dict()

@overload
def from_dict(self, keys_and_values: Iterable[tuple[str, str]]) -> None:
...

@overload
def from_dict(self, mapping: SupportsKeysAndGetItem[str, str]) -> None:
...

@gather_metrics("query_params.from_dict")
def from_dict(self, other):
"""
Set all of the query parameters from a dictionary or dictionary-like object.
This method primarily exists for advanced users who want to be able to control
multiple query string parameters in a single update. To set individual
query string parameters you should still use `st.query_params["parameter"] = "value"`
or `st.query_params.parameter = "value"`.
`embed` and `embed_options` may not be set via this method and may not be keys in the
`other` dictionary.
Note that this method is NOT a direct inverse of `st.query_params.to_dict()` when
the URL query string contains multiple values for a single key. A true inverse
operation for from_dict is `{key: st.query_params.get_all(key) for key st.query_params}`.
Parameters
-------
other: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]]
A dictionary used to replace the current query_params.
"""
with get_session_state().query_params() as qp:
return qp.from_dict(other)
15 changes: 15 additions & 0 deletions lib/tests/streamlit/runtime/state/query_params_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,18 @@ def test__getattr__raises_Attribute_exception(self):
def test__delattr__raises_Attribute_exception(self):
with pytest.raises(AttributeError):
del self.query_params_proxy.nonexistent

def test_to_dict(self):
self.query_params_proxy["test_multi"] = ["value1", "value2"]
assert self.query_params_proxy.to_dict() == {
"test": "value",
"test_multi": "value2",
}

def test_from_dict(self):
new_dict = {"test_new": "value_new", "test_multi": ["value1", "value2"]}
self.query_params_proxy.from_dict(new_dict)
assert self.query_params_proxy.test_new == "value_new"
assert self.query_params_proxy["test_multi"] == "value2"
assert self.query_params_proxy.get_all("test_multi") == ["value1", "value2"]
assert len(self.query_params_proxy) == 2
88 changes: 88 additions & 0 deletions lib/tests/streamlit/runtime/state/query_params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,79 @@ def test_to_dict_doesnt_include_embed_params(self):
result_dict = {"foo": "bar"}
assert self.query_params.to_dict() == result_dict

def test_from_dict(self):
result_dict = {"hello": "world"}
self.query_params.from_dict(result_dict)
assert self.query_params.to_dict() == result_dict

def test_from_dict_iterable(self):
self.query_params.from_dict((("key1", 5), ("key2", 6)))
assert self.query_params._query_params == {"key1": "5", "key2": "6"}

def test_from_dict_mixed_values(self):
result_dict = {"hello": ["world", "janice", "amy"], "snow": "flake"}
self.query_params.from_dict(result_dict)

# self.query_params.to_dict() has behavior consistent with fetching values using
# self.query_params["some_key"]. That is, if the value is an array, the last
# element of the array is returned rather than the array in its entirety.
assert self.query_params.to_dict() == {"hello": "amy", "snow": "flake"}

result_as_list = {"hello": ["world", "janice", "amy"], "snow": ["flake"]}
qp_as_list = {key: self.query_params.get_all(key) for key in self.query_params}
assert result_as_list == qp_as_list

def test_from_dict_preserves_embed_keys(self):
self.query_params._query_params.update(
{"embed_options": ["disable_scrolling", "show_colored_line"]}
)
self.query_params.from_dict({"a": "b", "c": "d"})
assert self.query_params._query_params == {
"a": "b",
"c": "d",
"embed_options": ["disable_scrolling", "show_colored_line"],
}

def test_from_dict_preserves_last_value_on_error(self):
old_value = self.query_params._query_params
with pytest.raises(StreamlitAPIException):
self.query_params.from_dict({"a": "b", "embed": False})
assert self.query_params._query_params == old_value

def test_from_dict_changes_values_in_single_message(self):
self.query_params.set_with_no_forward_msg("hello", "world")
self.query_params.from_dict({"foo": "bar", "baz": "test"})
assert self.query_params.get("foo") == "bar"
assert self.query_params.get("baz") == "test"
assert len(self.forward_msg_queue) == 1
message = self.get_message_from_queue(0)
assert message.page_info_changed.query_string == "foo=bar&baz=test"

def test_from_dict_raises_error_with_embed_key(self):
with pytest.raises(StreamlitAPIException):
self.query_params.from_dict({"foo": "bar", "embed": "true"})

def test_from_dict_raises_error_with_embed_options_key(self):
with pytest.raises(StreamlitAPIException):
self.query_params.from_dict({"foo": "bar", "embed_options": "show_toolbar"})

def test_from_dict_raises_exception_with_dictionary_value(self):
with pytest.raises(StreamlitAPIException):
self.query_params.from_dict({"a_dict": {"test": "test"}})

def test_from_dict_inverse(self):
self.query_params.from_dict({"a": "b", "c": "d"})
assert self.query_params._query_params == {"a": "b", "c": "d"}
message = self.get_message_from_queue(0)
assert message.page_info_changed.query_string == "a=b&c=d"
from_dict_inverse = {
key: self.query_params.get_all(key) for key in self.query_params
}
self.query_params.from_dict(from_dict_inverse)
assert self.query_params._query_params == {"a": ["b"], "c": ["d"]}
message = self.get_message_from_queue(0)
assert message.page_info_changed.query_string == "a=b&c=d"

def test_set_with_no_forward_msg_sends_no_msg_and_sets_query_params(self):
self.query_params.set_with_no_forward_msg("test", "test")
assert self.query_params["test"] == "test"
Expand Down Expand Up @@ -267,8 +340,23 @@ def test_set_with_no_forward_msg_accepts_multiple_embed_options(self):
self.get_message_from_queue(0)

def test_clear_with_no_forward_msg_sends_no_msg_and_clears_query_params(self):
self.query_params._query_params.update(
{"embed_options": ["disable_scrolling", "show_colored_line"]}
)
self.query_params.clear_with_no_forward_msg()
assert len(self.query_params) == 0
assert len(self.query_params._query_params) == 0
with pytest.raises(IndexError):
# no forward message should be sent
self.get_message_from_queue(0)

def test_clear_with_no_forward_msg_preserve_embed_keys(self):
self.query_params._query_params.update(
{"embed_options": ["disable_scrolling", "show_colored_line"]}
)
self.query_params.clear_with_no_forward_msg(preserve_embed=True)
assert len(self.query_params) == 0
assert len(self.query_params._query_params) == 1
assert self.query_params._query_params["embed_options"] == (
["disable_scrolling", "show_colored_line"]
)

0 comments on commit fc01c1f

Please sign in to comment.