diff --git a/Makefile b/Makefile index e01bc95..d6cc7e0 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,4 @@ +SHELL= /bin/bash PYTHON ?= python inplace: diff --git a/cytoolz/_signatures.py b/cytoolz/_signatures.py index cea1de6..cc5a3c4 100644 --- a/cytoolz/_signatures.py +++ b/cytoolz/_signatures.py @@ -46,16 +46,20 @@ lambda func, x: None], excepts=[ lambda exc, func, handler=None: None], - flip=[ # XXX: these are optional, but not keywords! - lambda func=None, a=None, b=None: None], + flip=[ + lambda: None, + lambda func: None, + lambda func, a: None, + lambda func, a, b: None], _flip=[ lambda func, a, b: None], identity=[ lambda x: None], juxt=[ lambda *funcs: None], - memoize=[ # XXX: func is optional, but not a keyword! - lambda func=None, cache=None, key=None: None], + memoize=[ + lambda cache=None, key=None: None, + lambda func, cache=None, key=None: None], _memoize=[ lambda func, cache=None, key=None: None], pipe=[ diff --git a/cytoolz/functoolz.pxd b/cytoolz/functoolz.pxd index 0eee035..766aca1 100644 --- a/cytoolz/functoolz.pxd +++ b/cytoolz/functoolz.pxd @@ -15,6 +15,8 @@ cdef class curry: cdef readonly dict keywords cdef public object __doc__ cdef public object __name__ + cdef public object __module__ + cdef public object __qualname__ cdef class memoize: cdef object func diff --git a/cytoolz/functoolz.pyx b/cytoolz/functoolz.pyx index 54a2ba4..2f9449b 100644 --- a/cytoolz/functoolz.pyx +++ b/cytoolz/functoolz.pyx @@ -4,7 +4,7 @@ from functools import partial from operator import attrgetter from textwrap import dedent from cytoolz.utils import no_default -from cytoolz.compatibility import PY3, PY34, filter as ifilter, map as imap, reduce, import_module +from cytoolz.compatibility import PY3, PY33, PY34, filter as ifilter, map as imap, reduce, import_module import cytoolz._signatures as _sigs from toolz.functoolz import (InstanceProperty, instanceproperty, is_arity, @@ -194,6 +194,8 @@ cdef class curry: self.keywords = kwargs if kwargs else _empty_kwargs() self.__doc__ = getattr(func, '__doc__', None) self.__name__ = getattr(func, '__name__', '') + self.__module__ = getattr(func, '__module__', None) + self.__qualname__ = getattr(func, '__qualname__', None) self._sigspec = None self._has_unknown_args = None @@ -292,7 +294,13 @@ cdef class curry: property __signature__: def __get__(self): - sig = inspect.signature(self.func) + try: + sig = inspect.signature(self.func) + except TypeError: + if PY33 and (getattr(self.func, '__module__') or '').startswith('cytoolz.'): + raise ValueError('callable %r is not supported by signature' % self.func) + raise + args = self.args or () keywords = self.keywords or {} if is_partial_args(self.func, args, keywords, sig) is False: @@ -331,29 +339,43 @@ cdef class curry: def __reduce__(self): func = self.func modname = getattr(func, '__module__', None) - funcname = getattr(func, '__name__', None) - if modname and funcname: - module = import_module(modname) - obj = getattr(module, funcname, None) - if obj is self: - return funcname - elif isinstance(obj, curry) and obj.func is func: - func = '%s.%s' % (modname, funcname) - - state = (type(self), func, self.args, self.keywords) + qualname = getattr(func, '__qualname__', None) + if qualname is None: + qualname = getattr(func, '__name__', None) + is_decorated = None + if modname and qualname: + attrs = [] + obj = import_module(modname) + for attr in qualname.split('.'): + if isinstance(obj, curry): + attrs.append('func') + obj = obj.func + obj = getattr(obj, attr, None) + if obj is None: + break + attrs.append(attr) + if isinstance(obj, curry) and obj.func is func: + is_decorated = obj is self + qualname = '.'.join(attrs) + func = '%s:%s' % (modname, qualname) + + state = (type(self), func, self.args, self.keywords, is_decorated) return (_restore_curry, state) -cpdef object _restore_curry(cls, func, args, kwargs): +cpdef object _restore_curry(cls, func, args, kwargs, is_decorated): if isinstance(func, str): - modname, funcname = func.rsplit('.', 1) - module = import_module(modname) - func = getattr(module, funcname).func + modname, qualname = func.rsplit(':', 1) + obj = import_module(modname) + for attr in qualname.split('.'): + obj = getattr(obj, attr) + if is_decorated: + return obj + func = obj.func obj = cls(func, *args, **(kwargs or {})) return obj - cdef class memoize: """ memoize(func, cache=None, key=None) diff --git a/cytoolz/tests/test_functoolz.py b/cytoolz/tests/test_functoolz.py index 5aa0723..b6ddcf1 100644 --- a/cytoolz/tests/test_functoolz.py +++ b/cytoolz/tests/test_functoolz.py @@ -285,16 +285,26 @@ def foo(a, b, c=1): def test_curry_attributes_writable(): def foo(a, b, c=1): return a + b + c - + foo.__qualname__ = 'this.is.foo' f = curry(foo, 1, c=2) + assert f.__qualname__ == 'this.is.foo' f.__name__ = 'newname' f.__doc__ = 'newdoc' + f.__module__ = 'newmodule' + f.__qualname__ = 'newqualname' assert f.__name__ == 'newname' assert f.__doc__ == 'newdoc' + assert f.__module__ == 'newmodule' + assert f.__qualname__ == 'newqualname' if hasattr(f, 'func_name'): assert f.__name__ == f.func_name +def test_curry_module(): + from cytoolz.curried.exceptions import merge + assert merge.__module__ == 'cytoolz.curried.exceptions' + + def test_curry_comparable(): def foo(a, b, c=1): return a + b + c diff --git a/cytoolz/tests/test_serialization.py b/cytoolz/tests/test_serialization.py index f5522b4..be216f8 100644 --- a/cytoolz/tests/test_serialization.py +++ b/cytoolz/tests/test_serialization.py @@ -1,6 +1,9 @@ from cytoolz import * import cytoolz +import cytoolz.curried.exceptions import pickle +from cytoolz.compatibility import PY3, PY33, PY34 +from cytoolz.utils import raises def test_compose(): @@ -55,3 +58,138 @@ def test_flip(): g1 = flip(f)(1) g2 = pickle.loads(pickle.dumps(g1)) assert g1(2) == g2(2) == f(2, 1) + + +def test_curried_exceptions(): + # This tests a global curried object that isn't defined in cytoolz.functoolz + merge = pickle.loads(pickle.dumps(cytoolz.curried.exceptions.merge)) + assert merge is cytoolz.curried.exceptions.merge + + +@cytoolz.curry +class GlobalCurried(object): + def __init__(self, x, y): + self.x = x + self.y = y + + @cytoolz.curry + def f1(self, a, b): + return self.x + self.y + a + b + + def g1(self): + pass + + def __reduce__(self): + """Allow us to serialize instances of GlobalCurried""" + return (GlobalCurried, (self.x, self.y)) + + @cytoolz.curry + class NestedCurried(object): + def __init__(self, x, y): + self.x = x + self.y = y + + @cytoolz.curry + def f2(self, a, b): + return self.x + self.y + a + b + + def g2(self): + pass + + def __reduce__(self): + """Allow us to serialize instances of NestedCurried""" + return (GlobalCurried.NestedCurried, (self.x, self.y)) + + class Nested(object): + def __init__(self, x, y): + self.x = x + self.y = y + + @cytoolz.curry + def f3(self, a, b): + return self.x + self.y + a + b + + def g3(self): + pass + + +def test_curried_qualname(): + if not PY3: + return + + def preserves_identity(obj): + return pickle.loads(pickle.dumps(obj)) is obj + + assert preserves_identity(GlobalCurried) + assert preserves_identity(GlobalCurried.func.f1) + assert preserves_identity(GlobalCurried.func.NestedCurried) + assert preserves_identity(GlobalCurried.func.NestedCurried.func.f2) + assert preserves_identity(GlobalCurried.func.Nested.f3) + + global_curried1 = GlobalCurried(1) + global_curried2 = pickle.loads(pickle.dumps(global_curried1)) + assert global_curried1 is not global_curried2 + assert global_curried1(2).f1(3, 4) == global_curried2(2).f1(3, 4) == 10 + + global_curried3 = global_curried1(2) + global_curried4 = pickle.loads(pickle.dumps(global_curried3)) + assert global_curried3 is not global_curried4 + assert global_curried3.f1(3, 4) == global_curried4.f1(3, 4) == 10 + + func1 = global_curried1(2).f1(3) + func2 = pickle.loads(pickle.dumps(func1)) + assert func1 is not func2 + assert func1(4) == func2(4) == 10 + + nested_curried1 = GlobalCurried.func.NestedCurried(1) + nested_curried2 = pickle.loads(pickle.dumps(nested_curried1)) + assert nested_curried1 is not nested_curried2 + assert nested_curried1(2).f2(3, 4) == nested_curried2(2).f2(3, 4) == 10 + + # If we add `curry.__getattr__` forwarding, the following tests will pass + + # if not PY33 and not PY34: + # assert preserves_identity(GlobalCurried.func.g1) + # assert preserves_identity(GlobalCurried.func.NestedCurried.func.g2) + # assert preserves_identity(GlobalCurried.func.Nested) + # assert preserves_identity(GlobalCurried.func.Nested.g3) + # + # # Rely on curry.__getattr__ + # assert preserves_identity(GlobalCurried.f1) + # assert preserves_identity(GlobalCurried.NestedCurried) + # assert preserves_identity(GlobalCurried.NestedCurried.f2) + # assert preserves_identity(GlobalCurried.Nested.f3) + # if not PY33 and not PY34: + # assert preserves_identity(GlobalCurried.g1) + # assert preserves_identity(GlobalCurried.NestedCurried.g2) + # assert preserves_identity(GlobalCurried.Nested) + # assert preserves_identity(GlobalCurried.Nested.g3) + # + # nested_curried3 = nested_curried1(2) + # nested_curried4 = pickle.loads(pickle.dumps(nested_curried3)) + # assert nested_curried3 is not nested_curried4 + # assert nested_curried3.f2(3, 4) == nested_curried4.f2(3, 4) == 10 + # + # func1 = nested_curried1(2).f2(3) + # func2 = pickle.loads(pickle.dumps(func1)) + # assert func1 is not func2 + # assert func1(4) == func2(4) == 10 + # + # if not PY33 and not PY34: + # nested3 = GlobalCurried.func.Nested(1, 2) + # nested4 = pickle.loads(pickle.dumps(nested3)) + # assert nested3 is not nested4 + # assert nested3.f3(3, 4) == nested4.f3(3, 4) == 10 + # + # func1 = nested3.f3(3) + # func2 = pickle.loads(pickle.dumps(func1)) + # assert func1 is not func2 + # assert func1(4) == func2(4) == 10 + + +def test_curried_bad_qualname(): + @cytoolz.curry + class Bad(object): + __qualname__ = 'cytoolz.functoolz.not.a.valid.path' + + assert raises(pickle.PicklingError, lambda: pickle.dumps(Bad))