Skip to content

Commit

Permalink
Lift closure cell update to earliest function (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
anivegesana committed Apr 19, 2022
1 parent 8163e08 commit e2831d0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
32 changes: 27 additions & 5 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,8 @@ def __init__(self, *args, **kwds):
self._strictio = False #_strictio
self._fmode = settings['fmode'] if _fmode is None else _fmode
self._recurse = settings['recurse'] if _recurse is None else _recurse
self._postproc = {}
from collections import OrderedDict
self._postproc = OrderedDict()

def dump(self, obj): #NOTE: if settings change, need to update attributes
# register if the object is a numpy ufunc
Expand Down Expand Up @@ -1424,14 +1425,14 @@ def save_cell(pickler, obj):
log.info("# Ce3")
return
if is_dill(pickler, child=True):
postproc = pickler._postproc.get(id(f))
postproc = next(iter(pickler._postproc.values()), None)
if postproc is not None:
log.info("Ce2: %s" % obj)
# _CELL_REF is defined in _shims.py to support older versions of
# dill. When breaking changes are made to dill, (_CELL_REF,) can
# be replaced by ()
postproc.append((_shims._setattr, (obj, 'cell_contents', f)))
pickler.save_reduce(_create_cell, (_CELL_REF,), obj=obj)
postproc.append((_shims._setattr, (obj, 'cell_contents', f)))
log.info("# Ce2")
return
log.info("Ce1: %s" % obj)
Expand Down Expand Up @@ -1748,16 +1749,37 @@ def save_function(pickler, obj):
postproc_list.append((dict.update, (globs, globs_copy)))

if PY3:
closure = obj.__closure__
fkwdefaults = getattr(obj, '__kwdefaults__', None)
_save_with_postproc(pickler, (_create_function, (
obj.__code__, globs, obj.__name__, obj.__defaults__,
obj.__closure__, obj.__dict__, fkwdefaults
closure, obj.__dict__, fkwdefaults
)), obj=obj, postproc_list=postproc_list)
else:
closure = obj.func_closure
_save_with_postproc(pickler, (_create_function, (
obj.func_code, globs, obj.func_name, obj.func_defaults,
obj.func_closure, obj.__dict__
closure, obj.__dict__
)), obj=obj, postproc_list=postproc_list)

# Lift closure cell update to earliest function (#458)
topmost_postproc = next(iter(pickler._postproc.values()), None)
if closure and topmost_postproc:
for cell in closure:
possible_postproc = (setattr, (cell, 'cell_contents', obj))
try:
topmost_postproc.remove(possible_postproc)
except ValueError:
continue

# Change the value of the cell
pickler.save_reduce(*possible_postproc)
# pop None created by calling preprocessing step off stack
if PY3:
pickler.write(bytes('0', 'UTF-8'))
else:
pickler.write('0')

log.info("# F1")
else:
log.info("F2: %s" % obj)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ def test_recursive_function():
fib = fib4


def collection_function_recursion():
d = {}
def g():
return d
d['g'] = g
return g


def test_collection_function_recursion():
g = copy(collection_function_recursion())
assert g()['g'] is g


if __name__ == '__main__':
with warnings.catch_warnings():
warnings.simplefilter('error')
Expand All @@ -163,3 +176,4 @@ def test_recursive_function():
test_circular_reference()
test_function_cells()
test_recursive_function()
test_collection_function_recursion()

0 comments on commit e2831d0

Please sign in to comment.