From cfd0ee66386b57997054982b18c7c1d32ae9ae0f Mon Sep 17 00:00:00 2001 From: William Wei Huang Date: Tue, 9 Jan 2024 22:23:05 -0800 Subject: [PATCH] Fix embed params being dropped in page swaps (#7918) --- .../multipage_apps/mpa_basics_test.py | 17 +++ frontend/app/src/App.test.tsx | 42 +++++++ frontend/app/src/App.tsx | 11 +- frontend/lib/src/util/utils.test.ts | 44 ++++++++ frontend/lib/src/util/utils.ts | 25 +++++ lib/streamlit/commands/execution_control.py | 2 +- .../commands/experimental_query_params.py | 9 +- lib/streamlit/constants.py | 17 +++ lib/streamlit/runtime/state/query_params.py | 48 +++++--- .../runtime/state/query_params_test.py | 103 ++++++++++++++++++ 10 files changed, 296 insertions(+), 22 deletions(-) create mode 100644 lib/streamlit/constants.py diff --git a/e2e_playwright/multipage_apps/mpa_basics_test.py b/e2e_playwright/multipage_apps/mpa_basics_test.py index 39abac481b6a..602a8da13d0f 100644 --- a/e2e_playwright/multipage_apps/mpa_basics_test.py +++ b/e2e_playwright/multipage_apps/mpa_basics_test.py @@ -176,3 +176,20 @@ def test_removes_query_params_when_swapping_pages(page: Page, app_port: int): wait_for_app_run(page) assert page.url == f"http://localhost:{app_port}/page3" + + +def test_removes_non_embed_query_params_when_swapping_pages(page: Page, app_port: int): + """Test that query params are removed when swapping pages""" + + page.goto( + f"http://localhost:{app_port}/page_7?foo=bar&embed=True&embed_options=show_toolbar&embed_options=show_colored_line" + ) + wait_for_app_loaded(page) + + page.get_by_test_id("stSidebarNav").locator("a").nth(2).click() + wait_for_app_run(page) + + assert ( + page.url + == f"http://localhost:{app_port}/page3?embed=true&embed_options=show_toolbar&embed_options=show_colored_line" + ) diff --git a/frontend/app/src/App.test.tsx b/frontend/app/src/App.test.tsx index a20fedf894cd..a99e37235749 100644 --- a/frontend/app/src/App.test.tsx +++ b/frontend/app/src/App.test.tsx @@ -1193,6 +1193,48 @@ describe("App.sendRerunBackMsg", () => { queryParams: "", }) }) + + it("retains embed query params even if the page hash is different", () => { + const embedParams = + "embed=true&embed_options=disable_scrolling&embed_options=show_colored_line" + + const prevWindowLocation = window.location + // @ts-expect-error + delete window.location + // @ts-expect-error + window.location = { + assign: jest.fn(), + search: `foo=bar&${embedParams}`, + } + + wrapper.setState({ + currentPageScriptHash: "current_page_hash", + queryParams: `foo=bar&${embedParams}`, + }) + const sendMessageFunc = jest.spyOn( + // @ts-expect-error + instance.hostCommunicationMgr, + "sendMessageToHost" + ) + + instance.sendRerunBackMsg(undefined, "some_other_page_hash") + + // @ts-expect-error + expect(instance.sendBackMsg).toHaveBeenCalledWith({ + rerunScript: { + pageScriptHash: "some_other_page_hash", + pageName: "", + queryString: embedParams, + }, + }) + + expect(sendMessageFunc).toHaveBeenCalledWith({ + type: "SET_QUERY_PARAM", + queryParams: embedParams, + }) + + window.location = prevWindowLocation + }) }) // * handlePageNotFound has branching error messages depending on pageName diff --git a/frontend/app/src/App.tsx b/frontend/app/src/App.tsx index 5039c75c769e..a06740ef6150 100644 --- a/frontend/app/src/App.tsx +++ b/frontend/app/src/App.tsx @@ -124,6 +124,7 @@ import withScreencast, { // Used to import fonts + responsive reboot items import "@streamlit/app/src/assets/css/theme.scss" +import { preserveEmbedQueryParams } from "@streamlit/lib/src/util/utils" export interface Props { screenCast: ScreenCastHOC @@ -865,10 +866,14 @@ export class App extends PureComponent { // e.g. the case where the user clicks the back button. // See https://github.com/streamlit/streamlit/pull/6271#issuecomment-1465090690 for the discussion. if (prevPageName !== newPageName) { + // If embed params need to be changed, make sure to change to other parts of the code that reference preserveEmbedQueryParams + const queryString = preserveEmbedQueryParams() + const qs = queryString ? `?${queryString}` : "" + const basePathPrefix = basePath ? `/${basePath}` : "" const pagePath = viewingMainPage ? "" : newPageName - const pageUrl = `${basePathPrefix}/${pagePath}` + const pageUrl = `${basePathPrefix}/${pagePath}${qs}` window.history.pushState({}, "", pageUrl) } @@ -1307,8 +1312,8 @@ export class App extends PureComponent { // The user specified exactly which page to run. We can simply use this // value in the BackMsg we send to the server. if (pageScriptHash != currentPageScriptHash) { - // clear query parameters within a page change - queryString = "" + // clear non-embed query parameters within a page change + queryString = preserveEmbedQueryParams() this.hostCommunicationMgr.sendMessageToHost({ type: "SET_QUERY_PARAM", queryParams: queryString, diff --git a/frontend/lib/src/util/utils.test.ts b/frontend/lib/src/util/utils.test.ts index dfa4d6a92f0a..81d072d4bb44 100644 --- a/frontend/lib/src/util/utils.test.ts +++ b/frontend/lib/src/util/utils.test.ts @@ -23,6 +23,7 @@ import { getLoadingScreenType, isEmbed, setCookie, + preserveEmbedQueryParams, } from "./utils" describe("getCookie", () => { @@ -331,4 +332,47 @@ describe("getLoadingScreenType", () => { expect(getLoadingScreenType()).toBe(LoadingScreenType.V2) }) + + describe("preserveEmbedQueryParams", () => { + let prevWindowLocation: Location + afterEach(() => { + window.location = prevWindowLocation + }) + + it("should return an empty string if not in embed mode", () => { + // @ts-expect-error + delete window.location + // @ts-expect-error + window.location = { + assign: jest.fn(), + search: "foo=bar", + } + expect(preserveEmbedQueryParams()).toBe("") + }) + + it("should preserve embed query string even with no embed options and remove foo=bar", () => { + // @ts-expect-error + delete window.location + // @ts-expect-error + window.location = { + assign: jest.fn(), + search: "embed=true&foo=bar", + } + expect(preserveEmbedQueryParams()).toBe("embed=true") + }) + + it("should preserve embed query string with embed options and remove foo=bar", () => { + // @ts-expect-error + delete window.location + // @ts-expect-error + window.location = { + assign: jest.fn(), + search: + "embed=true&embed_options=option1&embed_options=option2&foo=bar", + } + expect(preserveEmbedQueryParams()).toBe( + "embed=true&embed_options=option1&embed_options=option2" + ) + }) + }) }) diff --git a/frontend/lib/src/util/utils.ts b/frontend/lib/src/util/utils.ts index ebb35b6d33d3..f1728edd463a 100644 --- a/frontend/lib/src/util/utils.ts +++ b/frontend/lib/src/util/utils.ts @@ -98,6 +98,31 @@ export function getEmbedUrlParams(embedKey: string): Set { return embedUrlParams } +/** + * Returns "embed" and "embed_options" query param options in the url. Returns empty string if not embedded. + * Example: + * returns "embed=true&embed_options=show_loading_screen_v2" if the url is + * http://localhost:3000/test?embed=true&embed_options=show_loading_screen_v2 + */ +export function preserveEmbedQueryParams(): string { + if (!isEmbed()) { + return "" + } + + const embedOptionsValues = new URLSearchParams( + window.location.search + ).getAll(EMBED_OPTIONS_QUERY_PARAM_KEY) + + // instantiate multiple key values with an array of string pairs + // https://stackoverflow.com/questions/72571132/urlsearchparams-with-multiple-values + const embedUrlMap: string[][] = [] + embedUrlMap.push([EMBED_QUERY_PARAM_KEY, EMBED_TRUE]) + embedOptionsValues.forEach((embedValue: string) => { + embedUrlMap.push([EMBED_OPTIONS_QUERY_PARAM_KEY, embedValue]) + }) + return new URLSearchParams(embedUrlMap).toString() +} + /** * Returns true if the URL parameters contain ?embed=true (case insensitive). */ diff --git a/lib/streamlit/commands/execution_control.py b/lib/streamlit/commands/execution_control.py index 3913d11b1a94..07bc676d6d89 100644 --- a/lib/streamlit/commands/execution_control.py +++ b/lib/streamlit/commands/execution_control.py @@ -157,7 +157,7 @@ def switch_page(page: str) -> NoReturn: # type: ignore[misc] ctx.script_requests.request_rerun( RerunData( - query_string="", + query_string=ctx.query_string, page_script_hash=matched_pages[0]["page_script_hash"], ) ) diff --git a/lib/streamlit/commands/experimental_query_params.py b/lib/streamlit/commands/experimental_query_params.py index ad459ea7478a..c8e17913b12e 100644 --- a/lib/streamlit/commands/experimental_query_params.py +++ b/lib/streamlit/commands/experimental_query_params.py @@ -16,15 +16,16 @@ from typing import Any, Dict, List, Union from streamlit import util +from streamlit.constants import ( + EMBED_OPTIONS_QUERY_PARAM, + EMBED_QUERY_PARAM, + EMBED_QUERY_PARAMS_KEYS, +) from streamlit.errors import StreamlitAPIException from streamlit.proto.ForwardMsg_pb2 import ForwardMsg from streamlit.runtime.metrics_util import gather_metrics from streamlit.runtime.scriptrunner import get_script_run_ctx -EMBED_QUERY_PARAM = "embed" -EMBED_OPTIONS_QUERY_PARAM = "embed_options" -EMBED_QUERY_PARAMS_KEYS = [EMBED_QUERY_PARAM, EMBED_OPTIONS_QUERY_PARAM] - @gather_metrics("experimental_get_query_params") def get_query_params() -> Dict[str, List[str]]: diff --git a/lib/streamlit/constants.py b/lib/streamlit/constants.py new file mode 100644 index 000000000000..afa995a93849 --- /dev/null +++ b/lib/streamlit/constants.py @@ -0,0 +1,17 @@ +# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +EMBED_QUERY_PARAM = "embed" +EMBED_OPTIONS_QUERY_PARAM = "embed_options" +EMBED_QUERY_PARAMS_KEYS = [EMBED_QUERY_PARAM, EMBED_OPTIONS_QUERY_PARAM] diff --git a/lib/streamlit/runtime/state/query_params.py b/lib/streamlit/runtime/state/query_params.py index 522b2f4a355c..3525a0c7141a 100644 --- a/lib/streamlit/runtime/state/query_params.py +++ b/lib/streamlit/runtime/state/query_params.py @@ -14,7 +14,9 @@ from dataclasses import dataclass, field from typing import Dict, Iterable, Iterator, List, MutableMapping, Union +from urllib import parse +from streamlit.constants import EMBED_QUERY_PARAMS_KEYS from streamlit.errors import StreamlitAPIException from streamlit.proto.ForwardMsg_pb2 import ForwardMsg @@ -29,7 +31,12 @@ class QueryParams(MutableMapping[str, str]): def __iter__(self) -> Iterator[str]: self._ensure_single_query_api_used() - return iter(self._query_params.keys()) + + return iter( + key + for key in self._query_params.keys() + if key not in EMBED_QUERY_PARAMS_KEYS + ) def __getitem__(self, key: str) -> str: """Retrieves a value for a given key in query parameters. @@ -38,6 +45,8 @@ def __getitem__(self, key: str) -> str: """ self._ensure_single_query_api_used() try: + if key in EMBED_QUERY_PARAMS_KEYS: + raise KeyError(missing_key_error_message(key)) value = self._query_params[key] if isinstance(value, list): if len(value) == 0: @@ -56,6 +65,10 @@ def __setitem__(self, key: str, value: Union[str, Iterable[str]]) -> None: f"You cannot set a query params key `{key}` to a dictionary." ) + if key in EMBED_QUERY_PARAMS_KEYS: + raise StreamlitAPIException( + "Query param embed and embed_options (case-insensitive) cannot be set programmatically." + ) # Type checking users should handle the string serialization themselves # We will accept any type for the list and serialize to str just in case if isinstance(value, Iterable) and not isinstance(value, str): @@ -66,6 +79,8 @@ def __setitem__(self, key: str, value: Union[str, Iterable[str]]) -> None: def __delitem__(self, key: str) -> None: try: + if key in EMBED_QUERY_PARAMS_KEYS: + raise KeyError(missing_key_error_message(key)) del self._query_params[key] self._send_query_param_msg() except KeyError: @@ -73,18 +88,19 @@ def __delitem__(self, key: str) -> None: def get_all(self, key: str) -> List[str]: self._ensure_single_query_api_used() - if key not in self._query_params: + if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS: return [] value = self._query_params[key] return value if isinstance(value, list) else [value] def __len__(self) -> int: self._ensure_single_query_api_used() - return len(self._query_params) + return len( + {key for key in self._query_params if key not in EMBED_QUERY_PARAMS_KEYS} + ) def _send_query_param_msg(self) -> None: # Avoid circular imports - from streamlit.commands.experimental_query_params import _ensure_no_embed_params from streamlit.runtime.scriptrunner import get_script_run_ctx ctx = get_script_run_ctx() @@ -93,27 +109,31 @@ def _send_query_param_msg(self) -> None: self._ensure_single_query_api_used() msg = ForwardMsg() - msg.page_info_changed.query_string = _ensure_no_embed_params( - self._query_params, ctx.query_string + msg.page_info_changed.query_string = parse.urlencode( + self._query_params, doseq=True ) ctx.query_string = msg.page_info_changed.query_string ctx.enqueue(msg) def clear(self) -> None: - self._query_params.clear() + new_query_params = {} + for key, value in self._query_params.items(): + if key in EMBED_QUERY_PARAMS_KEYS: + new_query_params[key] = value + self._query_params = new_query_params + self._send_query_param_msg() def to_dict(self) -> Dict[str, str]: self._ensure_single_query_api_used() - # return the last query param if multiple keys are set - return {key: self[key] for key in self._query_params} + # return the last query param if multiple values are set + return { + key: self[key] + for key in self._query_params + if key not in EMBED_QUERY_PARAMS_KEYS + } def set_with_no_forward_msg(self, key: str, val: Union[List[str], str]) -> None: - # Avoid circular imports - from streamlit.commands.experimental_query_params import EMBED_QUERY_PARAMS_KEYS - - if key.lower() in EMBED_QUERY_PARAMS_KEYS: - return self._query_params[key] = val def clear_with_no_forward_msg(self) -> None: diff --git a/lib/tests/streamlit/runtime/state/query_params_test.py b/lib/tests/streamlit/runtime/state/query_params_test.py index 2a34a6d070b6..d6c1fc47f3a5 100644 --- a/lib/tests/streamlit/runtime/state/query_params_test.py +++ b/lib/tests/streamlit/runtime/state/query_params_test.py @@ -20,11 +20,34 @@ class QueryParamsMethodTests(DeltaGeneratorTestCase): + query_params_dict_with_embed_key = { + "foo": "bar", + "two": ["x", "y"], + "embed": "true", + "embed_options": "disable_scrolling", + } + def setUp(self): super().setUp() self.query_params = QueryParams() self.query_params._query_params = {"foo": "bar", "two": ["x", "y"]} + def test__iter__doesnt_include_embed_keys(self): + self.query_params._query_params = self.query_params_dict_with_embed_key + for key in self.query_params.__iter__(): + if key == "embed" or key == "embed_options": + raise KeyError("Cannot iterate through embed or embed_options key") + + def test__getitem__raises_KeyError_for_nonexistent_key_for_embed(self): + self.query_params._query_params = self.query_params_dict_with_embed_key + with pytest.raises(KeyError): + self.query_params["embed"] + + def test__getitem__raises_KeyError_for_nonexistent_key_for_embed_options(self): + self.query_params._query_params = self.query_params_dict_with_embed_key + with pytest.raises(KeyError): + self.query_params["embed_options"] + def test__getitem__raises_KeyError_for_nonexistent_key(self): with pytest.raises(KeyError): self.query_params["nonexistent"] @@ -97,6 +120,14 @@ def test__setitem__raises_exception_for_embed_options_key(self): with pytest.raises(StreamlitAPIException): self.query_params["embed_options"] = "show_toolbar" + def test__setitem__raises_error_with_embed_key(self): + with pytest.raises(StreamlitAPIException): + self.query_params["embed"] = "true" + + def test__setitem__raises_error_with_embed_options_key(self): + with pytest.raises(StreamlitAPIException): + self.query_params["embed_options"] = "disable_scrolling" + def test__delitem__removes_existing_key(self): del self.query_params["foo"] assert "foo" not in self.query_params @@ -108,6 +139,18 @@ def test__delitem__raises_error_for_nonexistent_key(self): with pytest.raises(KeyError): del self.query_params["nonexistent"] + def test__delitem__throws_KeyErrorException_for_embed_key(self): + self.query_params._query_params = self.query_params_dict_with_embed_key + with pytest.raises(KeyError): + del self.query_params["embed"] + assert "embed" in self.query_params._query_params + + def test__delitem__throws_KeyErrorException_for_embed_options_key(self): + self.query_params._query_params = self.query_params_dict_with_embed_key + with pytest.raises(KeyError): + del self.query_params["embed_options"] + assert "embed_options" in self.query_params._query_params + def test_get_all_returns_empty_list_for_nonexistent_key(self): assert self.query_params.get_all("nonexistent") == [] @@ -121,17 +164,51 @@ def test_get_all_handles_mixed_type_values(self): self.query_params["test"] = ["", "a", 1, 1.23] assert self.query_params.get_all("test") == ["", "a", "1", "1.23"] + def test_get_all_returns_empty_array_for_embed_key(self): + self.query_params._query_params = self.query_params_dict_with_embed_key + assert self.query_params.get_all("embed") == [] + + def test_get_all_returns_empty_array_for_embed_options_key(self): + self.query_params._query_params = self.query_params_dict_with_embed_key + assert self.query_params.get_all("embed_options") == [] + + def test__len__doesnt_include_embed_and_embed_options_key(self): + self.query_params._query_params = self.query_params_dict_with_embed_key + assert len(self.query_params) == 2 + def test_clear_removes_all_query_params(self): self.query_params.clear() assert len(self.query_params) == 0 message = self.get_message_from_queue(0) assert "" == message.page_info_changed.query_string + def test_clear_doesnt_remove_embed_query_params(self): + self.query_params._query_params = { + "foo": "bar", + "embed": "true", + "embed_options": ["show_colored_line", "disable_scrolling"], + } + result_dict = { + "embed": "true", + "embed_options": ["show_colored_line", "disable_scrolling"], + } + self.query_params.clear() + assert self.query_params._query_params == result_dict + def test_to_dict(self): self.query_params["baz"] = "" result_dict = {"foo": "bar", "two": "y", "baz": ""} assert self.query_params.to_dict() == result_dict + def test_to_dict_doesnt_include_embed_params(self): + self.query_params._query_params = { + "foo": "bar", + "embed": "true", + "embed_options": ["show_colored_line", "disable_scrolling"], + } + result_dict = {"foo": "bar"} + assert self.query_params.to_dict() == result_dict + def test_set_with_no_forward_msg_sends_no_msg_and_sets_query_params(self): self.query_params.set_with_no_forward_msg("test", "test") assert self.query_params["test"] == "test" @@ -139,6 +216,32 @@ def test_set_with_no_forward_msg_sends_no_msg_and_sets_query_params(self): # no forward message should be sent self.get_message_from_queue(0) + def test_set_with_no_forward_msg_accepts_embed(self): + self.query_params.set_with_no_forward_msg("embed", "true") + assert self.query_params._query_params["embed"] == "true" + with pytest.raises(IndexError): + # no forward message should be sent + self.get_message_from_queue(0) + + def test_set_with_no_forward_msg_accepts_embed_options(self): + self.query_params.set_with_no_forward_msg("embed_options", "disable_scrolling") + assert self.query_params._query_params["embed_options"] == "disable_scrolling" + with pytest.raises(IndexError): + # no forward message should be sent + self.get_message_from_queue(0) + + def test_set_with_no_forward_msg_accepts_multiple_embed_options(self): + self.query_params.set_with_no_forward_msg( + "embed_options", ["disable_scrolling", "show_colored_line"] + ) + assert self.query_params._query_params["embed_options"] == [ + "disable_scrolling", + "show_colored_line", + ] + with pytest.raises(IndexError): + # no forward message should be sent + self.get_message_from_queue(0) + def test_clear_with_no_forward_msg_sends_no_msg_and_clears_query_params(self): self.query_params.clear_with_no_forward_msg() assert len(self.query_params) == 0