Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: preserve magicgui-decorated function parameter hints with ParamSpec #600

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions src/magicgui/type_map/_magicgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@
from magicgui.widgets import FunctionGui, MainFunctionGui

if TYPE_CHECKING:
from typing_extensions import ParamSpec

from magicgui.application import AppRef

_P = ParamSpec("_P")

__all__ = ["magicgui", "magic_factory", "MagicFactory"]


_R = TypeVar("_R")
_T = TypeVar("_T", bound=FunctionGui)
_FGuiVar = TypeVar("_FGuiVar", bound=FunctionGui)


@overload
def magicgui(
function: Callable[..., _R],
function: Callable[_P, _R],
*,
layout: str = "horizontal",
scrollable: bool = False,
Expand All @@ -40,7 +45,7 @@ def magicgui(
persist: bool = False,
raise_on_unknown: bool = False,
**param_options: dict,
) -> FunctionGui[_R]:
) -> FunctionGui[_P, _R]:
...


Expand All @@ -60,13 +65,13 @@ def magicgui(
persist: bool = False,
raise_on_unknown: bool = False,
**param_options: dict,
) -> Callable[[Callable[..., _R]], FunctionGui[_R]]:
) -> Callable[[Callable[_P, _R]], FunctionGui[_P, _R]]:
...


@overload
def magicgui(
function: Callable[..., _R],
function: Callable[_P, _R],
*,
layout: str = "horizontal",
scrollable: bool = False,
Expand All @@ -80,7 +85,7 @@ def magicgui(
persist: bool = False,
raise_on_unknown: bool = False,
**param_options: dict,
) -> MainFunctionGui[_R]:
) -> MainFunctionGui[_P, _R]:
...


Expand All @@ -100,7 +105,7 @@ def magicgui(
persist: bool = False,
raise_on_unknown: bool = False,
**param_options: dict,
) -> Callable[[Callable[..., _R]], MainFunctionGui[_R]]:
) -> Callable[[Callable[_P, _R]], MainFunctionGui[_P, _R]]:
...


Expand Down Expand Up @@ -206,7 +211,7 @@ def magicgui(

@overload
def magic_factory(
function: Callable[..., _R],
function: Callable[_P, _R],
*,
layout: str = "horizontal",
scrollable: bool = False,
Expand All @@ -221,7 +226,7 @@ def magic_factory(
widget_init: Callable[[FunctionGui], None] | None = None,
raise_on_unknown: bool = False,
**param_options: dict,
) -> MagicFactory[_R, FunctionGui]:
) -> MagicFactory[FunctionGui[_P, _R]]:
...


Expand All @@ -242,13 +247,13 @@ def magic_factory(
widget_init: Callable[[FunctionGui], None] | None = None,
raise_on_unknown: bool = False,
**param_options: dict,
) -> Callable[[Callable[..., _R]], MagicFactory[_R, FunctionGui]]:
) -> Callable[[Callable[_P, _R]], MagicFactory[FunctionGui[_P, _R]]]:
...


@overload
def magic_factory(
function: Callable[..., _R],
function: Callable[_P, _R],
*,
layout: str = "horizontal",
scrollable: bool = False,
Expand All @@ -263,7 +268,7 @@ def magic_factory(
widget_init: Callable[[FunctionGui], None] | None = None,
raise_on_unknown: bool = False,
**param_options: dict,
) -> MagicFactory[_R, MainFunctionGui]:
) -> MagicFactory[MainFunctionGui[_P, _R]]:
...


Expand All @@ -284,7 +289,7 @@ def magic_factory(
widget_init: Callable[[FunctionGui], None] | None = None,
raise_on_unknown: bool = False,
**param_options: dict,
) -> Callable[[Callable[..., _R]], MagicFactory[_R, MainFunctionGui]]:
) -> Callable[[Callable[_P, _R]], MagicFactory[MainFunctionGui[_P, _R]]]:
...


Expand Down Expand Up @@ -418,7 +423,7 @@ def magic_factory(

# _R is the return type of the decorated function
# _T is the type of the FunctionGui instance (FunctionGui or MainFunctionGui)
class MagicFactory(partial, Generic[_R, _T]):
class MagicFactory(partial, Generic[_FGuiVar]):
"""Factory function that returns a FunctionGui instance.

While this can be used directly, (see example below) the preferred usage is
Expand All @@ -436,15 +441,17 @@ class MagicFactory(partial, Generic[_R, _T]):
>>> widget2 = factory(auto_call=True, labels=True)
"""

_widget_init: Callable[[_T], None] | None = None
func: Callable[..., _T]
_widget_init: Callable[[_FGuiVar], None] | None = None
# func here is the function that will be called to create the widget
# i.e. it will be either the FunctionGui or MainFunctionGui class
func: Callable[..., _FGuiVar]

def __new__(
cls,
function: Callable[..., _R],
function: Callable,
*args: Any,
magic_class: type[_T] = FunctionGui, # type: ignore
widget_init: Callable[[_T], None] | None = None,
magic_class: type[_FGuiVar] = FunctionGui, # type: ignore
widget_init: Callable[[_FGuiVar], None] | None = None,
**keywords: Any,
) -> MagicFactory:
"""Create new MagicFactory."""
Expand Down Expand Up @@ -477,7 +484,9 @@ def __repr__(self) -> str:
]
return f"MagicFactory({', '.join(args)})"

def __call__(self, *args: Any, **kwargs: Any) -> _T:
# TODO: annotate args and kwargs here so that
# calling a MagicFactory instance gives proper mypy hints
def __call__(self, *args: Any, **kwargs: Any) -> _FGuiVar:
"""Call the wrapped _magicgui and return a FunctionGui."""
if args:
raise ValueError("MagicFactory instance only accept keyword arguments")
Expand Down
2 changes: 1 addition & 1 deletion src/magicgui/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ChoicesDict(TypedDict):
#: be provided an instance of a
#: [~magicgui.widgets.FunctionGui][magicgui.widgets.FunctionGui],
#: the result of the function that was called, and the return annotation itself.
ReturnCallback = Callable[["FunctionGui[Any]", Any, type], None]
ReturnCallback = Callable[["FunctionGui", Any, type], None]
#: A valid file path type
PathLike = Union[Path, str, bytes]

Expand Down
21 changes: 16 additions & 5 deletions src/magicgui/widgets/_function_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@
if TYPE_CHECKING:
from pathlib import Path

from typing_extensions import ParamSpec

from magicgui.application import Application, AppRef # noqa: F401
from magicgui.widgets import TextEdit
from magicgui.widgets.protocols import ContainerProtocol, MainWindowProtocol

_P = ParamSpec("_P")
else:
_P = TypeVar("_P") # easier runtime dependency than ParamSpec


def _inject_tooltips_from_docstrings(
docstring: str | None, sig: MagicSignature
Expand Down Expand Up @@ -70,7 +76,7 @@ def _inject_tooltips_from_docstrings(
_VT = TypeVar("_VT")


class FunctionGui(Container, Generic[_R]):
class FunctionGui(Container, Generic[_P, _R]):
"""Wrapper for a container of widgets representing a callable object.

Parameters
Expand Down Expand Up @@ -129,7 +135,7 @@ class FunctionGui(Container, Generic[_R]):

def __init__(
self,
function: Callable[..., _R],
function: Callable[_P, _R],
call_button: bool | str | None = None,
layout: str = "vertical",
scrollable: bool = False,
Expand Down Expand Up @@ -276,9 +282,12 @@ def __signature__(self) -> MagicSignature:
"""Return a MagicSignature object representing the current state of the gui."""
return super().__signature__.replace(return_annotation=self.return_annotation)

def __call__(self, *args: Any, update_widget: bool = False, **kwargs: Any) -> _R:
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""Call the original function with the current parameter values from the Gui.

You may pass a `update_widget=True` keyword argument to update the widget
values to match the current parameter values before calling the function.

It is also possible to override the current parameter values from the GUI by
providing args/kwargs to the function call. Only those provided will override
the ones from the gui. A `called` signal will also be emitted with the results.
Expand All @@ -298,6 +307,8 @@ def __call__(self, *args: Any, update_widget: bool = False, **kwargs: Any) -> _R
gui() # calls the original function with the current parameters
```
"""
update_widget: bool = bool(kwargs.pop("update_widget", False))

sig = self.__signature__
try:
bound = sig.bind(*args, **kwargs)
Expand Down Expand Up @@ -441,12 +452,12 @@ def _load(self, path: str | Path | None = None, quiet: bool = False) -> None:
super()._load(path or self._dump_path, quiet=quiet)


class MainFunctionGui(FunctionGui[_R], MainWindow):
class MainFunctionGui(FunctionGui[_P, _R], MainWindow):
"""Container of widgets as a Main Application Window."""

_widget: MainWindowProtocol

def __init__(self, function: Callable, *args: Any, **kwargs: Any) -> None:
def __init__(self, function: Callable[_P, _R], *args: Any, **kwargs: Any) -> None:
super().__init__(function, *args, **kwargs)
self.create_menu_item("Help", "Documentation", callback=self._show_docs)
self._help_text_edit: TextEdit | None = None
Expand Down