Skip to content

Commit

Permalink
add example on parser
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jul 5, 2020
1 parent 135d28f commit 0110c72
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 1 deletion.
2 changes: 2 additions & 0 deletions examples/plot_custom_options.py
@@ -1,4 +1,6 @@
"""
.. _l-plot-custom-options:
A new converter with options
============================
Expand Down
168 changes: 167 additions & 1 deletion examples/plot_custom_parser.py
Expand Up @@ -10,5 +10,171 @@
and a custom parser which defines the number of outputs
expected by the converted model.
*to be continued*
Example :ref:`l-plot-custom-options` shows a converter
which selects two ways to compute the same outputs.
In this one, the converter produces both. That would not
be a very efficient converter but that's just for the sake
of using a parser. By default, a transformer only returns
one output and but both are needed.
.. contents::
:local:
A new transformer
+++++++++++++++++
"""
from mlprodict.onnxrt import OnnxInference
import numpy
from onnxruntime import InferenceSession
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.datasets import load_iris
from skl2onnx import update_registered_converter
from skl2onnx.common.data_types import guess_numpy_type
from skl2onnx.algebra.onnx_ops import (
OnnxSub, OnnxMatMul, OnnxGemm)
from skl2onnx import to_onnx, get_model_alias


class DecorrelateTransformer(TransformerMixin, BaseEstimator):
"""
Decorrelates correlated gaussiance features.
:param alpha: avoids non inversible matrices
*Attributes*
* `self.mean_`: average
* `self.coef_`: square root of the coveriance matrix
"""

def __init__(self, alpha=0.):
BaseEstimator.__init__(self)
TransformerMixin.__init__(self)
self.alpha = alpha

def fit(self, X, y=None, sample_weights=None):
if sample_weights is not None:
raise NotImplementedError(
"sample_weights != None is not implemented.")
self.mean_ = numpy.mean(X, axis=0, keepdims=True)
X = X - self.mean_
V = X.T @ X / X.shape[0]
if self.alpha != 0:
V += numpy.identity(V.shape[0]) * self.alpha
L, P = numpy.linalg.eig(V)
Linv = L ** (-0.5)
diag = numpy.diag(Linv)
root = P @ diag @ P.transpose()
self.coef_ = root
return self

def transform(self, X):
return (X - self.mean_) @ self.coef_


data = load_iris()
X = data.data

dec = DecorrelateTransformer()
dec.fit(X)
pred = dec.transform(X[:5])
print(pred)


############################################
# Conversion into ONNX with two outputs
# +++++++++++++++++++++++++++++++++++++
#
# Let's try to convert it to see what happens.


def decorrelate_transformer_shape_calculator(operator):
op = operator.raw_operator
input_type = operator.inputs[0].type.__class__
input_dim = operator.inputs[0].type.shape[0]
output_type = input_type([input_dim, op.coef_.shape[1]])
operator.outputs[0].type = output_type


def decorrelate_transformer_converter(scope, operator, container):
op = operator.raw_operator
opv = container.target_opset
out = operator.outputs

X = operator.inputs[0]

dtype = guess_numpy_type(X.type)

Y1 = OnnxMatMul(
OnnxSub(X, op.mean_.astype(dtype), op_version=opv),
op.coef_.astype(dtype),
op_version=opv, output_names=out[:1])

Y2 = OnnxGemm(X, op.coef_.astype(dtype),
(- op.mean_ @ op.coef_).astype(dtype),
op_version=opv, alpha=1., beta=1.,
output_names=out[1:2])

Y1.add_to(scope, container)
Y2.add_to(scope, container)


def decorrelate_transformer_parser(
scope, model, inputs, custom_parsers=None):
alias = get_model_alias(type(model))
this_operator = scope.declare_local_operator(alias, model)

# inputs
this_operator.inputs.append(inputs[0])

# outputs
cls_type = inputs[0].type.__class__
val_y1 = scope.declare_local_variable('nogemm', cls_type())
val_y2 = scope.declare_local_variable('gemm', cls_type())
this_operator.outputs.append(val_y1)
this_operator.outputs.append(val_y2)

# ends
return this_operator.outputs

###################################
# The registration needs to declare the options
# supported by the converted.


update_registered_converter(
DecorrelateTransformer, "SklearnDecorrelateTransformer",
decorrelate_transformer_shape_calculator,
decorrelate_transformer_converter,
parser=decorrelate_transformer_parser)


#############################################
# And conversion.

onx = to_onnx(dec, X.astype(numpy.float32))

sess = InferenceSession(onx.SerializeToString())

exp = dec.transform(X.astype(numpy.float32))
results = sess.run(None, {'X': X.astype(numpy.float32)})
y1 = results[0]
y2 = results[1]


def diff(p1, p2):
p1 = p1.ravel()
p2 = p2.ravel()
d = numpy.abs(p2 - p1)
return d.max(), (d / numpy.abs(p1)).max()


print(diff(exp, y1))
print(diff(exp, y2))


################################
# It works. The final looks like the following.

oinf = OnnxInference(onx, runtime="python_compiled")
print(oinf)

0 comments on commit 0110c72

Please sign in to comment.