Skip to content

Commit

Permalink
explicit update() method for query_params (streamlit#8205)
Browse files Browse the repository at this point in the history
  • Loading branch information
Asaurus1 authored and zyxue committed Apr 16, 2024
1 parent ca9342b commit 00bb6d4
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Expand Up @@ -27,6 +27,7 @@ repos:
# Configure Black to support only syntax supported by the minimum supported Python version in setup.py.
- --target-version=py37
files: \.py$|\.pyi$
exclude: ^e2e_flaky/.*
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
Expand Down Expand Up @@ -171,6 +172,7 @@ repos:
|^vendor/
|^component-lib/declarations/apache-arrow
|^lib/tests/isolated_asyncio_test_case\.py$
|^e2e_flaky/.*
- id: insert-license
name: Add license for all HTML files
files: \.html$
Expand Down
3 changes: 3 additions & 0 deletions lib/streamlit/runtime/forward_msg_queue.py
Expand Up @@ -92,6 +92,9 @@ def flush(self) -> list[ForwardMsg]:
self.clear()
return queue

def __len__(self) -> int:
return len(self._queue)


def _is_composable_message(msg: ForwardMsg) -> bool:
"""True if the ForwardMsg is potentially composable with other ForwardMsgs."""
Expand Down
28 changes: 26 additions & 2 deletions lib/streamlit/runtime/state/query_params.py
Expand Up @@ -15,13 +15,16 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Iterable, Iterator, MutableMapping
from typing import TYPE_CHECKING, Iterable, Iterator, MutableMapping
from urllib import parse

from streamlit.constants import EMBED_QUERY_PARAMS_KEYS
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg

if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem


@dataclass
class QueryParams(MutableMapping[str, str]):
Expand Down Expand Up @@ -62,6 +65,10 @@ def __getitem__(self, key: str) -> str:
raise KeyError(missing_key_error_message(key))

def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
self.__set_item_internal(key, value)
self._send_query_param_msg()

def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
if isinstance(value, dict):
raise StreamlitAPIException(
f"You cannot set a query params key `{key}` to a dictionary."
Expand All @@ -77,7 +84,6 @@ def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
self._query_params[key] = [str(item) for item in value]
else:
self._query_params[key] = str(value)
self._send_query_param_msg()

def __delitem__(self, key: str) -> None:
try:
Expand All @@ -88,6 +94,24 @@ def __delitem__(self, key: str) -> None:
except KeyError:
raise KeyError(missing_key_error_message(key))

def update(
self,
other: Iterable[tuple[str, str]] | SupportsKeysAndGetItem[str, str] = (),
/,
**kwds: str,
):
# an update function that only sends one ForwardMsg
# once all keys have been updated.
if hasattr(other, "keys") and hasattr(other, "__getitem__"):
for key in other.keys():
self.__set_item_internal(key, other[key])
else:
for (key, value) in other:
self.__set_item_internal(key, value)
for key, value in kwds.items():
self.__set_item_internal(key, value)
self._send_query_param_msg()

def get_all(self, key: str) -> list[str]:
self._ensure_single_query_api_used()
if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS:
Expand Down
21 changes: 20 additions & 1 deletion lib/streamlit/runtime/state/query_params_proxy.py
Expand Up @@ -14,12 +14,15 @@

from __future__ import annotations

from typing import Iterable, Iterator, MutableMapping
from typing import TYPE_CHECKING, Iterable, Iterator, MutableMapping, overload

from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.state.query_params import missing_key_error_message
from streamlit.runtime.state.session_state_proxy import get_session_state

if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem


class QueryParamsProxy(MutableMapping[str, str]):
"""
Expand Down Expand Up @@ -68,6 +71,22 @@ def __delattr__(self, key: str) -> None:
except KeyError:
raise AttributeError(missing_key_error_message(key))

@overload
def update(self, mapping: SupportsKeysAndGetItem[str, str], /, **kwds: str):
...

@overload
def update(self, keys_and_values: Iterable[tuple[str, str]], /, **kwds: str):
...

@overload
def update(self, **kwds: str):
...

def update(self, other=(), /, **kwds):
with get_session_state().query_params() as qp:
qp.update(other, **kwds)

@gather_metrics("query_params.set_attr")
def __setattr__(self, key: str, value: str | Iterable[str]) -> None:
with get_session_state().query_params() as qp:
Expand Down
5 changes: 5 additions & 0 deletions lib/tests/streamlit/runtime/state/query_params_proxy_test.py
Expand Up @@ -90,6 +90,11 @@ def test__setattr__sets_entry(self):
self.query_params_proxy.key = "value"
assert self.query_params_proxy["key"] == "value"

def test_update_sets_entries(self):
self.query_params_proxy.update({"key1": "value1", "key2": "value2"})
assert self.query_params_proxy["key1"] == "value1"
assert self.query_params_proxy["key2"] == "value2"

def test__delattr__deletes_entry(self):
del self.query_params_proxy.test
assert "test" not in self.query_params_proxy
Expand Down
28 changes: 26 additions & 2 deletions lib/tests/streamlit/runtime/state/query_params_test.py
Expand Up @@ -124,9 +124,33 @@ def test__setitem__raises_error_with_embed_key(self):
with pytest.raises(StreamlitAPIException):
self.query_params["embed"] = "true"

def test__setitem__raises_error_with_embed_options_key(self):
def test_update_adds_values(self):
self.query_params.update({"foo": "bar"})
assert self.query_params.get("foo") == "bar"
message = self.get_message_from_queue(0)
assert "foo=bar" in message.page_info_changed.query_string

def test_update_raises_error_with_embed_key(self):
with pytest.raises(StreamlitAPIException):
self.query_params.update({"foo": "bar", "embed": "true"})

def test_update_raises_error_with_embed_options_key(self):
with pytest.raises(StreamlitAPIException):
self.query_params["embed_options"] = "disable_scrolling"
self.query_params.update({"foo": "bar", "embed_options": "show_toolbar"})

def test_update_raises_exception_with_dictionary_value(self):
with pytest.raises(StreamlitAPIException):
self.query_params.update({"a_dict": {"test": "test"}})

def test_update_changes_values_in_single_message(self):
self.query_params.set_with_no_forward_msg("foo", "test")
self.query_params.update({"foo": "bar", "baz": "test"})
assert self.query_params.get("foo") == "bar"
assert self.query_params.get("baz") == "test"
assert len(self.forward_msg_queue) == 1
message = self.get_message_from_queue(0)
assert "foo=bar" in message.page_info_changed.query_string
assert "baz=test" in message.page_info_changed.query_string

def test__delitem__removes_existing_key(self):
del self.query_params["foo"]
Expand Down
6 changes: 5 additions & 1 deletion scripts/mypy
Expand Up @@ -29,7 +29,11 @@ import click
import mypy.main as mypy_main

PATHS = ["lib/streamlit/", "scripts/*", "e2e/scripts/*"]
EXCLUDE_FILES = {"scripts/add_license_headers.py", "e2e/scripts/st_reuse_label.py", "scripts/license-template.txt"}
EXCLUDE_FILES = {
os.path.join("scripts", "add_license_headers.py"),
os.path.join("e2e", "scripts", "st_reuse_label.py"),
os.path.join("scripts", "license-template.txt"),
}


def shlex_join(split_command: Iterable[str]):
Expand Down
22 changes: 22 additions & 0 deletions scripts/run_in_subdirectory.py
Expand Up @@ -86,6 +86,23 @@ def fix_arg(subdirectory: str, arg: str) -> str:
return str(arg_path.relative_to(subdirectory))


def try_as_shell(fixed_args: List[str], subdirectory: str):
# Windows doesn't know how to run "yarn" using the CreateProcess
# WINAPI because it's looking for an executable, and yarn is a node script.
# Yarn happens to be the only thing currently run with this patching script,
# so add a fall-back which tries to run the requested command in a shell
# if directly calling the process doesn't work.
import shlex

print("Direct call failed, trying as shell command:")
shell_cmd = shlex.join(fixed_args)
print(shell_cmd)
try:
subprocess.run(shell_cmd, cwd=subdirectory, check=True, shell=True)
except subprocess.CalledProcessError as ex:
sys.exit(ex.returncode)


def main():
subdirectory, subprocess_args = parse_args()

Expand All @@ -94,6 +111,11 @@ def main():
subprocess.run(fixed_args, cwd=subdirectory, check=True)
except subprocess.CalledProcessError as ex:
sys.exit(ex.returncode)
except FileNotFoundError:
if "win32" in sys.platform:
try_as_shell(fixed_args, subdirectory)
else:
sys.exit(1)


main()

0 comments on commit 00bb6d4

Please sign in to comment.