Skip to content

Commit

Permalink
Pickle inner collections.namedtuples and function attributes (#448)
Browse files Browse the repository at this point in the history
* fix #288 nested namedtuples

* Remove special case for PyPy 2.7 that doesn't exist

__kwdefaults__ and __annotations__ are invalid in PyPy2.7

* Bug fix for __qualname__ on classes

* Fix bug if _postproc not present and use _setitems
  • Loading branch information
anivegesana committed Apr 21, 2022
1 parent 914d47f commit 5bd56a8
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 37 deletions.
118 changes: 82 additions & 36 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,9 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None,
fclosure=None, fdict=None, fkwdefaults=None):
# same as FunctionType, but enable passing __dict__ to new function,
# __dict__ is the storehouse for attributes added after function creation
if fdict is None: fdict = dict()
func = FunctionType(fcode, fglobals or dict(), fname, fdefaults, fclosure)
func.__dict__.update(fdict) #XXX: better copy? option to copy?
if fdict is not None:
func.__dict__.update(fdict) #XXX: better copy? option to copy?
if fkwdefaults is not None:
func.__kwdefaults__ = fkwdefaults
# 'recurse' only stores referenced modules/objects in fglobals,
Expand Down Expand Up @@ -1001,14 +1001,23 @@ def _create_dtypemeta(scalar_type):
return NumpyDType
return type(NumpyDType(scalar_type))

def _create_namedtuple(name, fieldnames, modulename):
class_ = _import_module(modulename + '.' + name, safe=True)
if class_ is not None:
return class_
import collections
t = collections.namedtuple(name, fieldnames)
t.__module__ = modulename
return t
if OLD37:
def _create_namedtuple(name, fieldnames, modulename, defaults=None):
class_ = _import_module(modulename + '.' + name, safe=True)
if class_ is not None:
return class_
import collections
t = collections.namedtuple(name, fieldnames)
t.__module__ = modulename
return t
else:
def _create_namedtuple(name, fieldnames, modulename, defaults=None):
class_ = _import_module(modulename + '.' + name, safe=True)
if class_ is not None:
return class_
import collections
t = collections.namedtuple(name, fieldnames, defaults=defaults, module=modulename)
return t

def _getattr(objclass, name, repr_str):
# hack to grab the reference directly
Expand Down Expand Up @@ -1058,6 +1067,11 @@ def _locate_function(obj, session=False):
return found is obj


def _setitems(dest, source):
for k, v in source.items():
dest[k] = v


def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None):
if obj is Getattr.NO_DEFAULT:
obj = Reduce(reduction) # pragma: no cover
Expand Down Expand Up @@ -1089,7 +1103,7 @@ def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO
postproc = pickler._postproc.pop(id(obj))
# assert postproc_list == postproc, 'Stack tampered!'
for reduction in reversed(postproc):
if reduction[0] is dict.update and type(reduction[1][0]) is dict:
if reduction[0] is _setitems:
# use the internal machinery of pickle.py to speedup when
# updating a dictionary in postproc
dest, source = reduction[1]
Expand Down Expand Up @@ -1719,10 +1733,14 @@ def save_type(pickler, obj, postproc_list=None):
log.info("T1: %s" % obj)
pickler.save_reduce(_load_type, (_typemap[obj],), obj=obj)
log.info("# T1")
elif issubclass(obj, tuple) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]):
elif obj.__bases__ == (tuple,) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]):
# special case: namedtuples
log.info("T6: %s" % obj)
pickler.save_reduce(_create_namedtuple, (getattr(obj, "__qualname__", obj.__name__), obj._fields, obj.__module__), obj=obj)
if OLD37 or (not obj._field_defaults):
pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__), obj=obj)
else:
defaults = [obj._field_defaults[field] for field in obj._fields]
pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults), obj=obj)
log.info("# T6")
return

Expand Down Expand Up @@ -1764,8 +1782,12 @@ def save_type(pickler, obj, postproc_list=None):
#print ("%s\n%s" % (obj.__bases__, obj.__dict__))
for name in _dict.get("__slots__", []):
del _dict[name]
if PY3 and obj_name != obj.__name__:
if postproc_list is None:
postproc_list = []
postproc_list.append((setattr, (obj, '__qualname__', obj_name)))
_save_with_postproc(pickler, (_create_type, (
type(obj), obj_name, obj.__bases__, _dict
type(obj), obj.__name__, obj.__bases__, _dict
)), obj=obj, postproc_list=postproc_list)
log.info("# %s" % _t)
else:
Expand Down Expand Up @@ -1858,42 +1880,66 @@ def save_function(pickler, obj):
glob_ids = {id(g) for g in globs_copy.itervalues()}
for stack_element in _postproc:
if stack_element in glob_ids:
_postproc[stack_element].append((dict.update, (globs, globs_copy)))
_postproc[stack_element].append((_setitems, (globs, globs_copy)))
break
else:
postproc_list.append((dict.update, (globs, globs_copy)))
postproc_list.append((_setitems, (globs, globs_copy)))

if PY3:
closure = obj.__closure__
fkwdefaults = getattr(obj, '__kwdefaults__', None)
state_dict = {}
for fattrname in ('__doc__', '__kwdefaults__', '__annotations__'):
fattr = getattr(obj, fattrname, None)
if fattr is not None:
state_dict[fattrname] = fattr
if obj.__qualname__ != obj.__name__:
state_dict['__qualname__'] = obj.__qualname__
if '__name__' not in globs or obj.__module__ != globs['__name__']:
state_dict['__module__'] = obj.__module__

state = obj.__dict__
if type(state) is not dict:
state_dict['__dict__'] = state
state = None
if state_dict:
state = state, state_dict

_save_with_postproc(pickler, (_create_function, (
obj.__code__, globs, obj.__name__, obj.__defaults__,
closure, obj.__dict__, fkwdefaults
)), obj=obj, postproc_list=postproc_list)
closure
), state), obj=obj, postproc_list=postproc_list)
else:
closure = obj.func_closure
if obj.__doc__ is not None:
postproc_list.append((setattr, (obj, '__doc__', obj.__doc__)))
if '__name__' not in globs or obj.__module__ != globs['__name__']:
postproc_list.append((setattr, (obj, '__module__', obj.__module__)))
if obj.__dict__:
postproc_list.append((setattr, (obj, '__dict__', obj.__dict__)))

_save_with_postproc(pickler, (_create_function, (
obj.func_code, globs, obj.func_name, obj.func_defaults,
closure, obj.__dict__
closure
)), obj=obj, postproc_list=postproc_list)

# Lift closure cell update to earliest function (#458)
topmost_postproc = next(iter(pickler._postproc.values()), None)
if closure and topmost_postproc:
for cell in closure:
possible_postproc = (setattr, (cell, 'cell_contents', obj))
try:
topmost_postproc.remove(possible_postproc)
except ValueError:
continue

# Change the value of the cell
pickler.save_reduce(*possible_postproc)
# pop None created by calling preprocessing step off stack
if PY3:
pickler.write(bytes('0', 'UTF-8'))
else:
pickler.write('0')
if _postproc:
topmost_postproc = next(iter(_postproc.values()), None)
if closure and topmost_postproc:
for cell in closure:
possible_postproc = (setattr, (cell, 'cell_contents', obj))
try:
topmost_postproc.remove(possible_postproc)
except ValueError:
continue

# Change the value of the cell
pickler.save_reduce(*possible_postproc)
# pop None created by calling preprocessing step off stack
if PY3:
pickler.write(bytes('0', 'UTF-8'))
else:
pickler.write('0')

log.info("# F1")
else:
Expand Down
16 changes: 15 additions & 1 deletion tests/test_classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@ def test_namedtuple():
assert Bad._fields == dill.loads(dill.dumps(Bad))._fields
assert tuple(Badi) == tuple(dill.loads(dill.dumps(Badi)))

class A:
class B(namedtuple("B", ["one", "two"])):
'''docstring'''
B.__module__ = 'testing'

a = A()
assert dill.copy(a)

assert dill.copy(A.B).__name__ == 'B'
if dill._dill.PY3:
assert dill.copy(A.B).__qualname__.endswith('.<locals>.A.B')
assert dill.copy(A.B).__doc__ == 'docstring'
assert dill.copy(A.B).__module__ == 'testing'

def test_dtype():
try:
import numpy as np
Expand All @@ -127,7 +141,7 @@ def test_dtype():
def test_array_nested():
try:
import numpy as np

x = np.array([1])
y = (x,)
dill.dumps(x)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ def function_c(c, c1=1):


def function_d(d, d1, d2=1):
"""doc string"""
return d + d1 + d2

function_d.__module__ = 'a module'


if is_py3():
exec('''
Expand Down Expand Up @@ -63,6 +66,8 @@ def test_functions():
assert dill.loads(dumped_func_c)(1, 2) == 3

dumped_func_d = dill.dumps(function_d)
assert dill.loads(dumped_func_d).__doc__ == function_d.__doc__
assert dill.loads(dumped_func_d).__module__ == function_d.__module__
assert dill.loads(dumped_func_d)(1, 2) == 4
assert dill.loads(dumped_func_d)(1, 2, 3) == 6
assert dill.loads(dumped_func_d)(1, 2, d2=3) == 6
Expand Down

0 comments on commit 5bd56a8

Please sign in to comment.