Skip to content

Commit

Permalink
Merge pull request #355 from eriknw/fix/curry_module
Browse files Browse the repository at this point in the history
Set `__module__` on curried objects.  This can fix pickling global curried objects
  • Loading branch information
eriknw committed Dec 10, 2016
2 parents 7fb83c8 + 0ef082c commit be3ce86
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 15 deletions.
46 changes: 32 additions & 14 deletions toolz/functoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def __init__(self, *args, **kwargs):

self.__doc__ = getattr(func, '__doc__', None)
self.__name__ = getattr(func, '__name__', '<curry>')
self.__module__ = getattr(func, '__module__', None)
self.__qualname__ = getattr(func, '__qualname__', None)
self._sigspec = None
self._has_unknown_args = None

Expand Down Expand Up @@ -324,27 +326,43 @@ def __get__(self, instance, owner):
def __reduce__(self):
func = self.func
modname = getattr(func, '__module__', None)
funcname = getattr(func, '__name__', None)
if modname and funcname:
module = import_module(modname)
obj = getattr(module, funcname, None)
if obj is self:
return funcname
elif isinstance(obj, curry) and obj.func is func:
func = '%s.%s' % (modname, funcname)
qualname = getattr(func, '__qualname__', None)
if qualname is None: # pragma: py3 no cover
qualname = getattr(func, '__name__', None)
is_decorated = None
if modname and qualname:
attrs = []
obj = import_module(modname)
for attr in qualname.split('.'):
if isinstance(obj, curry): # pragma: py2 no cover
attrs.append('func')
obj = obj.func
obj = getattr(obj, attr, None)
if obj is None:
break
attrs.append(attr)
if isinstance(obj, curry) and obj.func is func:
is_decorated = obj is self
qualname = '.'.join(attrs)
func = '%s:%s' % (modname, qualname)

# functools.partial objects can't be pickled
userdict = tuple((k, v) for k, v in self.__dict__.items()
if k != '_partial')
state = (type(self), func, self.args, self.keywords, userdict)
if k not in ('_partial', '_sigspec'))
state = (type(self), func, self.args, self.keywords, userdict,
is_decorated)
return (_restore_curry, state)


def _restore_curry(cls, func, args, kwargs, userdict):
def _restore_curry(cls, func, args, kwargs, userdict, is_decorated):
if isinstance(func, str):
modname, funcname = func.rsplit('.', 1)
module = import_module(modname)
func = getattr(module, funcname).func
modname, qualname = func.rsplit(':', 1)
obj = import_module(modname)
for attr in qualname.split('.'):
obj = getattr(obj, attr)
if is_decorated:
return obj
func = obj.func
obj = cls(func, *args, **(kwargs or {}))
obj.__dict__.update(userdict)
return obj
Expand Down
12 changes: 11 additions & 1 deletion toolz/tests/test_functoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,26 @@ def foo(a, b, c=1):
def test_curry_attributes_writable():
def foo(a, b, c=1):
return a + b + c

foo.__qualname__ = 'this.is.foo'
f = curry(foo, 1, c=2)
assert f.__qualname__ == 'this.is.foo'
f.__name__ = 'newname'
f.__doc__ = 'newdoc'
f.__module__ = 'newmodule'
f.__qualname__ = 'newqualname'
assert f.__name__ == 'newname'
assert f.__doc__ == 'newdoc'
assert f.__module__ == 'newmodule'
assert f.__qualname__ == 'newqualname'
if hasattr(f, 'func_name'):
assert f.__name__ == f.func_name


def test_curry_module():
from toolz.curried.exceptions import merge
assert merge.__module__ == 'toolz.curried.exceptions'


def test_curry_comparable():
def foo(a, b, c=1):
return a + b + c
Expand Down
138 changes: 138 additions & 0 deletions toolz/tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from toolz import *
import toolz
import toolz.curried.exceptions
import pickle
from toolz.compatibility import PY3, PY33, PY34
from toolz.utils import raises


def test_compose():
Expand Down Expand Up @@ -55,3 +58,138 @@ def test_flip():
g1 = flip(f)(1)
g2 = pickle.loads(pickle.dumps(g1))
assert g1(2) == g2(2) == f(2, 1)


def test_curried_exceptions():
# This tests a global curried object that isn't defined in toolz.functoolz
merge = pickle.loads(pickle.dumps(toolz.curried.exceptions.merge))
assert merge is toolz.curried.exceptions.merge


@toolz.curry
class GlobalCurried(object):
def __init__(self, x, y):
self.x = x
self.y = y

@toolz.curry
def f1(self, a, b):
return self.x + self.y + a + b

def g1(self):
pass

def __reduce__(self):
"""Allow us to serialize instances of GlobalCurried"""
return (GlobalCurried, (self.x, self.y))

@toolz.curry
class NestedCurried(object):
def __init__(self, x, y):
self.x = x
self.y = y

@toolz.curry
def f2(self, a, b):
return self.x + self.y + a + b

def g2(self):
pass

def __reduce__(self):
"""Allow us to serialize instances of NestedCurried"""
return (GlobalCurried.NestedCurried, (self.x, self.y))

class Nested(object):
def __init__(self, x, y):
self.x = x
self.y = y

@toolz.curry
def f3(self, a, b):
return self.x + self.y + a + b

def g3(self):
pass


def test_curried_qualname():
if not PY3:
return

def preserves_identity(obj):
return pickle.loads(pickle.dumps(obj)) is obj

assert preserves_identity(GlobalCurried)
assert preserves_identity(GlobalCurried.func.f1)
assert preserves_identity(GlobalCurried.func.NestedCurried)
assert preserves_identity(GlobalCurried.func.NestedCurried.func.f2)
assert preserves_identity(GlobalCurried.func.Nested.f3)

global_curried1 = GlobalCurried(1)
global_curried2 = pickle.loads(pickle.dumps(global_curried1))
assert global_curried1 is not global_curried2
assert global_curried1(2).f1(3, 4) == global_curried2(2).f1(3, 4) == 10

global_curried3 = global_curried1(2)
global_curried4 = pickle.loads(pickle.dumps(global_curried3))
assert global_curried3 is not global_curried4
assert global_curried3.f1(3, 4) == global_curried4.f1(3, 4) == 10

func1 = global_curried1(2).f1(3)
func2 = pickle.loads(pickle.dumps(func1))
assert func1 is not func2
assert func1(4) == func2(4) == 10

nested_curried1 = GlobalCurried.func.NestedCurried(1)
nested_curried2 = pickle.loads(pickle.dumps(nested_curried1))
assert nested_curried1 is not nested_curried2
assert nested_curried1(2).f2(3, 4) == nested_curried2(2).f2(3, 4) == 10

# If we add `curry.__getattr__` forwarding, the following tests will pass

# if not PY33 and not PY34:
# assert preserves_identity(GlobalCurried.func.g1)
# assert preserves_identity(GlobalCurried.func.NestedCurried.func.g2)
# assert preserves_identity(GlobalCurried.func.Nested)
# assert preserves_identity(GlobalCurried.func.Nested.g3)
#
# # Rely on curry.__getattr__
# assert preserves_identity(GlobalCurried.f1)
# assert preserves_identity(GlobalCurried.NestedCurried)
# assert preserves_identity(GlobalCurried.NestedCurried.f2)
# assert preserves_identity(GlobalCurried.Nested.f3)
# if not PY33 and not PY34:
# assert preserves_identity(GlobalCurried.g1)
# assert preserves_identity(GlobalCurried.NestedCurried.g2)
# assert preserves_identity(GlobalCurried.Nested)
# assert preserves_identity(GlobalCurried.Nested.g3)
#
# nested_curried3 = nested_curried1(2)
# nested_curried4 = pickle.loads(pickle.dumps(nested_curried3))
# assert nested_curried3 is not nested_curried4
# assert nested_curried3.f2(3, 4) == nested_curried4.f2(3, 4) == 10
#
# func1 = nested_curried1(2).f2(3)
# func2 = pickle.loads(pickle.dumps(func1))
# assert func1 is not func2
# assert func1(4) == func2(4) == 10
#
# if not PY33 and not PY34:
# nested3 = GlobalCurried.func.Nested(1, 2)
# nested4 = pickle.loads(pickle.dumps(nested3))
# assert nested3 is not nested4
# assert nested3.f3(3, 4) == nested4.f3(3, 4) == 10
#
# func1 = nested3.f3(3)
# func2 = pickle.loads(pickle.dumps(func1))
# assert func1 is not func2
# assert func1(4) == func2(4) == 10


def test_curried_bad_qualname():
@toolz.curry
class Bad(object):
__qualname__ = 'toolz.functoolz.not.a.valid.path'

assert raises(pickle.PicklingError, lambda: pickle.dumps(Bad))

0 comments on commit be3ce86

Please sign in to comment.