Skip to content

Commit

Permalink
Optimize the numpy hook (#542)
Browse files Browse the repository at this point in the history
* optimize the numpy hook

* update numpy tests

* Remove try-except clauses to catch ReferenceError

- Only type(obj) is passed to the numpy helper functions, and calling type()
    on a proxy of a deleted weakly-referenced object is safe.
- Similarly, using "%r" for string substitution with a proxy of a deleted
    object is safe, so there's no need to have a 'R3' case in save_weakproxy()
- Removed test cases from test_weakref.py related to Python 2 old-style classes
- Fixed a bad use of 'NoReturn' instead of 'None' in logger.py
  • Loading branch information
leogama committed Aug 13, 2022
1 parent 43fd2ca commit 01bab78
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 127 deletions.
143 changes: 54 additions & 89 deletions dill/_dill.py
Expand Up @@ -91,54 +91,24 @@ def __hook__():
from numpy import dtype as NumpyDType
return True
if NumpyArrayType: # then has numpy
def ndarraysubclassinstance(obj):
if type(obj) in (TypeType, ClassType):
return False # all classes return False
try: # check if is ndarray, and elif is subclass of ndarray
cls = getattr(obj, '__class__', None)
if cls is None: return False
elif cls is TypeType: return False
elif 'numpy.ndarray' not in str(getattr(cls, 'mro', int.mro)()):
return False
except OSError: return False # ctypes.LibraryLoader
except ReferenceError: return False # handle 'R3' weakref in 3.x
except TypeError: return False
def ndarraysubclassinstance(obj_type):
if all((c.__module__, c.__name__) != ('numpy', 'ndarray') for c in obj_type.__mro__):
return False
# anything below here is a numpy array (or subclass) instance
__hook__() # import numpy (so the following works!!!)
# verify that __reduce__ has not been overridden
NumpyInstance = NumpyArrayType((0,),'int8')
if id(obj.__reduce_ex__) == id(NumpyInstance.__reduce_ex__) and \
id(obj.__reduce__) == id(NumpyInstance.__reduce__): return True
return False
def numpyufunc(obj):
if type(obj) in (TypeType, ClassType):
return False # all classes return False
try: # check if is ufunc
cls = getattr(obj, '__class__', None)
if cls is None: return False
elif cls is TypeType: return False
if 'numpy.ufunc' not in str(getattr(cls, 'mro', int.mro)()):
return False
except OSError: return False # ctypes.LibraryLoader
except ReferenceError: return False # handle 'R3' weakref in 3.x
except TypeError: return False
# anything below here is a numpy ufunc
if obj_type.__reduce_ex__ is not NumpyArrayType.__reduce_ex__ \
or obj_type.__reduce__ is not NumpyArrayType.__reduce__:
return False
return True
def numpydtype(obj):
if type(obj) in (TypeType, ClassType):
return False # all classes return False
try: # check if is dtype
cls = getattr(obj, '__class__', None)
if cls is None: return False
elif cls is TypeType: return False
if 'numpy.dtype' not in str(getattr(obj, 'mro', int.mro)()):
return False
except OSError: return False # ctypes.LibraryLoader
except ReferenceError: return False # handle 'R3' weakref in 3.x
except TypeError: return False
def numpyufunc(obj_type):
return any((c.__module__, c.__name__) == ('numpy', 'ufunc') for c in obj_type.__mro__)
def numpydtype(obj_type):
if all((c.__module__, c.__name__) != ('numpy', 'dtype') for c in obj_type.__mro__):
return False
# anything below here is a numpy dtype
__hook__() # import numpy (so the following works!!!)
return type(obj) is type(NumpyDType) # handles subclasses
return obj_type is type(NumpyDType) # handles subclasses
else:
def ndarraysubclassinstance(obj): return False
def numpyufunc(obj): return False
Expand Down Expand Up @@ -373,42 +343,44 @@ def __init__(self, file, *args, **kwds):
def save(self, obj, save_persistent_id=True):
# register if the object is a numpy ufunc
# thanks to Paul Kienzle for pointing out ufuncs didn't pickle
if NumpyUfuncType and numpyufunc(obj):
@register(type(obj))
def save_numpy_ufunc(pickler, obj):
logger.trace(pickler, "Nu: %s", obj)
name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
StockPickler.save_global(pickler, obj, name=name)
logger.trace(pickler, "# Nu")
return
# NOTE: the above 'save' performs like:
# import copy_reg
# def udump(f): return f.__name__
# def uload(name): return getattr(numpy, name)
# copy_reg.pickle(NumpyUfuncType, udump, uload)
# register if the object is a numpy dtype
if NumpyDType and numpydtype(obj):
@register(type(obj))
def save_numpy_dtype(pickler, obj):
logger.trace(pickler, "Dt: %s", obj)
pickler.save_reduce(_create_dtypemeta, (obj.type,), obj=obj)
logger.trace(pickler, "# Dt")
return
# NOTE: the above 'save' performs like:
# import copy_reg
# def uload(name): return type(NumpyDType(name))
# def udump(f): return uload, (f.type,)
# copy_reg.pickle(NumpyDTypeType, udump, uload)
# register if the object is a subclassed numpy array instance
if NumpyArrayType and ndarraysubclassinstance(obj):
@register(type(obj))
def save_numpy_array(pickler, obj):
logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype)
npdict = getattr(obj, '__dict__', None)
f, args, state = obj.__reduce__()
pickler.save_reduce(_create_array, (f,args,state,npdict), obj=obj)
logger.trace(pickler, "# Nu")
return
obj_type = type(obj)
if NumpyArrayType and not (obj_type is type or obj_type in Pickler.dispatch):
if NumpyUfuncType and numpyufunc(obj_type):
@register(obj_type)
def save_numpy_ufunc(pickler, obj):
logger.trace(pickler, "Nu: %s", obj)
name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
StockPickler.save_global(pickler, obj, name=name)
logger.trace(pickler, "# Nu")
return
# NOTE: the above 'save' performs like:
# import copy_reg
# def udump(f): return f.__name__
# def uload(name): return getattr(numpy, name)
# copy_reg.pickle(NumpyUfuncType, udump, uload)
# register if the object is a numpy dtype
if NumpyDType and numpydtype(obj_type):
@register(obj_type)
def save_numpy_dtype(pickler, obj):
logger.trace(pickler, "Dt: %s", obj)
pickler.save_reduce(_create_dtypemeta, (obj.type,), obj=obj)
logger.trace(pickler, "# Dt")
return
# NOTE: the above 'save' performs like:
# import copy_reg
# def uload(name): return type(NumpyDType(name))
# def udump(f): return uload, (f.type,)
# copy_reg.pickle(NumpyDTypeType, udump, uload)
# register if the object is a subclassed numpy array instance
if NumpyArrayType and ndarraysubclassinstance(obj_type):
@register(obj_type)
def save_numpy_array(pickler, obj):
logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype)
npdict = getattr(obj, '__dict__', None)
f, args, state = obj.__reduce__()
pickler.save_reduce(_create_array, (f,args,state,npdict), obj=obj)
logger.trace(pickler, "# Nu")
return
# end hack
if GENERATOR_FAIL and type(obj) == GeneratorType:
msg = "Can't pickle %s: attribute lookup builtins.generator failed" % GeneratorType
Expand Down Expand Up @@ -1604,18 +1576,11 @@ def save_weakref(pickler, obj):
@register(ProxyType)
@register(CallableProxyType)
def save_weakproxy(pickler, obj):
# Must do string substitution here and use %r to avoid ReferenceError.
logger.trace(pickler, "R2: %r" % obj)
refobj = _locate_object(_proxy_helper(obj))
try:
_t = "R2"
logger.trace(pickler, "%s: %s", _t, obj)
except ReferenceError:
_t = "R3"
logger.trace(pickler, "%s: %s", _t, sys.exc_info()[1])
#callable = bool(getattr(refobj, '__call__', None))
if type(obj) is CallableProxyType: callable = True
else: callable = False
pickler.save_reduce(_create_weakproxy, (refobj, callable), obj=obj)
logger.trace(pickler, "# %s", _t)
pickler.save_reduce(_create_weakproxy, (refobj, callable(obj)), obj=obj)
logger.trace(pickler, "# R2")
return

def _is_builtin_module(module):
Expand Down
4 changes: 2 additions & 2 deletions dill/logger.py
Expand Up @@ -50,7 +50,7 @@
import math
import os
from functools import partial
from typing import NoReturn, TextIO, Union
from typing import TextIO, Union

import dill

Expand Down Expand Up @@ -214,7 +214,7 @@ def format(self, record):
stderr_handler = logging._StderrHandler()
adapter.addHandler(stderr_handler)

def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') -> NoReturn:
def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') -> None:
"""print a trace through the stack when pickling; useful for debugging
With a single boolean argument, enable or disable the tracing.
Expand Down
7 changes: 3 additions & 4 deletions dill/tests/test_classdef.py
Expand Up @@ -128,8 +128,8 @@ def test_dtype():
import numpy as np

dti = np.dtype('int')
assert np.dtype == dill.loads(dill.dumps(np.dtype))
assert dti == dill.loads(dill.dumps(dti))
assert np.dtype == dill.copy(np.dtype)
assert dti == dill.copy(dti)
except ImportError: pass


Expand All @@ -139,8 +139,7 @@ def test_array_nested():

x = np.array([1])
y = (x,)
dill.dumps(x)
assert y == dill.loads(dill.dumps(y))
assert y == dill.copy(y)

except ImportError: pass

Expand Down
47 changes: 15 additions & 32 deletions dill/tests/test_weakref.py
Expand Up @@ -14,15 +14,7 @@ class _class:
def _method(self):
pass

class _class2:
def __call__(self):
pass

class _newclass(object):
def _method(self):
pass

class _newclass2(object):
class _callable_class:
def __call__(self):
pass

Expand All @@ -32,42 +24,33 @@ def _function():

def test_weakref():
o = _class()
oc = _class2()
n = _newclass()
nc = _newclass2()
oc = _callable_class()
f = _function
z = _class
x = _newclass
x = _class

# ReferenceType
r = weakref.ref(o)
dr = weakref.ref(_class())
p = weakref.proxy(o)
dp = weakref.proxy(_class())
c = weakref.proxy(oc)
dc = weakref.proxy(_class2())
d_r = weakref.ref(_class())
fr = weakref.ref(f)
xr = weakref.ref(x)

m = weakref.ref(n)
dm = weakref.ref(_newclass())
t = weakref.proxy(n)
dt = weakref.proxy(_newclass())
d = weakref.proxy(nc)
dd = weakref.proxy(_newclass2())
# ProxyType
p = weakref.proxy(o)
d_p = weakref.proxy(_class())

fr = weakref.ref(f)
# CallableProxyType
cp = weakref.proxy(oc)
d_cp = weakref.proxy(_callable_class())
fp = weakref.proxy(f)
#zr = weakref.ref(z) #XXX: weakrefs not allowed for classobj objects
#zp = weakref.proxy(z) #XXX: weakrefs not allowed for classobj objects
xr = weakref.ref(x)
xp = weakref.proxy(x)

objlist = [r,dr,m,dm,fr,xr, p,dp,t,dt, c,dc,d,dd, fp,xp]
objlist = [r,d_r,fr,xr, p,d_p, cp,d_cp,fp,xp]
#dill.detect.trace(True)

for obj in objlist:
res = dill.detect.errors(obj)
if res:
print ("%s" % res)
#print ("%s:\n %s" % (obj, res))
print ("%r:\n %s" % (obj, res))
# else:
# print ("PASS: %s" % obj)
assert not res
Expand Down

0 comments on commit 01bab78

Please sign in to comment.