forked from streamlit/streamlit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
query_params.py
183 lines (153 loc) · 6.55 KB
/
query_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Iterable, Iterator, MutableMapping
from urllib import parse
from streamlit.constants import EMBED_QUERY_PARAMS_KEYS
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem
@dataclass
class QueryParams(MutableMapping[str, str]):
"""A lightweight wrapper of a dict that sends forwardMsgs when state changes.
It stores str keys with str and List[str] values.
"""
_query_params: dict[str, list[str] | str] = field(default_factory=dict)
def __iter__(self) -> Iterator[str]:
self._ensure_single_query_api_used()
return iter(
key
for key in self._query_params.keys()
if key not in EMBED_QUERY_PARAMS_KEYS
)
def __getitem__(self, key: str) -> str:
"""Retrieves a value for a given key in query parameters.
Returns the last item in a list or an empty string if empty.
If the key is not present, raise KeyError.
"""
self._ensure_single_query_api_used()
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
value = self._query_params[key]
if isinstance(value, list):
if len(value) == 0:
return ""
else:
# Return the last value to mimic Tornado's behavior
# https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.get_query_argument
return value[-1]
return value
except KeyError:
raise KeyError(missing_key_error_message(key))
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
self.__set_item_internal(key, value)
self._send_query_param_msg()
def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
if isinstance(value, dict):
raise StreamlitAPIException(
f"You cannot set a query params key `{key}` to a dictionary."
)
if key in EMBED_QUERY_PARAMS_KEYS:
raise StreamlitAPIException(
"Query param embed and embed_options (case-insensitive) cannot be set programmatically."
)
# Type checking users should handle the string serialization themselves
# We will accept any type for the list and serialize to str just in case
if isinstance(value, Iterable) and not isinstance(value, str):
self._query_params[key] = [str(item) for item in value]
else:
self._query_params[key] = str(value)
def __delitem__(self, key: str) -> None:
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
del self._query_params[key]
self._send_query_param_msg()
except KeyError:
raise KeyError(missing_key_error_message(key))
def update(
self,
other: Iterable[tuple[str, str]] | SupportsKeysAndGetItem[str, str] = (),
/,
**kwds: str,
):
# an update function that only sends one ForwardMsg
# once all keys have been updated.
if hasattr(other, "keys") and hasattr(other, "__getitem__"):
for key in other.keys():
self.__set_item_internal(key, other[key])
else:
for (key, value) in other:
self.__set_item_internal(key, value)
for key, value in kwds.items():
self.__set_item_internal(key, value)
self._send_query_param_msg()
def get_all(self, key: str) -> list[str]:
self._ensure_single_query_api_used()
if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS:
return []
value = self._query_params[key]
return value if isinstance(value, list) else [value]
def __len__(self) -> int:
self._ensure_single_query_api_used()
return len(
{key for key in self._query_params if key not in EMBED_QUERY_PARAMS_KEYS}
)
def __str__(self) -> str:
self._ensure_single_query_api_used()
return str(self._query_params)
def _send_query_param_msg(self) -> None:
# Avoid circular imports
from streamlit.runtime.scriptrunner import get_script_run_ctx
ctx = get_script_run_ctx()
if ctx is None:
return
self._ensure_single_query_api_used()
msg = ForwardMsg()
msg.page_info_changed.query_string = parse.urlencode(
self._query_params, doseq=True
)
ctx.query_string = msg.page_info_changed.query_string
ctx.enqueue(msg)
def clear(self) -> None:
new_query_params = {}
for key, value in self._query_params.items():
if key in EMBED_QUERY_PARAMS_KEYS:
new_query_params[key] = value
self._query_params = new_query_params
self._send_query_param_msg()
def to_dict(self) -> dict[str, str]:
self._ensure_single_query_api_used()
# return the last query param if multiple values are set
return {
key: self[key]
for key in self._query_params
if key not in EMBED_QUERY_PARAMS_KEYS
}
def set_with_no_forward_msg(self, key: str, val: list[str] | str) -> None:
self._query_params[key] = val
def clear_with_no_forward_msg(self) -> None:
self._query_params.clear()
def _ensure_single_query_api_used(self):
# Avoid circular imports
from streamlit.runtime.scriptrunner import get_script_run_ctx
ctx = get_script_run_ctx()
if ctx is None:
return
ctx.mark_production_query_params_used()
def missing_key_error_message(key: str) -> str:
return f'st.query_params has no key "{key}". Did you forget to initialize it?'