Skip to content

Commit

Permalink
Fix embed params being dropped in page swaps (streamlit#7918)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-wihuang authored and zyxue committed Apr 16, 2024
1 parent ac776e2 commit cfd0ee6
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 22 deletions.
17 changes: 17 additions & 0 deletions e2e_playwright/multipage_apps/mpa_basics_test.py
Expand Up @@ -176,3 +176,20 @@ def test_removes_query_params_when_swapping_pages(page: Page, app_port: int):
wait_for_app_run(page)

assert page.url == f"http://localhost:{app_port}/page3"


def test_removes_non_embed_query_params_when_swapping_pages(page: Page, app_port: int):
"""Test that query params are removed when swapping pages"""

page.goto(
f"http://localhost:{app_port}/page_7?foo=bar&embed=True&embed_options=show_toolbar&embed_options=show_colored_line"
)
wait_for_app_loaded(page)

page.get_by_test_id("stSidebarNav").locator("a").nth(2).click()
wait_for_app_run(page)

assert (
page.url
== f"http://localhost:{app_port}/page3?embed=true&embed_options=show_toolbar&embed_options=show_colored_line"
)
42 changes: 42 additions & 0 deletions frontend/app/src/App.test.tsx
Expand Up @@ -1193,6 +1193,48 @@ describe("App.sendRerunBackMsg", () => {
queryParams: "",
})
})

it("retains embed query params even if the page hash is different", () => {
const embedParams =
"embed=true&embed_options=disable_scrolling&embed_options=show_colored_line"

const prevWindowLocation = window.location
// @ts-expect-error
delete window.location
// @ts-expect-error
window.location = {
assign: jest.fn(),
search: `foo=bar&${embedParams}`,
}

wrapper.setState({
currentPageScriptHash: "current_page_hash",
queryParams: `foo=bar&${embedParams}`,
})
const sendMessageFunc = jest.spyOn(
// @ts-expect-error
instance.hostCommunicationMgr,
"sendMessageToHost"
)

instance.sendRerunBackMsg(undefined, "some_other_page_hash")

// @ts-expect-error
expect(instance.sendBackMsg).toHaveBeenCalledWith({
rerunScript: {
pageScriptHash: "some_other_page_hash",
pageName: "",
queryString: embedParams,
},
})

expect(sendMessageFunc).toHaveBeenCalledWith({
type: "SET_QUERY_PARAM",
queryParams: embedParams,
})

window.location = prevWindowLocation
})
})

// * handlePageNotFound has branching error messages depending on pageName
Expand Down
11 changes: 8 additions & 3 deletions frontend/app/src/App.tsx
Expand Up @@ -124,6 +124,7 @@ import withScreencast, {

// Used to import fonts + responsive reboot items
import "@streamlit/app/src/assets/css/theme.scss"
import { preserveEmbedQueryParams } from "@streamlit/lib/src/util/utils"

export interface Props {
screenCast: ScreenCastHOC
Expand Down Expand Up @@ -865,10 +866,14 @@ export class App extends PureComponent<Props, State> {
// e.g. the case where the user clicks the back button.
// See https://github.com/streamlit/streamlit/pull/6271#issuecomment-1465090690 for the discussion.
if (prevPageName !== newPageName) {
// If embed params need to be changed, make sure to change to other parts of the code that reference preserveEmbedQueryParams
const queryString = preserveEmbedQueryParams()
const qs = queryString ? `?${queryString}` : ""

const basePathPrefix = basePath ? `/${basePath}` : ""

const pagePath = viewingMainPage ? "" : newPageName
const pageUrl = `${basePathPrefix}/${pagePath}`
const pageUrl = `${basePathPrefix}/${pagePath}${qs}`

window.history.pushState({}, "", pageUrl)
}
Expand Down Expand Up @@ -1307,8 +1312,8 @@ export class App extends PureComponent<Props, State> {
// The user specified exactly which page to run. We can simply use this
// value in the BackMsg we send to the server.
if (pageScriptHash != currentPageScriptHash) {
// clear query parameters within a page change
queryString = ""
// clear non-embed query parameters within a page change
queryString = preserveEmbedQueryParams()
this.hostCommunicationMgr.sendMessageToHost({
type: "SET_QUERY_PARAM",
queryParams: queryString,
Expand Down
44 changes: 44 additions & 0 deletions frontend/lib/src/util/utils.test.ts
Expand Up @@ -23,6 +23,7 @@ import {
getLoadingScreenType,
isEmbed,
setCookie,
preserveEmbedQueryParams,
} from "./utils"

describe("getCookie", () => {
Expand Down Expand Up @@ -331,4 +332,47 @@ describe("getLoadingScreenType", () => {

expect(getLoadingScreenType()).toBe(LoadingScreenType.V2)
})

describe("preserveEmbedQueryParams", () => {
let prevWindowLocation: Location
afterEach(() => {
window.location = prevWindowLocation
})

it("should return an empty string if not in embed mode", () => {
// @ts-expect-error
delete window.location
// @ts-expect-error
window.location = {
assign: jest.fn(),
search: "foo=bar",
}
expect(preserveEmbedQueryParams()).toBe("")
})

it("should preserve embed query string even with no embed options and remove foo=bar", () => {
// @ts-expect-error
delete window.location
// @ts-expect-error
window.location = {
assign: jest.fn(),
search: "embed=true&foo=bar",
}
expect(preserveEmbedQueryParams()).toBe("embed=true")
})

it("should preserve embed query string with embed options and remove foo=bar", () => {
// @ts-expect-error
delete window.location
// @ts-expect-error
window.location = {
assign: jest.fn(),
search:
"embed=true&embed_options=option1&embed_options=option2&foo=bar",
}
expect(preserveEmbedQueryParams()).toBe(
"embed=true&embed_options=option1&embed_options=option2"
)
})
})
})
25 changes: 25 additions & 0 deletions frontend/lib/src/util/utils.ts
Expand Up @@ -98,6 +98,31 @@ export function getEmbedUrlParams(embedKey: string): Set<string> {
return embedUrlParams
}

/**
* Returns "embed" and "embed_options" query param options in the url. Returns empty string if not embedded.
* Example:
* returns "embed=true&embed_options=show_loading_screen_v2" if the url is
* http://localhost:3000/test?embed=true&embed_options=show_loading_screen_v2
*/
export function preserveEmbedQueryParams(): string {
if (!isEmbed()) {
return ""
}

const embedOptionsValues = new URLSearchParams(
window.location.search
).getAll(EMBED_OPTIONS_QUERY_PARAM_KEY)

// instantiate multiple key values with an array of string pairs
// https://stackoverflow.com/questions/72571132/urlsearchparams-with-multiple-values
const embedUrlMap: string[][] = []
embedUrlMap.push([EMBED_QUERY_PARAM_KEY, EMBED_TRUE])
embedOptionsValues.forEach((embedValue: string) => {
embedUrlMap.push([EMBED_OPTIONS_QUERY_PARAM_KEY, embedValue])
})
return new URLSearchParams(embedUrlMap).toString()
}

/**
* Returns true if the URL parameters contain ?embed=true (case insensitive).
*/
Expand Down
2 changes: 1 addition & 1 deletion lib/streamlit/commands/execution_control.py
Expand Up @@ -157,7 +157,7 @@ def switch_page(page: str) -> NoReturn: # type: ignore[misc]

ctx.script_requests.request_rerun(
RerunData(
query_string="",
query_string=ctx.query_string,
page_script_hash=matched_pages[0]["page_script_hash"],
)
)
Expand Down
9 changes: 5 additions & 4 deletions lib/streamlit/commands/experimental_query_params.py
Expand Up @@ -16,15 +16,16 @@
from typing import Any, Dict, List, Union

from streamlit import util
from streamlit.constants import (
EMBED_OPTIONS_QUERY_PARAM,
EMBED_QUERY_PARAM,
EMBED_QUERY_PARAMS_KEYS,
)
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner import get_script_run_ctx

EMBED_QUERY_PARAM = "embed"
EMBED_OPTIONS_QUERY_PARAM = "embed_options"
EMBED_QUERY_PARAMS_KEYS = [EMBED_QUERY_PARAM, EMBED_OPTIONS_QUERY_PARAM]


@gather_metrics("experimental_get_query_params")
def get_query_params() -> Dict[str, List[str]]:
Expand Down
17 changes: 17 additions & 0 deletions lib/streamlit/constants.py
@@ -0,0 +1,17 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

EMBED_QUERY_PARAM = "embed"
EMBED_OPTIONS_QUERY_PARAM = "embed_options"
EMBED_QUERY_PARAMS_KEYS = [EMBED_QUERY_PARAM, EMBED_OPTIONS_QUERY_PARAM]
48 changes: 34 additions & 14 deletions lib/streamlit/runtime/state/query_params.py
Expand Up @@ -14,7 +14,9 @@

from dataclasses import dataclass, field
from typing import Dict, Iterable, Iterator, List, MutableMapping, Union
from urllib import parse

from streamlit.constants import EMBED_QUERY_PARAMS_KEYS
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg

Expand All @@ -29,7 +31,12 @@ class QueryParams(MutableMapping[str, str]):

def __iter__(self) -> Iterator[str]:
self._ensure_single_query_api_used()
return iter(self._query_params.keys())

return iter(
key
for key in self._query_params.keys()
if key not in EMBED_QUERY_PARAMS_KEYS
)

def __getitem__(self, key: str) -> str:
"""Retrieves a value for a given key in query parameters.
Expand All @@ -38,6 +45,8 @@ def __getitem__(self, key: str) -> str:
"""
self._ensure_single_query_api_used()
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
value = self._query_params[key]
if isinstance(value, list):
if len(value) == 0:
Expand All @@ -56,6 +65,10 @@ def __setitem__(self, key: str, value: Union[str, Iterable[str]]) -> None:
f"You cannot set a query params key `{key}` to a dictionary."
)

if key in EMBED_QUERY_PARAMS_KEYS:
raise StreamlitAPIException(
"Query param embed and embed_options (case-insensitive) cannot be set programmatically."
)
# Type checking users should handle the string serialization themselves
# We will accept any type for the list and serialize to str just in case
if isinstance(value, Iterable) and not isinstance(value, str):
Expand All @@ -66,25 +79,28 @@ def __setitem__(self, key: str, value: Union[str, Iterable[str]]) -> None:

def __delitem__(self, key: str) -> None:
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
del self._query_params[key]
self._send_query_param_msg()
except KeyError:
raise KeyError(missing_key_error_message(key))

def get_all(self, key: str) -> List[str]:
self._ensure_single_query_api_used()
if key not in self._query_params:
if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS:
return []
value = self._query_params[key]
return value if isinstance(value, list) else [value]

def __len__(self) -> int:
self._ensure_single_query_api_used()
return len(self._query_params)
return len(
{key for key in self._query_params if key not in EMBED_QUERY_PARAMS_KEYS}
)

def _send_query_param_msg(self) -> None:
# Avoid circular imports
from streamlit.commands.experimental_query_params import _ensure_no_embed_params
from streamlit.runtime.scriptrunner import get_script_run_ctx

ctx = get_script_run_ctx()
Expand All @@ -93,27 +109,31 @@ def _send_query_param_msg(self) -> None:
self._ensure_single_query_api_used()

msg = ForwardMsg()
msg.page_info_changed.query_string = _ensure_no_embed_params(
self._query_params, ctx.query_string
msg.page_info_changed.query_string = parse.urlencode(
self._query_params, doseq=True
)
ctx.query_string = msg.page_info_changed.query_string
ctx.enqueue(msg)

def clear(self) -> None:
self._query_params.clear()
new_query_params = {}
for key, value in self._query_params.items():
if key in EMBED_QUERY_PARAMS_KEYS:
new_query_params[key] = value
self._query_params = new_query_params

self._send_query_param_msg()

def to_dict(self) -> Dict[str, str]:
self._ensure_single_query_api_used()
# return the last query param if multiple keys are set
return {key: self[key] for key in self._query_params}
# return the last query param if multiple values are set
return {
key: self[key]
for key in self._query_params
if key not in EMBED_QUERY_PARAMS_KEYS
}

def set_with_no_forward_msg(self, key: str, val: Union[List[str], str]) -> None:
# Avoid circular imports
from streamlit.commands.experimental_query_params import EMBED_QUERY_PARAMS_KEYS

if key.lower() in EMBED_QUERY_PARAMS_KEYS:
return
self._query_params[key] = val

def clear_with_no_forward_msg(self) -> None:
Expand Down

0 comments on commit cfd0ee6

Please sign in to comment.