diff --git a/magicgui/_magicgui.py b/magicgui/_magicgui.py index 363efaeed..328ceda6e 100644 --- a/magicgui/_magicgui.py +++ b/magicgui/_magicgui.py @@ -2,7 +2,6 @@ import inspect from functools import partial -from types import FunctionType from typing import ( TYPE_CHECKING, Any, @@ -13,7 +12,6 @@ Union, overload, ) -from warnings import warn from typing_extensions import Literal @@ -209,23 +207,6 @@ def __new__(cls, function, *args, magic_class=FunctionGui, **keywords): "MagicFactory missing required positional argument 'function'" ) - # if someone uses `@magic_factory` *inside* of another function (i.e., not in - # the module-level scope), *and* they try to use the "self-reference trick", - # (wherein they use the function name in the body of the function in order to - # access the resulting FunctionGui instance)... it will not work. - # here we detect that type of usage and give a warning. - if isinstance(function, FunctionType): - # this tells us the function has not been defined at the module level - if "" in function.__qualname__: - # this tells us they are accessing an undefined variable *inside* of the - # function that has the same name as the function. - # https://docs.python.org/3/library/inspect.html?highlight=co_freevars - if function.__name__ in function.__code__.co_freevars: - warn( - "Self-reference detected in MagicFactory function created " - "in a local scope. FunctionGui references will not work." - ) - # we want function first for the repr keywords = {"function": function, **keywords} return super().__new__(cls, magic_class, *args, **keywords) # type: ignore diff --git a/magicgui/widgets/_function_gui.py b/magicgui/widgets/_function_gui.py index f39be1bc1..187516388 100644 --- a/magicgui/widgets/_function_gui.py +++ b/magicgui/widgets/_function_gui.py @@ -411,20 +411,38 @@ def _function_name_pointing_to_widget(function_gui: FunctionGui): """ function = function_gui._function if not isinstance(function, FunctionType): + # it's not a function object, so we don't know how to patch it... yield return func_name = function.__name__ - # function.__globals__ here points to the module-level globals in which the function - # was defined. This means that this will NOT work for factories defined inside - # other functions. we use `_UNSET` just in case the function name has somehow been - # deleted or does not exist in the function module's globals() - original_value = function.__globals__.get(func_name, _UNSET) - function.__globals__[func_name] = function_gui - try: - yield - finally: - if original_value is _UNSET: - del function.__globals__[func_name] - else: + # see https://docs.python.org/3/library/inspect.html for details on code objects + code = function.__code__ + + if func_name in code.co_names: + # This indicates that the function name was used inside the body of the + # function, and points to some object in the module's global namespace. + # function.__globals__ here points to the module-level globals in which the + # function was defined. + original_value = function.__globals__.get(func_name) + function.__globals__[func_name] = function_gui + try: + yield + finally: function.__globals__[func_name] = original_value + + elif function.__closure__ and func_name in code.co_freevars: + # This indicates that the function name was used inside the body of the + # function, and points to some object defined in a local scope (closure), rather + # than the module's global namespace. + # the position of the function name in code.co_freevars tells us where to look + # for the value in the function.__closure__ tuple. + idx = code.co_freevars.index(func_name) + original_value = function.__closure__[idx].cell_contents + function.__closure__[idx].cell_contents = function_gui + try: + yield + finally: + function.__closure__[idx].cell_contents = original_value + else: + yield diff --git a/tests/test_factory.py b/tests/test_factory.py index 309323d77..b99eee76c 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -75,14 +75,11 @@ def test_magic_factory_self_reference(): def test_magic_local_factory_self_reference(): - """Test that self-referential factories DON'T work in local scopes, but warn.""" + """Test that self-referential factories work in local scopes.""" - with pytest.warns(UserWarning) as wrn: + @magic_factory + def local_self_referencing_factory(x: int = 1): + return local_self_referencing_factory - @magic_factory - def local_self_referencing_factory(x: int = 1): - return local_self_referencing_factory - - assert "Self-reference detected" in str(wrn[0]) widget = local_self_referencing_factory() - assert isinstance(widget(), MagicFactory) + assert isinstance(widget(), FunctionGui)