Skip to content

Commit

Permalink
Refactor page management to a PagesManager class (#8639)
Browse files Browse the repository at this point in the history
## Describe your changes

Migrates pages are primarily managed in the `source_util` module to a
`PagesManager` class. Because V2 can have a dynamic set of pages per
session, the Pages Manager instance will live on the `AppSession`. For
V1, we will leverage static variables/methods for page watching.

## Testing Plan

- Original Unit tests are applied (with patches adjusted)
- Unit tests are applied for the PagesManager component

---

**Contribution License Agreement**

By submitting this pull request you agree that all contributions to this
project are made under the Apache 2.0 license.
  • Loading branch information
kmcgrady committed May 20, 2024
1 parent 2f87d23 commit baf3c96
Show file tree
Hide file tree
Showing 30 changed files with 567 additions and 352 deletions.
8 changes: 5 additions & 3 deletions e2e_playwright/hello_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
np.random.seed(0)

# This is a trick to setup the MPA hello app programmatically
from streamlit.runtime.scriptrunner import get_script_run_ctx

source_util._cached_pages = None
source_util._cached_pages = source_util.get_pages(Hello.__file__)
source_util._on_pages_changed.send()
ctx = get_script_run_ctx()
if ctx:
ctx.pages_manager._cached_pages = source_util.get_pages(Hello.__file__)
ctx.pages_manager._on_pages_changed.send()

# TODO(lukasmasuch): Once we migrate the hello app to the new programmatic
# MPA API, we can remove this workaround.
Expand Down
3 changes: 1 addition & 2 deletions lib/streamlit/commands/execution_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Final, NoReturn

import streamlit as st
from streamlit import source_util
from streamlit.deprecation_util import make_deprecated_name_warning
from streamlit.errors import NoSessionContext, StreamlitAPIException
from streamlit.file_util import get_main_script_directory, normalize_path_join
Expand Down Expand Up @@ -143,7 +142,7 @@ def switch_page(page: str) -> NoReturn: # type: ignore[misc]

main_script_directory = get_main_script_directory(ctx.main_script_path)
requested_page = os.path.realpath(normalize_path_join(main_script_directory, page))
all_app_pages = source_util.get_pages(ctx.main_script_path).values()
all_app_pages = ctx.pages_manager.get_pages().values()

matched_pages = [p for p in all_app_pages if p["script_path"] == requested_page]

Expand Down
7 changes: 4 additions & 3 deletions lib/streamlit/elements/widgets/button.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from typing_extensions import TypeAlias

from streamlit import runtime, source_util
from streamlit import runtime
from streamlit.elements.form import current_form_id, is_in_form
from streamlit.errors import StreamlitAPIException
from streamlit.file_util import get_main_script_directory, normalize_path_join
Expand Down Expand Up @@ -694,17 +694,18 @@ def _page_link(

ctx = get_script_run_ctx()
ctx_main_script = ""
all_app_pages = {}
if ctx:
ctx_main_script = ctx.main_script_path
all_app_pages = ctx.pages_manager.get_pages()

main_script_directory = get_main_script_directory(ctx_main_script)
requested_page = os.path.realpath(
normalize_path_join(main_script_directory, page)
)
all_app_pages = source_util.get_pages(ctx_main_script).values()

# Handle retrieving the page_script_hash & page
for page_data in all_app_pages:
for page_data in all_app_pages.values():
full_path = page_data["script_path"]
page_name = page_data["page_name"]
if requested_page == full_path:
Expand Down
35 changes: 17 additions & 18 deletions lib/streamlit/runtime/app_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Callable, Final

import streamlit.elements.exception as exception_utils
from streamlit import config, runtime, source_util
from streamlit import config, runtime
from streamlit.case_converters import to_snake_case
from streamlit.logger import get_logger
from streamlit.proto.BackMsg_pb2 import BackMsg
Expand All @@ -40,6 +40,7 @@
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
from streamlit.runtime.fragment import FragmentStorage, MemoryFragmentStorage
from streamlit.runtime.metrics_util import Installation
from streamlit.runtime.pages_manager import PagesManager
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.scriptrunner import RerunData, ScriptRunner, ScriptRunnerEvent
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
self._script_data = script_data
self._uploaded_file_mgr = uploaded_file_manager
self._script_cache = script_cache
self._pages_manager = PagesManager(script_data.main_script_path)

# The browser queue contains messages that haven't yet been
# delivered to the browser. Periodically, the server flushes
Expand Down Expand Up @@ -181,17 +183,15 @@ def register_file_watchers(self) -> None:
to.
"""
if self._local_sources_watcher is None:
self._local_sources_watcher = LocalSourcesWatcher(
self._script_data.main_script_path
)
self._local_sources_watcher = LocalSourcesWatcher(self._pages_manager)

self._local_sources_watcher.register_file_change_callback(
self._on_source_file_changed
)
self._stop_config_listener = config.on_config_parsed(
self._on_source_file_changed, force_connect=True
)
self._stop_pages_listener = source_util.register_pages_changed_callback(
self._stop_pages_listener = self._pages_manager.register_pages_changed_callback(
self._on_pages_changed
)
secrets_singleton.file_change_listener.connect(self._on_secrets_file_changed)
Expand Down Expand Up @@ -407,6 +407,7 @@ def _create_scriptrunner(self, initial_rerun_data: RerunData) -> None:
initial_rerun_data=initial_rerun_data,
user_info=self._user_info,
fragment_storage=self._fragment_storage,
pages_manager=self._pages_manager,
)
self._scriptrunner.on_event.connect(self._on_scriptrunner_event)
self._scriptrunner.start()
Expand All @@ -416,8 +417,7 @@ def session_state(self) -> SessionState:
return self._session_state

def _should_rerun_on_file_change(self, filepath: str) -> bool:
main_script_path = self._script_data.main_script_path
pages = source_util.get_pages(main_script_path)
pages = self._pages_manager.get_pages()

changed_page_script_hash = next(
filter(lambda k: pages[k]["script_path"] == filepath, pages),
Expand Down Expand Up @@ -454,7 +454,7 @@ def _on_secrets_file_changed(self, _) -> None:

def _on_pages_changed(self, _) -> None:
msg = ForwardMsg()
_populate_app_pages(msg.pages_changed, self._script_data.main_script_path)
self._populate_app_pages(msg.pages_changed)
self._enqueue_forward_msg(msg)

if self._local_sources_watcher is not None:
Expand Down Expand Up @@ -678,7 +678,7 @@ def _create_new_session_message(
if fragment_ids_this_run:
msg.new_session.fragment_ids_this_run.extend(fragment_ids_this_run)

_populate_app_pages(msg.new_session, self._script_data.main_script_path)
self._populate_app_pages(msg.new_session)
_populate_config_msg(msg.new_session.config)
_populate_theme_msg(msg.new_session.custom_theme)

Expand Down Expand Up @@ -829,6 +829,14 @@ def _handle_file_urls_request(self, file_urls_request: FileURLsRequest) -> None:

self._enqueue_forward_msg(msg)

def _populate_app_pages(self, msg: NewSession | PagesChanged) -> None:
for page_script_hash, page_info in self._pages_manager.get_pages().items():
page_proto = msg.app_pages.add()

page_proto.page_script_hash = page_script_hash
page_proto.page_name = page_info["page_name"]
page_proto.icon = page_info["icon"]


# Config.ToolbarMode.ValueType does not exist at runtime (only in the pyi stubs), so
# we need to use quotes.
Expand Down Expand Up @@ -913,12 +921,3 @@ def _populate_theme_msg(msg: CustomThemeConfig) -> None:
def _populate_user_info_msg(msg: UserInfo) -> None:
msg.installation_id = Installation.instance().installation_id
msg.installation_id_v3 = Installation.instance().installation_id_v3


def _populate_app_pages(msg: NewSession | PagesChanged, main_script_path: str) -> None:
for page_script_hash, page_info in source_util.get_pages(main_script_path).items():
page_proto = msg.app_pages.add()

page_proto.page_script_hash = page_script_hash
page_proto.page_name = page_info["page_name"]
page_proto.icon = page_info["icon"]
148 changes: 148 additions & 0 deletions lib/streamlit/runtime/pages_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import threading
from pathlib import Path
from typing import Callable, Final

from blinker import Signal

from streamlit.logger import get_logger
from streamlit.source_util import PageInfo, get_pages
from streamlit.util import calc_md5
from streamlit.watcher import watch_dir

_LOGGER: Final = get_logger(__name__)


class PagesManagerV1:
is_watching_pages_dir: bool = False

# This is a static method because we only want to watch the pages directory
# once on initial load.
@staticmethod
def watch_pages_dir(pages_manager: PagesManager):
lock = threading.Lock()
with lock:
if PagesManagerV1.is_watching_pages_dir:
return

def _on_pages_changed(_path: str) -> None:
pages_manager.invalidate_pages_cache()

main_script_path = Path(pages_manager.main_script_path)
pages_dir = main_script_path.parent / "pages"
watch_dir(
str(pages_dir),
_on_pages_changed,
glob_pattern="*.py",
allow_nonexistent=True,
)
PagesManagerV1.is_watching_pages_dir = True


class PagesManager:
def __init__(self, main_script_path, **kwargs):
self._cached_pages: dict[str, PageInfo] | None = None
self._pages_cache_lock = threading.RLock()
self._on_pages_changed = Signal(doc="Emitted when the set of pages has changed")
self._main_script_path: str = main_script_path
self._main_script_hash: str = calc_md5(main_script_path)
self._current_page_hash: str = self._main_script_hash

if kwargs.get("setup_watcher", True):
PagesManagerV1.watch_pages_dir(self)

@property
def main_script_path(self) -> str:
return self._main_script_path

def get_main_page(self) -> PageInfo:
return {
"script_path": self._main_script_path,
"page_script_hash": self._main_script_hash,
}

def get_current_page_script_hash(self) -> str:
return self._current_page_hash

def set_current_page_script_hash(self, page_hash: str) -> None:
self._current_page_hash = page_hash

def get_active_script(self, page_script_hash: str, page_name: str):
pages = self.get_pages()

if page_script_hash:
return pages.get(page_script_hash, None)
elif not page_script_hash and page_name:
# If a user navigates directly to a non-main page of an app, we get
# the first script run request before the list of pages has been
# sent to the frontend. In this case, we choose the first script
# with a name matching the requested page name.
return next(
filter(
# There seems to be this weird bug with mypy where it
# thinks that p can be None (which is impossible given the
# types of pages), so we add `p and` at the beginning of
# the predicate to circumvent this.
lambda p: p and (p["page_name"] == page_name),
pages.values(),
),
None,
)

# If no information about what page to run is given, default to
# running the main page.
# Safe because pages will at least contain the app's main page.
main_page_info = list(pages.values())[0]
return main_page_info

def get_pages(self) -> dict[str, PageInfo]:
# Avoid taking the lock if the pages cache hasn't been invalidated.
pages = self._cached_pages
if pages is not None:
return pages

with self._pages_cache_lock:
# The cache may have been repopulated while we were waiting to grab
# the lock.
if self._cached_pages is not None:
return self._cached_pages

pages = get_pages(self.main_script_path)
self._cached_pages = pages

return pages

def invalidate_pages_cache(self) -> None:
_LOGGER.debug("Set of pages have changed. Invalidating cache.")
with self._pages_cache_lock:
self._cached_pages = None

self._on_pages_changed.send()

def register_pages_changed_callback(
self,
callback: Callable[[str], None],
) -> Callable[[], None]:
def disconnect():
self._on_pages_changed.disconnect(callback)

# weak=False so that we have control of when the pages changed
# callback is deregistered.
self._on_pages_changed.connect(callback, weak=False)

return disconnect
9 changes: 7 additions & 2 deletions lib/streamlit/runtime/scriptrunner/script_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if TYPE_CHECKING:
from streamlit.runtime.fragment import FragmentStorage
from streamlit.runtime.pages_manager import PagesManager

_LOGGER: Final = get_logger(__name__)

Expand Down Expand Up @@ -60,9 +61,9 @@ class ScriptRunContext:
session_state: SafeSessionState
uploaded_file_mgr: UploadedFileManager
main_script_path: str
page_script_hash: str
user_info: UserInfo
fragment_storage: "FragmentStorage"
pages_manager: "PagesManager"

gather_usage_stats: bool = False
command_tracking_deactivated: bool = False
Expand All @@ -87,6 +88,10 @@ class ScriptRunContext:
_experimental_query_params_used = False
_production_query_params_used = False

@property
def page_script_hash(self):
return self.pages_manager.get_current_page_script_hash()

def reset(
self,
query_string: str = "",
Expand All @@ -98,7 +103,7 @@ def reset(
self.widget_user_keys_this_run = set()
self.form_ids_this_run = set()
self.query_string = query_string
self.page_script_hash = page_script_hash
self.pages_manager.set_current_page_script_hash(page_script_hash)
# Permit set_page_config when the ScriptRunContext is reused on a rerun
self._set_page_config_allowed = True
self._has_script_started = False
Expand Down
Loading

0 comments on commit baf3c96

Please sign in to comment.