Skip to content

Commit

Permalink
Add fixes for type stub generation
Browse files Browse the repository at this point in the history
  • Loading branch information
wch committed Nov 28, 2023
1 parent 70764fc commit c941552
Show file tree
Hide file tree
Showing 16 changed files with 57 additions and 42 deletions.
4 changes: 2 additions & 2 deletions shiny/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _server(inputs: Inputs, outputs: Outputs, session: Session):
cast("Tag | TagList", ui), lib_prefix=self.lib_prefix
)

def init_starlette_app(self):
def init_starlette_app(self) -> starlette.applications.Starlette:
routes: list[starlette.routing.BaseRoute] = [
starlette.routing.WebSocketRoute("/websocket/", self._on_connect_cb),
starlette.routing.Route("/", self._on_root_request_cb, methods=["GET"]),
Expand Down Expand Up @@ -400,7 +400,7 @@ def _render_page_from_file(self, file: Path, lib_prefix: str) -> RenderedHTML:
return rendered


def is_uifunc(x: Path | Tag | TagList | Callable[[Request], Tag | TagList]):
def is_uifunc(x: Path | Tag | TagList | Callable[[Request], Tag | TagList]) -> bool:
if (
isinstance(x, Path)
or isinstance(x, Tag)
Expand Down
6 changes: 3 additions & 3 deletions shiny/_namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from contextlib import contextmanager
from contextvars import ContextVar, Token
from typing import Pattern, Union, overload
from typing import Generator, Pattern, Union, overload


class ResolvedId(str):
Expand Down Expand Up @@ -82,7 +82,7 @@ def resolve_id_or_none(id: Id | None) -> ResolvedId | None:
re_valid_id: Pattern[str] = re.compile("^\\.?\\w+$")


def validate_id(id: str):
def validate_id(id: str) -> None:
if not re_valid_id.match(id):
raise ValueError(
f"The string '{id}' is not a valid id; only letters, numbers, and "
Expand All @@ -97,7 +97,7 @@ def validate_id(id: str):


@contextmanager
def namespace_context(id: Id | None):
def namespace_context(id: Id | None) -> Generator[None, None, None]:
namespace = resolve_id(id) if id else Root
token: Token[ResolvedId | None] = _current_namespace.set(namespace)
try:
Expand Down
4 changes: 2 additions & 2 deletions shiny/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import secrets
import socketserver
import tempfile
from typing import Any, Awaitable, Callable, Optional, TypeVar, cast
from typing import Any, Awaitable, Callable, Generator, Optional, TypeVar, cast

from ._typing_extensions import ParamSpec, TypeGuard

Expand Down Expand Up @@ -200,7 +200,7 @@ def private_random_int(min: int, max: int) -> str:


@contextlib.contextmanager
def private_seed():
def private_seed() -> Generator[None, None, None]:
state = random.getstate()
global own_random_state
try:
Expand Down
8 changes: 4 additions & 4 deletions shiny/express/_is_express.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def __init__(self):
super().__init__()
self.found_shiny_express_import = False

def visit_Import(self, node: ast.Import):
def visit_Import(self, node: ast.Import) -> None:
if any(alias.name == "shiny.express" for alias in node.names):
self.found_shiny_express_import = True

def visit_ImportFrom(self, node: ast.ImportFrom):
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module == "shiny.express":
self.found_shiny_express_import = True
elif node.module == "shiny" and any(
Expand All @@ -69,9 +69,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
self.found_shiny_express_import = True

# Visit top-level nodes.
def visit_Module(self, node: ast.Module):
def visit_Module(self, node: ast.Module) -> None:
super().generic_visit(node)

# Don't recurse into any nodes, so the we'll only ever look at top-level nodes.
def generic_visit(self, node: ast.AST):
def generic_visit(self, node: ast.AST) -> None:
pass
4 changes: 2 additions & 2 deletions shiny/express/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import sys
from contextlib import AbstractContextManager
from typing import Callable, TypeVar, cast, overload
from typing import Callable, Generator, TypeVar, cast, overload

from .. import ui
from .._typing_extensions import ParamSpec
Expand Down Expand Up @@ -109,7 +109,7 @@ def suspend_display(


@contextlib.contextmanager
def suspend_display_ctxmgr():
def suspend_display_ctxmgr() -> Generator[None, None, None]:
oldhook = sys.displayhook
sys.displayhook = null_displayhook
try:
Expand Down
2 changes: 1 addition & 1 deletion shiny/express/_recall_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self.args: list[object] = list(args)
self.kwargs: dict[str, object] = dict(kwargs)

def append_arg(self, value: object):
def append_arg(self, value: object) -> None:
if isinstance(value, (Tag, TagList, Tagifiable)):
self.args.append(value)
elif hasattr(value, "_repr_html_"):
Expand Down
4 changes: 2 additions & 2 deletions shiny/express/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ def set_result(x: object):
_top_level_recall_context_manager_has_been_replaced = False


def reset_top_level_recall_context_manager():
def reset_top_level_recall_context_manager() -> None:
global _top_level_recall_context_manager
global _top_level_recall_context_manager_has_been_replaced
_top_level_recall_context_manager = RecallContextManager(_DEFAULT_PAGE_FUNCTION)
_top_level_recall_context_manager_has_been_replaced = False


def get_top_level_recall_context_manager():
def get_top_level_recall_context_manager() -> RecallContextManager[Tag]:
return _top_level_recall_context_manager


Expand Down
3 changes: 2 additions & 1 deletion shiny/express/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from pathlib import Path

from .._app import App
from ._run import wrap_express_app
from ._utils import unescape_from_var_name


# If someone requests shiny.express.app:_2f_path_2f_to_2f_app_2e_py, then we will call
# wrap_express_app(Path("/path/to/app.py")) and return the result.
def __getattr__(name: str):
def __getattr__(name: str) -> App:
name = unescape_from_var_name(name)
return wrap_express_app(Path(name))
8 changes: 5 additions & 3 deletions shiny/express/display_decorator/_display_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def unwrap(fn: TFunc) -> TFunc:
display_body_attr = "__display_body__"


def display_body_unwrap_inplace():
def display_body_unwrap_inplace() -> Callable[[TFunc], TFunc]:
"""
Like `display_body`, but far more violent. This will attempt to traverse any
decorators between this one and the function, and then modify the function _in
Expand Down Expand Up @@ -73,7 +73,7 @@ def decorator(fn: TFunc) -> TFunc:
return decorator


def display_body():
def display_body() -> Callable[[TFunc], TFunc]:
def decorator(fn: TFunc) -> TFunc:
if fn.__code__ in code_cache:
fcode = code_cache[fn.__code__]
Expand Down Expand Up @@ -176,7 +176,9 @@ def _transform_function_ast(node: ast.AST) -> ast.AST:
return func_node


def compare_decorated_code_objects(func_ast: ast.FunctionDef):
def compare_decorated_code_objects(
func_ast: ast.FunctionDef,
) -> Callable[[types.CodeType, types.CodeType], bool]:
linenos = [*[x.lineno for x in func_ast.decorator_list], func_ast.lineno]

def comparator(candidate: types.CodeType, target: types.CodeType) -> bool:
Expand Down
28 changes: 16 additions & 12 deletions shiny/express/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from .. import ui
from ..types import MISSING, MISSING_TYPE
from ..ui._accordion import AccordionPanel
from ..ui._navs import Nav, NavSet, NavSetCard
from ..ui.css import CssUnit
from . import _run
from ._recall_context import RecallContextManager, wrap_recall_context_manager
Expand All @@ -33,7 +35,7 @@
# ======================================================================================
# Page functions
# ======================================================================================
def set_page(page_fn: RecallContextManager[Tag]):
def set_page(page_fn: RecallContextManager[Tag]) -> None:
"""Set the page function for the current Shiny express app."""
_run.replace_top_level_recall_context_manager(page_fn, force=True)

Expand Down Expand Up @@ -151,7 +153,7 @@ def layout_column_wrap(
gap: Optional[CssUnit] = None,
class_: Optional[str] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
A grid-like, column-first layout
Expand Down Expand Up @@ -221,7 +223,9 @@ def layout_column_wrap(
)


def column(width: int, *, offset: int = 0, **kwargs: TagAttrValue):
def column(
width: int, *, offset: int = 0, **kwargs: TagAttrValue
) -> RecallContextManager[Tag]:
"""
Responsive row-column based layout
Expand Down Expand Up @@ -256,7 +260,7 @@ def column(width: int, *, offset: int = 0, **kwargs: TagAttrValue):
)


def row(**kwargs: TagAttrValue):
def row(**kwargs: TagAttrValue) -> RecallContextManager[Tag]:
"""
Responsive row-column based layout
Expand Down Expand Up @@ -294,7 +298,7 @@ def card(
fill: bool = True,
class_: Optional[str] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
A Bootstrap card component
Expand Down Expand Up @@ -354,7 +358,7 @@ def accordion(
width: Optional[CssUnit] = None,
height: Optional[CssUnit] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
Create a vertically collapsing accordion.
Expand Down Expand Up @@ -410,7 +414,7 @@ def accordion_panel(
value: Optional[str] | MISSING_TYPE = MISSING,
icon: Optional[TagChild] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[AccordionPanel]:
"""
Single accordion panel.
Expand Down Expand Up @@ -455,7 +459,7 @@ def navset_tab(
selected: Optional[str] = None,
header: TagChild = None,
footer: TagChild = None,
):
) -> RecallContextManager[NavSet]:
"""
Render nav items as a tabset.
Expand Down Expand Up @@ -493,7 +497,7 @@ def navset_card_tab(
sidebar: Optional[ui.Sidebar] = None,
header: TagChild = None,
footer: TagChild = None,
):
) -> RecallContextManager[NavSetCard]:
"""
Render nav items as a tabset inside a card container.
Expand Down Expand Up @@ -549,7 +553,7 @@ def nav(
*,
value: Optional[str] = None,
icon: TagChild = None,
):
) -> RecallContextManager[Nav]:
"""
Create a nav item pointing to some internal content.
Expand Down Expand Up @@ -627,7 +631,7 @@ def page_fillable(
title: Optional[str] = None,
lang: Optional[str] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
Creates a fillable page.
Expand Down Expand Up @@ -678,7 +682,7 @@ def page_sidebar(
window_title: str | MISSING_TYPE = MISSING,
lang: Optional[str] = None,
**kwargs: TagAttrValue,
):
) -> RecallContextManager[Tag]:
"""
Create a page with a sidebar and a title.
Expand Down
2 changes: 1 addition & 1 deletion shiny/http_staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class StaticFiles:
def __init__(self, *, directory: str | os.PathLike[str]):
self.dir = pathlib.Path(os.path.realpath(os.path.normpath(directory)))

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
raise AssertionError("StaticFiles can't handle non-http request")
path = scope["path"]
Expand Down
2 changes: 1 addition & 1 deletion shiny/input_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _(func: InputHandlerType):

return _

def remove(self, type: str):
def remove(self, type: str) -> None:
del self[type]

def _process_value(self, type: str, value: Any, name: str, session: Session) -> Any:
Expand Down
6 changes: 3 additions & 3 deletions shiny/reactive/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import typing
import warnings
from contextvars import ContextVar
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, TypeVar
from typing import TYPE_CHECKING, Awaitable, Callable, Generator, Optional, TypeVar

from .. import _utils
from .._datastructures import PriorityQueueFIFO
Expand Down Expand Up @@ -188,7 +188,7 @@ def add_pending_flush(self, ctx: Context, priority: int) -> None:
self._pending_flush_queue.put(priority, ctx)

@contextlib.contextmanager
def isolate(self):
def isolate(self) -> Generator[None, None, None]:
token = self._current_context.set(Context())
try:
yield
Expand All @@ -201,7 +201,7 @@ def isolate(self):

@add_example()
@contextlib.contextmanager
def isolate():
def isolate() -> Generator[None, None, None]:
"""
Create a non-reactive scope within a reactive scope.
Expand Down
6 changes: 3 additions & 3 deletions shiny/render/_try_render_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import base64
import io
import warnings
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, cast

from ..types import ImgData, PlotnineFigure
from ._coordmap import get_coordmap, get_coordmap_plotnine

TryPlotResult = Tuple[bool, "ImgData| None"]
TryPlotResult = Tuple[bool, Union[ImgData, None]]


if TYPE_CHECKING:
Expand Down Expand Up @@ -376,7 +376,7 @@ def try_render_plotnine(
#
# One negative consequence of this logic: if the user intentionally set the dpi to
# rcParam * device_pixel_ratio, we're going to ignore it.
def get_desired_dpi_from_fig(fig: Figure):
def get_desired_dpi_from_fig(fig: Figure) -> float:
ppi_out = fig.get_dpi()

if fig.canvas.device_pixel_ratio != 1 and hasattr(fig, "_original_dpi"):
Expand Down
6 changes: 6 additions & 0 deletions shiny/render/transformer/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Generic,
NamedTuple,
Optional,
Type,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -572,6 +573,11 @@ class OutputTransformer(Generic[IT, OT, P]):
* :class:`~shiny.render.transformer.OutputRenderer`
"""

fn: OutputTransformerFn[IT, P, OT]
ValueFn: Type[ValueFn[IT]]
OutputRenderer: Type[OutputRenderer[OT]]
OutputRendererDecorator: Type[OutputRendererDecorator[IT, OT]]

def params(
self,
*args: P.args,
Expand Down
Loading

0 comments on commit c941552

Please sign in to comment.