Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Query Params #7664

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions lib/streamlit/commands/query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import urllib.parse as parse
from dataclasses import dataclass, field
from typing import Any, Dict, List

from streamlit import util
Expand Down Expand Up @@ -140,3 +141,46 @@ def _ensure_no_embed_params(
separator = "&" if current_embed_params else ""
return separator.join([query_string, current_embed_params])
return current_embed_params


@dataclass
class QueryParams:
"""A dict-like representation of query params.
The main difference is that it only stores and returns str and list[str].

TODO(willhuang1997): Fill in these docs with examples and fix doc above. Above is just a stub for now.
"""

_query_params: Dict[str, List[Any]] = field(default_factory=dict)

def __init__(self, query_params: Dict[str, List[Any]] = {}):
self._query_params = query_params

def __repr__(self):
return util.repr_(self)

def __contains__(self, key: str) -> bool:
return key in self._query_params

def __len__(self):
return len(self._query_params)

def _send_query_param_msg(self):
ctx = get_script_run_ctx()
if ctx is None:
return
msg = ForwardMsg()
msg.page_info_changed.query_string = _ensure_no_embed_params(
self._query_params, ctx.query_string
)
ctx.query_string = msg.page_info_changed.query_string
ctx.enqueue(msg)

def clear(self):
self._query_params.clear()
self._send_query_param_msg()

def __delitem__(self, key: str) -> None:
if key in self._query_params:
del self._query_params[key]
self._send_query_param_msg()
50 changes: 47 additions & 3 deletions lib/tests/streamlit/commands/query_params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import streamlit as st
from streamlit.commands.query_params import QueryParams
from streamlit.errors import StreamlitAPIException
from tests.delta_generator_test_case import DeltaGeneratorTestCase

Expand All @@ -33,15 +36,56 @@ def test_set_query_params_exceptions(self):
with self.assertRaises(StreamlitAPIException):
st.experimental_set_query_params(embed_options="show_colored_line")

def test_get_query_params_after_set_query_params(self):
def test_get_query_params_after_set_query_params_single_element(self):
"""Test valid st.set_query_params sends protobuf message."""
p_set = dict(x=["a"])

p_set = {"x": ["a"]}
st.experimental_set_query_params(**p_set)
p_get = st.experimental_get_query_params()
self.assertEqual(p_get, p_set)

def test_get_query_params_after_set_query_params_list(self):
"""Test valid st.set_query_params sends protobuf message."""

p_set = {"x": ["a", "b"]}
st.experimental_set_query_params(**p_set)
p_get = st.experimental_get_query_params()
self.assertEqual(p_get, p_set)

def test_set_query_params_empty_str(self):
empty_str_params = dict(x=[""])
empty_str_params = {"x": [""]}
st.experimental_set_query_params(**empty_str_params)
params_get = st.experimental_get_query_params()
self.assertEqual(params_get, empty_str_params)


class QueryParamsMethodTests(DeltaGeneratorTestCase):
def setUp(self):
super().setUp()
self.q_params = {"foo": "bar", "a": "a"}
self.query_params = QueryParams(self.q_params)

def test_contains_valid(self):
assert "foo" in self.query_params

def test_contains_invalid(self):
assert "baz" not in self.query_params

def test_clear(self):
self.query_params.clear()
assert len(self.query_params) == 0
message = self.get_message_from_queue(0)
self.assertEqual(message.page_info_changed.query_string, "")

def test_del_valid(self):
del self.query_params["foo"]
assert "foo" not in self.query_params
message = self.get_message_from_queue(0)
self.assertEqual(message.page_info_changed.query_string, "a=a")

def test_del_invalid(self):
del self.query_params["nonexistent"]

# no forward message should be sent
with pytest.raises(IndexError):
self.get_message_from_queue(0)