From 186531dde85675179547845611ad0643fc9dac55 Mon Sep 17 00:00:00 2001 From: Winston Chang Date: Tue, 19 Dec 2023 13:58:22 -0600 Subject: [PATCH] Add fixes for type stub generation --- shiny/_app.py | 4 +-- shiny/_namespaces.py | 6 ++-- shiny/_utils.py | 4 +-- shiny/express/_is_express.py | 8 ++--- shiny/express/_output.py | 4 +-- shiny/express/_run.py | 4 +-- shiny/express/app.py | 3 +- .../display_decorator/_display_body.py | 8 +++-- shiny/express/layout.py | 30 +++++++++++-------- shiny/http_staticfiles.py | 2 +- shiny/input_handler.py | 2 +- shiny/reactive/_core.py | 6 ++-- shiny/render/_try_render_plot.py | 6 ++-- shiny/render/transformer/_transformer.py | 6 ++++ shiny/session/_session.py | 6 ++-- 15 files changed, 57 insertions(+), 42 deletions(-) diff --git a/shiny/_app.py b/shiny/_app.py index 664a2238a..3e4f075a6 100644 --- a/shiny/_app.py +++ b/shiny/_app.py @@ -166,7 +166,7 @@ def _server(inputs: Inputs, outputs: Outputs, session: Session): cast("Tag | TagList", ui), lib_prefix=self.lib_prefix ) - def init_starlette_app(self): + def init_starlette_app(self) -> starlette.applications.Starlette: routes: list[starlette.routing.BaseRoute] = [ starlette.routing.WebSocketRoute("/websocket/", self._on_connect_cb), starlette.routing.Route("/", self._on_root_request_cb, methods=["GET"]), @@ -400,7 +400,7 @@ def _render_page_from_file(self, file: Path, lib_prefix: str) -> RenderedHTML: return rendered -def is_uifunc(x: Path | Tag | TagList | Callable[[Request], Tag | TagList]): +def is_uifunc(x: Path | Tag | TagList | Callable[[Request], Tag | TagList]) -> bool: if ( isinstance(x, Path) or isinstance(x, Tag) diff --git a/shiny/_namespaces.py b/shiny/_namespaces.py index d031d52b0..b52584b46 100644 --- a/shiny/_namespaces.py +++ b/shiny/_namespaces.py @@ -5,7 +5,7 @@ import re from contextlib import contextmanager from contextvars import ContextVar, Token -from typing import Pattern, Union, overload +from typing import Generator, Pattern, Union, overload class ResolvedId(str): @@ -82,7 +82,7 @@ def resolve_id_or_none(id: Id | None) -> ResolvedId | None: re_valid_id: Pattern[str] = re.compile("^\\.?\\w+$") -def validate_id(id: str): +def validate_id(id: str) -> None: if not re_valid_id.match(id): raise ValueError( f"The string '{id}' is not a valid id; only letters, numbers, and " @@ -97,7 +97,7 @@ def validate_id(id: str): @contextmanager -def namespace_context(id: Id | None): +def namespace_context(id: Id | None) -> Generator[None, None, None]: namespace = resolve_id(id) if id else Root token: Token[ResolvedId | None] = _current_namespace.set(namespace) try: diff --git a/shiny/_utils.py b/shiny/_utils.py index 2ec1ebc16..3aa6a5517 100644 --- a/shiny/_utils.py +++ b/shiny/_utils.py @@ -11,7 +11,7 @@ import secrets import socketserver import tempfile -from typing import Any, Awaitable, Callable, Optional, TypeVar, cast +from typing import Any, Awaitable, Callable, Generator, Optional, TypeVar, cast from ._typing_extensions import ParamSpec, TypeGuard @@ -200,7 +200,7 @@ def private_random_int(min: int, max: int) -> str: @contextlib.contextmanager -def private_seed(): +def private_seed() -> Generator[None, None, None]: state = random.getstate() global own_random_state try: diff --git a/shiny/express/_is_express.py b/shiny/express/_is_express.py index ce36a6eaa..f6694a8ad 100644 --- a/shiny/express/_is_express.py +++ b/shiny/express/_is_express.py @@ -56,11 +56,11 @@ def __init__(self): super().__init__() self.found_shiny_express_import = False - def visit_Import(self, node: ast.Import): + def visit_Import(self, node: ast.Import) -> None: if any(alias.name == "shiny.express" for alias in node.names): self.found_shiny_express_import = True - def visit_ImportFrom(self, node: ast.ImportFrom): + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if node.module == "shiny.express": self.found_shiny_express_import = True elif node.module == "shiny" and any( @@ -69,9 +69,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom): self.found_shiny_express_import = True # Visit top-level nodes. - def visit_Module(self, node: ast.Module): + def visit_Module(self, node: ast.Module) -> None: super().generic_visit(node) # Don't recurse into any nodes, so the we'll only ever look at top-level nodes. - def generic_visit(self, node: ast.AST): + def generic_visit(self, node: ast.AST) -> None: pass diff --git a/shiny/express/_output.py b/shiny/express/_output.py index 8733b07ec..6c72e6f7e 100644 --- a/shiny/express/_output.py +++ b/shiny/express/_output.py @@ -3,7 +3,7 @@ import contextlib import sys from contextlib import AbstractContextManager -from typing import Callable, TypeVar, cast, overload +from typing import Callable, Generator, TypeVar, cast, overload from .. import ui from .._typing_extensions import ParamSpec @@ -109,7 +109,7 @@ def suspend_display( @contextlib.contextmanager -def suspend_display_ctxmgr(): +def suspend_display_ctxmgr() -> Generator[None, None, None]: oldhook = sys.displayhook sys.displayhook = null_displayhook try: diff --git a/shiny/express/_run.py b/shiny/express/_run.py index 001c3139a..e843e558c 100644 --- a/shiny/express/_run.py +++ b/shiny/express/_run.py @@ -136,14 +136,14 @@ def set_result(x: object): _top_level_recall_context_manager_has_been_replaced = False -def reset_top_level_recall_context_manager(): +def reset_top_level_recall_context_manager() -> None: global _top_level_recall_context_manager global _top_level_recall_context_manager_has_been_replaced _top_level_recall_context_manager = RecallContextManager(_DEFAULT_PAGE_FUNCTION) _top_level_recall_context_manager_has_been_replaced = False -def get_top_level_recall_context_manager(): +def get_top_level_recall_context_manager() -> RecallContextManager[Tag]: return _top_level_recall_context_manager diff --git a/shiny/express/app.py b/shiny/express/app.py index f309e9262..9612d0892 100644 --- a/shiny/express/app.py +++ b/shiny/express/app.py @@ -2,12 +2,13 @@ from pathlib import Path +from .._app import App from ._run import wrap_express_app from ._utils import unescape_from_var_name # If someone requests shiny.express.app:_2f_path_2f_to_2f_app_2e_py, then we will call # wrap_express_app(Path("/path/to/app.py")) and return the result. -def __getattr__(name: str): +def __getattr__(name: str) -> App: name = unescape_from_var_name(name) return wrap_express_app(Path(name)) diff --git a/shiny/express/display_decorator/_display_body.py b/shiny/express/display_decorator/_display_body.py index 36e17e6e1..526d01ee3 100644 --- a/shiny/express/display_decorator/_display_body.py +++ b/shiny/express/display_decorator/_display_body.py @@ -48,7 +48,7 @@ def unwrap(fn: TFunc) -> TFunc: display_body_attr = "__display_body__" -def display_body_unwrap_inplace(): +def display_body_unwrap_inplace() -> Callable[[TFunc], TFunc]: """ Like `display_body`, but far more violent. This will attempt to traverse any decorators between this one and the function, and then modify the function _in @@ -76,7 +76,7 @@ def decorator(fn: TFunc) -> TFunc: return decorator -def display_body(): +def display_body() -> Callable[[TFunc], TFunc]: def decorator(fn: TFunc) -> TFunc: if fn.__code__ in code_cache: fcode = code_cache[fn.__code__] @@ -197,7 +197,9 @@ def _transform_function_ast(node: ast.AST) -> ast.AST: return func_node -def compare_decorated_code_objects(func_ast: ast.FunctionDef): +def compare_decorated_code_objects( + func_ast: ast.FunctionDef, +) -> Callable[[types.CodeType, types.CodeType], bool]: linenos = [*[x.lineno for x in func_ast.decorator_list], func_ast.lineno] def comparator(candidate: types.CodeType, target: types.CodeType) -> bool: diff --git a/shiny/express/layout.py b/shiny/express/layout.py index 3ea390868..e65a8dc87 100644 --- a/shiny/express/layout.py +++ b/shiny/express/layout.py @@ -7,7 +7,9 @@ from .. import ui from ..types import MISSING, MISSING_TYPE +from ..ui._accordion import AccordionPanel from ..ui._layout_columns import BreakpointsUser +from ..ui._navs import NavPanel, NavSet, NavSetCard from ..ui.css import CssUnit from . import _run from ._recall_context import RecallContextManager, wrap_recall_context_manager @@ -39,7 +41,7 @@ # ====================================================================================== # Page functions # ====================================================================================== -def set_page(page_fn: RecallContextManager[Tag]): +def set_page(page_fn: RecallContextManager[Tag]) -> None: """Set the page function for the current Shiny express app.""" _run.replace_top_level_recall_context_manager(page_fn, force=True) @@ -162,7 +164,7 @@ def layout_column_wrap( gap: Optional[CssUnit] = None, class_: Optional[str] = None, **kwargs: TagAttrValue, -): +) -> RecallContextManager[Tag]: """ A grid-like, column-first layout @@ -252,7 +254,7 @@ def layout_columns( class_: Optional[str] = None, height: Optional[CssUnit] = None, **kwargs: TagAttrValue, -): +) -> RecallContextManager[Tag]: """ Create responsive, column-based grid layouts, based on a 12-column grid. @@ -346,7 +348,9 @@ def layout_columns( ) -def column(width: int, *, offset: int = 0, **kwargs: TagAttrValue): +def column( + width: int, *, offset: int = 0, **kwargs: TagAttrValue +) -> RecallContextManager[Tag]: """ Responsive row-column based layout @@ -381,7 +385,7 @@ def column(width: int, *, offset: int = 0, **kwargs: TagAttrValue): ) -def row(**kwargs: TagAttrValue): +def row(**kwargs: TagAttrValue) -> RecallContextManager[Tag]: """ Responsive row-column based layout @@ -419,7 +423,7 @@ def card( fill: bool = True, class_: Optional[str] = None, **kwargs: TagAttrValue, -): +) -> RecallContextManager[Tag]: """ A Bootstrap card component @@ -481,7 +485,7 @@ def accordion( width: Optional[CssUnit] = None, height: Optional[CssUnit] = None, **kwargs: TagAttrValue, -): +) -> RecallContextManager[Tag]: """ Create a vertically collapsing accordion. @@ -537,7 +541,7 @@ def accordion_panel( value: Optional[str] | MISSING_TYPE = MISSING, icon: Optional[TagChild] = None, **kwargs: TagAttrValue, -): +) -> RecallContextManager[AccordionPanel]: """ Single accordion panel. @@ -583,7 +587,7 @@ def navset( selected: Optional[str] = None, header: TagChild = None, footer: TagChild = None, -): +) -> RecallContextManager[NavSet]: """ Render a set of nav items @@ -635,7 +639,7 @@ def navset_card( sidebar: Optional[ui.Sidebar] = None, header: TagChild = None, footer: TagChild = None, -): +) -> RecallContextManager[NavSetCard]: """ Render a set of nav items inside a card container. @@ -687,7 +691,7 @@ def nav_panel( *, value: Optional[str] = None, icon: TagChild = None, -): +) -> RecallContextManager[NavPanel]: """ Create a nav item pointing to some internal content. @@ -803,7 +807,7 @@ def page_fillable( title: Optional[str] = None, lang: Optional[str] = None, **kwargs: TagAttrValue, -): +) -> RecallContextManager[Tag]: """ Creates a fillable page. @@ -854,7 +858,7 @@ def page_sidebar( window_title: str | MISSING_TYPE = MISSING, lang: Optional[str] = None, **kwargs: TagAttrValue, -): +) -> RecallContextManager[Tag]: """ Create a page with a sidebar and a title. diff --git a/shiny/http_staticfiles.py b/shiny/http_staticfiles.py index 3390d5eb8..db38f543c 100644 --- a/shiny/http_staticfiles.py +++ b/shiny/http_staticfiles.py @@ -50,7 +50,7 @@ class StaticFiles: def __init__(self, *, directory: str | os.PathLike[str]): self.dir = pathlib.Path(os.path.realpath(os.path.normpath(directory))) - async def __call__(self, scope: Scope, receive: Receive, send: Send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": raise AssertionError("StaticFiles can't handle non-http request") path = scope["path"] diff --git a/shiny/input_handler.py b/shiny/input_handler.py index abd48cdb2..982271480 100644 --- a/shiny/input_handler.py +++ b/shiny/input_handler.py @@ -28,7 +28,7 @@ def _(func: InputHandlerType): return _ - def remove(self, type: str): + def remove(self, type: str) -> None: del self[type] def _process_value(self, type: str, value: Any, name: str, session: Session) -> Any: diff --git a/shiny/reactive/_core.py b/shiny/reactive/_core.py index 40576c527..699d557d4 100644 --- a/shiny/reactive/_core.py +++ b/shiny/reactive/_core.py @@ -18,7 +18,7 @@ import typing import warnings from contextvars import ContextVar -from typing import TYPE_CHECKING, Awaitable, Callable, Optional, TypeVar +from typing import TYPE_CHECKING, Awaitable, Callable, Generator, Optional, TypeVar from .. import _utils from .._datastructures import PriorityQueueFIFO @@ -188,7 +188,7 @@ def add_pending_flush(self, ctx: Context, priority: int) -> None: self._pending_flush_queue.put(priority, ctx) @contextlib.contextmanager - def isolate(self): + def isolate(self) -> Generator[None, None, None]: token = self._current_context.set(Context()) try: yield @@ -201,7 +201,7 @@ def isolate(self): @add_example() @contextlib.contextmanager -def isolate(): +def isolate() -> Generator[None, None, None]: """ Create a non-reactive scope within a reactive scope. diff --git a/shiny/render/_try_render_plot.py b/shiny/render/_try_render_plot.py index 8e77c91a7..75f22b383 100644 --- a/shiny/render/_try_render_plot.py +++ b/shiny/render/_try_render_plot.py @@ -3,12 +3,12 @@ import base64 import io import warnings -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, cast from ..types import ImgData, PlotnineFigure from ._coordmap import get_coordmap, get_coordmap_plotnine -TryPlotResult = Tuple[bool, "ImgData| None"] +TryPlotResult = Tuple[bool, Union[ImgData, None]] if TYPE_CHECKING: @@ -376,7 +376,7 @@ def try_render_plotnine( # # One negative consequence of this logic: if the user intentionally set the dpi to # rcParam * device_pixel_ratio, we're going to ignore it. -def get_desired_dpi_from_fig(fig: Figure): +def get_desired_dpi_from_fig(fig: Figure) -> float: ppi_out = fig.get_dpi() if fig.canvas.device_pixel_ratio != 1 and hasattr(fig, "_original_dpi"): diff --git a/shiny/render/transformer/_transformer.py b/shiny/render/transformer/_transformer.py index 0392f4d97..31528859c 100644 --- a/shiny/render/transformer/_transformer.py +++ b/shiny/render/transformer/_transformer.py @@ -31,6 +31,7 @@ Generic, NamedTuple, Optional, + Type, TypeVar, Union, cast, @@ -572,6 +573,11 @@ class OutputTransformer(Generic[IT, OT, P]): * :class:`~shiny.render.transformer.OutputRenderer` """ + fn: OutputTransformerFn[IT, P, OT] + ValueFn: Type[ValueFn[IT]] + OutputRenderer: Type[OutputRenderer[OT]] + OutputRendererDecorator: Type[OutputRendererDecorator[IT, OT]] + def params( self, *args: P.args, diff --git a/shiny/session/_session.py b/shiny/session/_session.py index b46eec44c..1bfeaf453 100644 --- a/shiny/session/_session.py +++ b/shiny/session/_session.py @@ -25,6 +25,7 @@ Iterable, Optional, TypeVar, + Union, cast, overload, ) @@ -107,7 +108,8 @@ class ClientMessageOther(ClientMessage): # # (Not currently supported is Awaitable[str], could be added easily enough if needed.) DownloadHandler = Callable[ - [], "str | Iterable[bytes | str] | AsyncIterable[bytes | str]" + [], + Union[str, Iterable[Union[bytes, str]], AsyncIterable[Union[bytes, str]]], ] DynamicRouteHandler = Callable[[Request], ASGIApp] @@ -1068,7 +1070,7 @@ async def output_obs(): else: return set_renderer(renderer_fn) - def remove(self, id: Id): + def remove(self, id: Id) -> None: output_name = self._ns(id) if output_name in self._effects: self._effects[output_name].destroy()