Skip to content

Commit

Permalink
Address Pages Manager lock behavior and types (#8650)
Browse files Browse the repository at this point in the history
## Describe your changes

From #8639, @raethlein made a good claim about the lock as well as some
types to make things clear, so I addressed those here.

## Testing Plan

- Type checks should pass
- Existing tests should pass with lock behavior (we don't usually manage
concurrency testing at the moment)

---

**Contribution License Agreement**

By submitting this pull request you agree that all contributions to this
project are made under the Apache 2.0 license.
  • Loading branch information
kmcgrady committed May 20, 2024
1 parent baf3c96 commit f0b78be
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
29 changes: 17 additions & 12 deletions lib/streamlit/runtime/pages_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from blinker import Signal

from streamlit.logger import get_logger
from streamlit.source_util import PageInfo, get_pages
from streamlit.source_util import PageHash, PageInfo, PageName, ScriptPath, get_pages
from streamlit.util import calc_md5
from streamlit.watcher import watch_dir

Expand All @@ -30,13 +30,13 @@

class PagesManagerV1:
is_watching_pages_dir: bool = False
pages_watcher_lock = threading.Lock()

# This is a static method because we only want to watch the pages directory
# once on initial load.
@staticmethod
def watch_pages_dir(pages_manager: PagesManager):
lock = threading.Lock()
with lock:
with PagesManagerV1.pages_watcher_lock:
if PagesManagerV1.is_watching_pages_dir:
return

Expand All @@ -56,18 +56,18 @@ def _on_pages_changed(_path: str) -> None:

class PagesManager:
def __init__(self, main_script_path, **kwargs):
self._cached_pages: dict[str, PageInfo] | None = None
self._cached_pages: dict[PageHash, PageInfo] | None = None
self._pages_cache_lock = threading.RLock()
self._on_pages_changed = Signal(doc="Emitted when the set of pages has changed")
self._main_script_path: str = main_script_path
self._main_script_hash: str = calc_md5(main_script_path)
self._current_page_hash: str = self._main_script_hash
self._main_script_path: ScriptPath = main_script_path
self._main_script_hash: PageHash = calc_md5(main_script_path)
self._current_page_hash: PageHash = self._main_script_hash

if kwargs.get("setup_watcher", True):
PagesManagerV1.watch_pages_dir(self)

@property
def main_script_path(self) -> str:
def main_script_path(self) -> ScriptPath:
return self._main_script_path

def get_main_page(self) -> PageInfo:
Expand All @@ -76,13 +76,13 @@ def get_main_page(self) -> PageInfo:
"page_script_hash": self._main_script_hash,
}

def get_current_page_script_hash(self) -> str:
def get_current_page_script_hash(self) -> PageHash:
return self._current_page_hash

def set_current_page_script_hash(self, page_hash: str) -> None:
def set_current_page_script_hash(self, page_hash: PageHash) -> None:
self._current_page_hash = page_hash

def get_active_script(self, page_script_hash: str, page_name: str):
def get_active_script(self, page_script_hash: PageHash, page_name: PageName):
pages = self.get_pages()

if page_script_hash:
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_active_script(self, page_script_hash: str, page_name: str):
main_page_info = list(pages.values())[0]
return main_page_info

def get_pages(self) -> dict[str, PageInfo]:
def get_pages(self) -> dict[PageHash, PageInfo]:
# Avoid taking the lock if the pages cache hasn't been invalidated.
pages = self._cached_pages
if pages is not None:
Expand Down Expand Up @@ -138,6 +138,11 @@ def register_pages_changed_callback(
self,
callback: Callable[[str], None],
) -> Callable[[], None]:
"""Register a callback to be called when the set of pages changes.
The callback will be called with the path changed.
"""

def disconnect():
self._on_pages_changed.disconnect(callback)

Expand Down
15 changes: 10 additions & 5 deletions lib/streamlit/source_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,22 @@
from pathlib import Path
from typing import Any, TypedDict, cast

from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypeAlias

from streamlit.string_util import extract_leading_emoji
from streamlit.util import calc_md5

PageHash: TypeAlias = str
PageName: TypeAlias = str
ScriptPath: TypeAlias = str
Icon: TypeAlias = str


class PageInfo(TypedDict):
script_path: str
page_script_hash: str
icon: NotRequired[str]
page_name: NotRequired[str]
script_path: ScriptPath
page_script_hash: PageHash
icon: NotRequired[Icon]
page_name: NotRequired[PageName]


def open_python_file(filename: str):
Expand Down

0 comments on commit f0b78be

Please sign in to comment.