Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions _unittests/ut_npy/test_function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,17 @@ def test_function_transformer(self):
self.assertEqualArray(y_exp, y_onx['variable'])

@ignore_warnings((DeprecationWarning, RuntimeWarning))
@unittest.skipIf(True, reason="pickling not implemented yet")
def test_function_transformer_pickle(self):
x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
tr = FunctionTransformer(custom_fct)
tr.fit(x)
y_exp = tr.transform(x)
st = BytesIO()
pickle.dump(tr, st)
# import cloudpickle as pkl
pkl = pickle
pkl.dump(tr, st)
cp = BytesIO(st.getvalue())
tr2 = pickle.load(cp)
tr2 = pkl.load(cp)
y_exp2 = tr2.transform(x)
self.assertEqualArray(y_exp, y_exp2)

Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_npy/test_onnx_variable_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_abs_reshape(x: NDArray[Any, numpy.float32],
return nxnp.abs(x).reshape((-1, 1))


@onnxnumpy(op_version=11)
@onnxnumpy(op_version=11, runtime='onnxruntime1')
def test_abs_reshape_11(x: NDArray[Any, numpy.float32],
) -> NDArray[Any, numpy.float32]:
"onnx numpy reshape with opset 11"
Expand Down
36 changes: 33 additions & 3 deletions mlprodict/npy/onnx_numpy_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,34 @@ def __init__(self, fct, op_version=None, runtime=None, signature=None,
signature=signature, version=version)
inputs, outputs, kwargs, n_input_range = self._parse_annotation(
signature=signature, version=version)
n_opt = 0 if signature is None else signature.n_optional
args, kwargs2 = get_args_kwargs(self.fct_, n_opt)
self.meta_ = dict(op_version=op_version, runtime=runtime,
signature=signature, version=version,
inputs=inputs, outputs=outputs,
kwargs=kwargs, n_input_range=n_input_range)
kwargs=kwargs, n_input_range=n_input_range,
args=args, kwargs2=kwargs2,
annotations=self.fct_.__annotations__)

def __getstate__(self):
"""
Serializes everything but function `fct_`.
Function `fct_` is used to build the onnx graph
and is not needed anymore.
"""
return dict(onnx_=self.onnx_, meta_=self.meta_)

def __setstate__(self, state):
"""
Restores serialized data.
"""
for k, v in state.items():
setattr(self, k, v)
self.runtime_ = self._build_runtime(
op_version=self.meta_['op_version'],
runtime=self.meta_['runtime'],
signature=self.meta_['signature'],
version=self.meta_['version'])

def __repr__(self):
"usual"
Expand Down Expand Up @@ -169,7 +193,10 @@ def _parse_annotation(self, signature, version):
*kwargs* is the list of additional parameters
"""
n_opt = 0 if signature is None else signature.n_optional
args, kwargs = get_args_kwargs(self.fct_, n_opt)
if hasattr(self, 'meta_'):
args, kwargs = self.meta_['args'], self.meta_['kwargs2']
else:
args, kwargs = get_args_kwargs(self.fct_, n_opt)
if isinstance(version, tuple):
nv = len(version) - len(args) - n_opt
if (signature is not None and not
Expand Down Expand Up @@ -205,7 +232,10 @@ def _possible_names():
for i in range(0, 10000): # pragma: no cover
yield 'o%d' % i

annotations = self.fct_.__annotations__
if hasattr(self, 'meta_'):
annotations = self.meta_['annotations']
else:
annotations = self.fct_.__annotations__
inputs = []
outputs = []
for a in args:
Expand Down
70 changes: 65 additions & 5 deletions mlprodict/npy/onnx_numpy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@
from .onnx_numpy_compiler import OnnxNumpyCompiler


class _created_classes:
"""
Class to store all dynamic classes created by wrappers.
"""

def __init__(self):
self.stored = {}

def append(self, name, cl):
"""
Adds a class into `globals()` to enable pickling on dynamic
classes.
"""
if name in self.stored:
raise RuntimeError(
"Class %r already exists in\n%r\n---\n%r" % (
name, ", ".join(sorted(self.stored)), cl))
self.stored[name] = cl
globals()[name] = cl


_created_classes_inst = _created_classes()


class wrapper_onnxnumpy:
"""
Intermediate wrapper to store a pointer
Expand All @@ -27,6 +51,20 @@ def __call__(self, *args, **kwargs):
"""
return self.compiled(*args, **kwargs)

def __getstate__(self):
"""
Serializes everything but the function which generates
the ONNX graph, not needed anymore.
"""
return dict(compiled=self.compiled)

def __setstate__(self, state):
"""
Serializes everything but the function which generates
the ONNX graph, not needed anymore.
"""
self.compiled = state['compiled']


def onnxnumpy(op_version=None, runtime=None, signature=None):
"""
Expand All @@ -47,10 +85,10 @@ def decorator_fct(fct):
compiled = OnnxNumpyCompiler(
fct, op_version=op_version, runtime=runtime,
signature=signature)
name = "onnxnumpy_%s_%s_%s" % (fct.__name__, str(op_version), runtime)
newclass = type(
"onnxnumpy_%s_%s_%s" % (fct.__name__, str(op_version), runtime),
(wrapper_onnxnumpy,), {'__doc__': fct.__doc__})

name, (wrapper_onnxnumpy,), {'__doc__': fct.__doc__})
_created_classes_inst.append(name, newclass)
return newclass(compiled)
return decorator_fct

Expand Down Expand Up @@ -86,6 +124,23 @@ def __init__(self, **kwargs):
self.data = kwargs
self.signed_compiled = {}

def __getstate__(self):
"""
Serializes everything but the function which generates
the ONNX graph, not needed anymore.
"""
data_copy = {k: v for k, v in self.data.items() if k != 'fct'}
return dict(signature=self.signature, args=self.args,
kwargs=self.kwargs, data=data_copy,
signed_compiled=self.signed_compiled)

def __setstate__(self, state):
"""
Restores serialized data.
"""
for k, v in state.items():
setattr(self, k, v)

def __getitem__(self, dtype):
"""
Returns the instance of @see cl wrapper_onnxnumpy
Expand Down Expand Up @@ -157,9 +212,14 @@ def onnxnumpy_np(op_version=None, runtime=None, signature=None):
.. versionadded:: 0.6
"""
def decorator_fct(fct):
name = "onnxnumpy_nb_%s_%s_%s" % (
fct.__name__, str(op_version), runtime)
newclass = type(
"onnxnumpy_nb_%s_%s_%s" % (fct.__name__, str(op_version), runtime),
(wrapper_onnxnumpy_np,), {'__doc__': fct.__doc__})
name, (wrapper_onnxnumpy_np,), {
'__doc__': fct.__doc__,
'__getstate__': wrapper_onnxnumpy_np.__getstate__,
'__setstate__': wrapper_onnxnumpy_np.__setstate__})
_created_classes_inst.append(name, newclass)
return newclass(
fct=fct, op_version=op_version, runtime=runtime,
signature=signature)
Expand Down