diff --git a/toolz/functoolz.py b/toolz/functoolz.py index ed37c0a1..2bd80a4d 100644 --- a/toolz/functoolz.py +++ b/toolz/functoolz.py @@ -200,6 +200,8 @@ def __init__(self, *args, **kwargs): self.__doc__ = getattr(func, '__doc__', None) self.__name__ = getattr(func, '__name__', '') + self.__module__ = getattr(func, '__module__', None) + self.__qualname__ = getattr(func, '__qualname__', None) self._sigspec = None self._has_unknown_args = None @@ -324,27 +326,43 @@ def __get__(self, instance, owner): def __reduce__(self): func = self.func modname = getattr(func, '__module__', None) - funcname = getattr(func, '__name__', None) - if modname and funcname: - module = import_module(modname) - obj = getattr(module, funcname, None) - if obj is self: - return funcname - elif isinstance(obj, curry) and obj.func is func: - func = '%s.%s' % (modname, funcname) + qualname = getattr(func, '__qualname__', None) + if qualname is None: # pragma: py3 no cover + qualname = getattr(func, '__name__', None) + is_decorated = None + if modname and qualname: + attrs = [] + obj = import_module(modname) + for attr in qualname.split('.'): + if isinstance(obj, curry): # pragma: py2 no cover + attrs.append('func') + obj = obj.func + obj = getattr(obj, attr, None) + if obj is None: + break + attrs.append(attr) + if isinstance(obj, curry) and obj.func is func: + is_decorated = obj is self + qualname = '.'.join(attrs) + func = '%s:%s' % (modname, qualname) # functools.partial objects can't be pickled userdict = tuple((k, v) for k, v in self.__dict__.items() - if k != '_partial') - state = (type(self), func, self.args, self.keywords, userdict) + if k not in ('_partial', '_sigspec')) + state = (type(self), func, self.args, self.keywords, userdict, + is_decorated) return (_restore_curry, state) -def _restore_curry(cls, func, args, kwargs, userdict): +def _restore_curry(cls, func, args, kwargs, userdict, is_decorated): if isinstance(func, str): - modname, funcname = func.rsplit('.', 1) - module = import_module(modname) - func = getattr(module, funcname).func + modname, qualname = func.rsplit(':', 1) + obj = import_module(modname) + for attr in qualname.split('.'): + obj = getattr(obj, attr) + if is_decorated: + return obj + func = obj.func obj = cls(func, *args, **(kwargs or {})) obj.__dict__.update(userdict) return obj diff --git a/toolz/tests/test_functoolz.py b/toolz/tests/test_functoolz.py index 6c7cb0ea..deda6068 100644 --- a/toolz/tests/test_functoolz.py +++ b/toolz/tests/test_functoolz.py @@ -285,16 +285,26 @@ def foo(a, b, c=1): def test_curry_attributes_writable(): def foo(a, b, c=1): return a + b + c - + foo.__qualname__ = 'this.is.foo' f = curry(foo, 1, c=2) + assert f.__qualname__ == 'this.is.foo' f.__name__ = 'newname' f.__doc__ = 'newdoc' + f.__module__ = 'newmodule' + f.__qualname__ = 'newqualname' assert f.__name__ == 'newname' assert f.__doc__ == 'newdoc' + assert f.__module__ == 'newmodule' + assert f.__qualname__ == 'newqualname' if hasattr(f, 'func_name'): assert f.__name__ == f.func_name +def test_curry_module(): + from toolz.curried.exceptions import merge + assert merge.__module__ == 'toolz.curried.exceptions' + + def test_curry_comparable(): def foo(a, b, c=1): return a + b + c diff --git a/toolz/tests/test_serialization.py b/toolz/tests/test_serialization.py index a49f7a87..afee159f 100644 --- a/toolz/tests/test_serialization.py +++ b/toolz/tests/test_serialization.py @@ -1,6 +1,9 @@ from toolz import * import toolz +import toolz.curried.exceptions import pickle +from toolz.compatibility import PY3, PY33, PY34 +from toolz.utils import raises def test_compose(): @@ -55,3 +58,138 @@ def test_flip(): g1 = flip(f)(1) g2 = pickle.loads(pickle.dumps(g1)) assert g1(2) == g2(2) == f(2, 1) + + +def test_curried_exceptions(): + # This tests a global curried object that isn't defined in toolz.functoolz + merge = pickle.loads(pickle.dumps(toolz.curried.exceptions.merge)) + assert merge is toolz.curried.exceptions.merge + + +@toolz.curry +class GlobalCurried(object): + def __init__(self, x, y): + self.x = x + self.y = y + + @toolz.curry + def f1(self, a, b): + return self.x + self.y + a + b + + def g1(self): + pass + + def __reduce__(self): + """Allow us to serialize instances of GlobalCurried""" + return (GlobalCurried, (self.x, self.y)) + + @toolz.curry + class NestedCurried(object): + def __init__(self, x, y): + self.x = x + self.y = y + + @toolz.curry + def f2(self, a, b): + return self.x + self.y + a + b + + def g2(self): + pass + + def __reduce__(self): + """Allow us to serialize instances of NestedCurried""" + return (GlobalCurried.NestedCurried, (self.x, self.y)) + + class Nested(object): + def __init__(self, x, y): + self.x = x + self.y = y + + @toolz.curry + def f3(self, a, b): + return self.x + self.y + a + b + + def g3(self): + pass + + +def test_curried_qualname(): + if not PY3: + return + + def preserves_identity(obj): + return pickle.loads(pickle.dumps(obj)) is obj + + assert preserves_identity(GlobalCurried) + assert preserves_identity(GlobalCurried.func.f1) + assert preserves_identity(GlobalCurried.func.NestedCurried) + assert preserves_identity(GlobalCurried.func.NestedCurried.func.f2) + assert preserves_identity(GlobalCurried.func.Nested.f3) + + global_curried1 = GlobalCurried(1) + global_curried2 = pickle.loads(pickle.dumps(global_curried1)) + assert global_curried1 is not global_curried2 + assert global_curried1(2).f1(3, 4) == global_curried2(2).f1(3, 4) == 10 + + global_curried3 = global_curried1(2) + global_curried4 = pickle.loads(pickle.dumps(global_curried3)) + assert global_curried3 is not global_curried4 + assert global_curried3.f1(3, 4) == global_curried4.f1(3, 4) == 10 + + func1 = global_curried1(2).f1(3) + func2 = pickle.loads(pickle.dumps(func1)) + assert func1 is not func2 + assert func1(4) == func2(4) == 10 + + nested_curried1 = GlobalCurried.func.NestedCurried(1) + nested_curried2 = pickle.loads(pickle.dumps(nested_curried1)) + assert nested_curried1 is not nested_curried2 + assert nested_curried1(2).f2(3, 4) == nested_curried2(2).f2(3, 4) == 10 + + # If we add `curry.__getattr__` forwarding, the following tests will pass + + # if not PY33 and not PY34: + # assert preserves_identity(GlobalCurried.func.g1) + # assert preserves_identity(GlobalCurried.func.NestedCurried.func.g2) + # assert preserves_identity(GlobalCurried.func.Nested) + # assert preserves_identity(GlobalCurried.func.Nested.g3) + # + # # Rely on curry.__getattr__ + # assert preserves_identity(GlobalCurried.f1) + # assert preserves_identity(GlobalCurried.NestedCurried) + # assert preserves_identity(GlobalCurried.NestedCurried.f2) + # assert preserves_identity(GlobalCurried.Nested.f3) + # if not PY33 and not PY34: + # assert preserves_identity(GlobalCurried.g1) + # assert preserves_identity(GlobalCurried.NestedCurried.g2) + # assert preserves_identity(GlobalCurried.Nested) + # assert preserves_identity(GlobalCurried.Nested.g3) + # + # nested_curried3 = nested_curried1(2) + # nested_curried4 = pickle.loads(pickle.dumps(nested_curried3)) + # assert nested_curried3 is not nested_curried4 + # assert nested_curried3.f2(3, 4) == nested_curried4.f2(3, 4) == 10 + # + # func1 = nested_curried1(2).f2(3) + # func2 = pickle.loads(pickle.dumps(func1)) + # assert func1 is not func2 + # assert func1(4) == func2(4) == 10 + # + # if not PY33 and not PY34: + # nested3 = GlobalCurried.func.Nested(1, 2) + # nested4 = pickle.loads(pickle.dumps(nested3)) + # assert nested3 is not nested4 + # assert nested3.f3(3, 4) == nested4.f3(3, 4) == 10 + # + # func1 = nested3.f3(3) + # func2 = pickle.loads(pickle.dumps(func1)) + # assert func1 is not func2 + # assert func1(4) == func2(4) == 10 + + +def test_curried_bad_qualname(): + @toolz.curry + class Bad(object): + __qualname__ = 'toolz.functoolz.not.a.valid.path' + + assert raises(pickle.PicklingError, lambda: pickle.dumps(Bad))