Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create return matcher for extensible return widget creation #355

Merged
merged 12 commits into from
Jan 26, 2022
103 changes: 102 additions & 1 deletion magicgui/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MissingWidget(RuntimeError):


_RETURN_CALLBACKS: DefaultDict[type, list[ReturnCallback]] = defaultdict(list)
_RETURN_MATCHERS: list[TypeMatcher] = list()
_TYPE_MATCHERS: list[TypeMatcher] = list()
_TYPE_DEFS: dict[type, WidgetTuple] = dict()

Expand Down Expand Up @@ -112,6 +113,19 @@ def type_matcher(func: TypeMatcher) -> TypeMatcher:
return func


def return_matcher(func: TypeMatcher) -> TypeMatcher:
"""Add function to the set of return matchers.

Example
-------
>>> @return_matcher
... def default_return_widget(value, annotation):
... return widgets.LineEdit, {}
"""
_RETURN_MATCHERS.append(func)
return func


_SIMPLE_ANNOTATIONS = {
PathLike: widgets.FileEdit,
}
Expand Down Expand Up @@ -225,10 +239,91 @@ def pick_widget_type(
return widgets.EmptyWidget, {"visible": False}


_SIMPLE_RETURN_TYPES = [
bool,
int,
float,
str,
pathlib.Path,
datetime.time,
datetime.date,
datetime.datetime,
range,
slice,
]


@return_matcher
def default_return_matcher(value, annotation) -> WidgetTuple | None:
"""Checks for 'simple' types that fit in a LineEdit."""
dtype, optional = _normalize_type(value, annotation)
if dtype in _SIMPLE_TYPES:
return widgets.LineEdit, {"gui_only": True}
else:
return None


@return_matcher
def tabular_return_matcher(value, annotation) -> WidgetTuple | None:
"""Checks for tabular data."""
# TODO: is this correct?
if annotation == inspect._empty:
return None
dtype, optional = _normalize_type(value, annotation)
args = [_evaluate_forwardref(a) for a in get_args(widgets._table.TableData)]
if dtype in args:
return widgets.Table, {}

return None


def return_widget_type(
value: Any = None,
annotation: type | None = None,
options: WidgetOptions | None = None,
) -> WidgetTuple:
"""Pick the appropriate widget type for ``value`` with ``annotation``."""
options = options or {}
annotation = _evaluate_forwardref(annotation)
dtype, optional = _normalize_type(value, annotation)
if optional:
options.setdefault("nullable", True)
choices = options.get("choices") or (isinstance(dtype, EnumMeta) and dtype)

if "widget_type" in options:
widget_type = options.pop("widget_type")
if choices:
if widget_type == "RadioButton":
widget_type = "RadioButtons"
warnings.warn(
f"widget_type of 'RadioButton' (with dtype {dtype}) is being "
"coerced to 'RadioButtons' due to choices or Enum type.",
stacklevel=2,
)
options.setdefault("choices", choices)
return widget_type, options

# look for subclasses
for registered_type in _TYPE_DEFS:
if dtype == registered_type or _is_subclass(dtype, registered_type):
_cls, opts = _TYPE_DEFS[registered_type]
return _cls, {**options, **opts} # type: ignore

for matcher in _RETURN_MATCHERS:
_widget_type = matcher(value, annotation)
if _widget_type:
_cls, opts = _widget_type
return _cls, {**options, **opts} # type: ignore

# Chosen for backwards/test compatibility
return widgets.LineEdit, {"gui_only": True}


def get_widget_class(
value: Any = None,
annotation: type | None = None,
options: WidgetOptions | None = None,
is_result: bool = False,
) -> tuple[WidgetClass, WidgetOptions]:
"""Return a WidgetClass appropriate for the given parameters.

Expand All @@ -241,6 +336,9 @@ def get_widget_class(
A type annotation, by default None
options : WidgetOptions, optional
Options to pass when constructing the widget, by default {}
is_result : bool, optional
Identifies whether the returned widget should be tailored to
an input or to an output.

Returns
-------
Expand All @@ -249,7 +347,10 @@ def get_widget_class(
may be different than the options passed in.
"""
_options = cast(WidgetOptions, options)
widget_type, _options = pick_widget_type(value, annotation, _options)
if is_result:
widget_type, _options = return_widget_type(value, annotation, _options)
else:
widget_type, _options = pick_widget_type(value, annotation, _options)

if isinstance(widget_type, str):
widget_class: WidgetClass = _import_class(widget_type)
Expand Down
3 changes: 3 additions & 0 deletions magicgui/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#: A function that takes a ``(value, annotation)`` argument and returns an optional
#: :attr:`WidgetTuple`
TypeMatcher = Callable[[Any, Optional[Type]], Optional[WidgetTuple]]
#: A function that takes a ``(value, annotation)`` argument and returns an optional
#: :attr:`WidgetTuple`
ReturnMatcher = Callable[[Any, Optional[Type]], Optional[WidgetTuple]]
tlambert03 marked this conversation as resolved.
Show resolved Hide resolved
#: An iterable that can be used as a valid argument for widget ``choices``
ChoicesIterable = Union[Iterable[Tuple[str, Any]], Iterable[Any]]
#: An callback that can be used as a valid argument for widget ``choices``. It takes
Expand Down
9 changes: 7 additions & 2 deletions magicgui/widgets/_bases/create_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def create_widget(
app=None,
widget_type: str | type[_protocols.WidgetProtocol] | None = None,
options: WidgetOptions = dict(),
is_result: bool = False,
):
"""Create and return appropriate widget subclass.

Expand Down Expand Up @@ -58,6 +59,9 @@ def create_widget(
autodetermined from ``value`` and/or ``annotation`` above.
options : WidgetOptions, optional
Dict of options to pass to the Widget constructor, by default dict()
is_result : boolean, optional
Whether the widget belongs to an input or an output. By defult, an input
is assumed.

Returns
-------
Expand All @@ -71,8 +75,9 @@ def create_widget(
widget protocols from widgets._protocols.
"""
options = options.copy()
kwargs = locals()
kwargs = locals().copy()
_kind = kwargs.pop("param_kind", None)
_is_result = kwargs.pop("is_result", None)
_app = use_app(kwargs.pop("app"))
assert _app.native
if isinstance(widget_type, _protocols.WidgetProtocol):
Expand All @@ -82,7 +87,7 @@ def create_widget(

if widget_type:
options["widget_type"] = widget_type
wdg_class, opts = get_widget_class(value, annotation, options)
wdg_class, opts = get_widget_class(value, annotation, options, is_result)

if issubclass(wdg_class, Widget):
opts.update(kwargs.pop("options"))
Expand Down
15 changes: 11 additions & 4 deletions magicgui/widgets/_function_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from magicgui.application import AppRef
from magicgui.events import Signal
from magicgui.signature import MagicSignature, magic_signature
from magicgui.widgets import Container, LineEdit, MainWindow, ProgressBar, PushButton
from magicgui.widgets import Container, MainWindow, ProgressBar, PushButton
from magicgui.widgets._bases.value_widget import ValueWidget
from magicgui.widgets._protocols import ContainerProtocol, MainWindowProtocol

if TYPE_CHECKING:
Expand Down Expand Up @@ -202,10 +203,16 @@ def _disable_button_and_call():

self.append(self._call_button)

self._result_widget: LineEdit | None = None
self._result_widget: ValueWidget | None = None
if result_widget:
self._result_widget = LineEdit(gui_only=True, name="result")
self._result_widget.enabled = False
from magicgui.widgets._bases import create_widget

self._result_widget = create_widget(
value=None,
annotation=self._return_annotation,
gui_only=True,
is_result=True,
)
self.append(self._result_widget)

if persist:
Expand Down
25 changes: 14 additions & 11 deletions magicgui/widgets/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@
from magicgui.widgets._protocols import TableWidgetProtocol

if TYPE_CHECKING:
import numpy as np
import pandas as pd
import numpy
import pandas


TblKey = Any
_KT = TypeVar("_KT") # Key type
_KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers.
_VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers.
TableData = Union[dict, "pd.DataFrame", list, "np.ndarray", tuple, None]
TableData = Union[dict, "pandas.DataFrame", list, "numpy.ndarray", tuple, None]
IndexKey = Union[int, slice]


Expand Down Expand Up @@ -70,7 +70,7 @@ def normalize_table_data(data: TableData) -> tuple[Collection[Collection], list,
_columns = data[2] if data_len > 2 else []
return _data, _index, _columns
if _is_dataframe(data):
data = cast("pd.DataFrame", data)
data = cast("pandas.DataFrame", data)
return data.values, data.index, data.columns
if isinstance(data, list):
if data:
Expand Down Expand Up @@ -479,7 +479,7 @@ def _assert_col(self, col):

# #### EXPORT METHODS #####

def to_dataframe(self) -> pd.DataFrame:
def to_dataframe(self) -> pandas.DataFrame:
"""Convert TableData to dataframe."""
try:
import pandas
Expand All @@ -504,7 +504,7 @@ def to_dict(self, orient: Literal['records']) -> list[dict[TblKey, Any]]: ... #
@overload
def to_dict(self, orient: Literal['index']) -> dict[TblKey, dict[TblKey, list]]: ... # noqa
@overload
def to_dict(self, orient: Literal['series']) -> dict[TblKey, pd.Series]: ... # noqa
def to_dict(self, orient: Literal['series']) -> dict[TblKey, pandas.Series]: ... # noqa
# fmt: on

def to_dict(self, orient: str = "dict") -> list | dict:
Expand Down Expand Up @@ -734,9 +734,9 @@ def _from_nested_column_dict(data: dict) -> tuple[list[list], list]:
_index = {frozenset(i) for i in data.values()}
if len(_index) > 1:
try:
import pandas as pd
import pandas

df = pd.DataFrame(data)
df = pandas.DataFrame(data)
return df.values, df.index
except ImportError:
raise ValueError(
Expand Down Expand Up @@ -766,7 +766,10 @@ def _from_dict(data: dict, dtype=None) -> tuple[list[list], list, list]:
if isinstance(list(data.values())[0], dict):
_data, index = _from_nested_column_dict(data)
else:
_data = list(list(x) for x in zip(*data.values()))
try:
_data = list(list(x) for x in zip(*data.values()))
except TypeError:
raise ValueError("All values in the dict must be iterable (e.g. a list).")
index = []
return _data, index, columns

Expand All @@ -778,9 +781,9 @@ def _from_records(data: list[dict[TblKey, Any]]) -> tuple[list[list], list, list
_columns = {frozenset(i) for i in data}
if len(_columns) > 1:
try:
import pandas as pd
import pandas

df = pd.DataFrame(data)
df = pandas.DataFrame(data)
return df.values, df.index, df.columns
except ImportError:
raise ValueError(
Expand Down
27 changes: 17 additions & 10 deletions tests/test_magicgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,19 +444,26 @@ class Sub(Base):
register_type(int, return_callback=check_value)
register_type(Base, return_callback=check_value)

@magicgui
def func(a=1) -> int:
return a
try:

func()
with pytest.raises(AssertionError):
func(3)
@magicgui
def func(a=1) -> int:
return a

@magicgui
def func2(a=1) -> Sub:
return a
func()
with pytest.raises(AssertionError):
func(3)

@magicgui
def func2(a=1) -> Sub:
return a

func2()
finally:
from magicgui.type_map import _RETURN_CALLBACKS

func2()
_RETURN_CALLBACKS.pop(int)
_RETURN_CALLBACKS.pop(Base)


# @pytest.mark.skip(reason="need to rethink how to test this")
Expand Down
Loading