Skip to content

Commit

Permalink
Merge pull request #280 from legion-platform/master
Browse files Browse the repository at this point in the history
Fix comparing algorithm for serialization process
  • Loading branch information
mmckerns committed Sep 9, 2018
2 parents 52a7bc1 + 796dd11 commit c7281de
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def save_function(pickler, obj):
# if isinstance(value, stacktypes) and id(value) in stack:
# del globs[key]
# ABORT: if self-references, use _recurse=False
if obj in globs.values(): # or id(obj) in stack:
if id(obj) in stack:
globs = obj.__globals__ if PY3 else obj.func_globals
else:
globs = obj.__globals__ if PY3 else obj.func_globals
Expand Down
4 changes: 2 additions & 2 deletions dill/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def globalvars(func, recurse=True, builtin=False):
# find globals for all entries of func
for key in func.copy(): #XXX: unnecessary...?
nested_func = globs.get(key)
if nested_func == orig_func:
if nested_func is orig_func:
#func.remove(key) if key in func else None
continue #XXX: globalvars(func, False)?
func.update(globalvars(nested_func, True, builtin))
Expand All @@ -228,7 +228,7 @@ def globalvars(func, recurse=True, builtin=False):
func = set(nestedglobals(func))
# find globals for all entries of func
for key in func.copy(): #XXX: unnecessary...?
if key == orig_func:
if key is orig_func:
#func.remove(key) if key in func else None
continue #XXX: globalvars(func, False)?
nested_func = globs.get(key)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_restricted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python
#
# Author: Kirill Makhonin (@kirillmakhonin)
# Copyright (c) 2008-2016 California Institute of Technology.
# Copyright (c) 2016-2018 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE

import dill

class RestrictedType:
def __bool__(*args, **kwargs):
raise Exception('Restricted function')

__eq__ = __lt__ = __le__ = __ne__ = __gt__ = __ge__ = __hash__ = __bool__

glob_obj = RestrictedType()

def restricted_func():
a = glob_obj

def test_function_with_restricted_object():
deserialized = dill.loads(dill.dumps(restricted_func, recurse=True))


if __name__ is '__main__':
test_function_with_restricted_object()

0 comments on commit c7281de

Please sign in to comment.