Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions _doc/examples/plot_abegin_convert_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
X_train, X_test, y_train, y_test = train_test_split(X, y)

# Train classifiers
reg1 = GradientBoostingRegressor(random_state=1)
reg2 = RandomForestRegressor(random_state=1)
reg1 = GradientBoostingRegressor(random_state=1, n_estimators=5)
reg2 = RandomForestRegressor(random_state=1, n_estimators=5)
reg3 = LinearRegression()

ereg = Pipeline(steps=[
Expand Down
3 changes: 2 additions & 1 deletion _doc/examples/plot_bbegin_measure_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
# The same is done with the two ONNX runtime
# available.

onx = to_onnx(ereg, X_train[:1].astype(numpy.float32))
onx = to_onnx(ereg, X_train[:1].astype(numpy.float32),
target_opset=14)
sess = InferenceSession(onx.SerializeToString())
oinf = OnnxInference(onx, runtime="python_compiled")

Expand Down
188 changes: 188 additions & 0 deletions _doc/examples/plot_dbegin_options_zipmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""
.. _l-tutorial-example-zipmap:

Choose appropriate output of a classifier
=========================================

A scikit-learn classifier usually returns a matrix of probabilities.
By default, *sklearn-onnx* converts that matrix
into a list of dictionaries where each probabily is mapped
to its class id or name. That mechanism retains the class names
but is slower. Let's see what other options are available.

.. contents::
:local:

Train a model and convert it
++++++++++++++++++++++++++++

"""
from timeit import repeat
import numpy
import sklearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import onnxruntime as rt
import onnx
import skl2onnx
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import to_onnx
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier

iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(numpy.float32)
y = y * 2 + 10 # to get labels different from [0, 1, 2]
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LogisticRegression(max_iter=500)
clr.fit(X_train, y_train)
print(clr)

onx = to_onnx(clr, X_train, target_opset=12)

############################
# Default behaviour: zipmap=True
# ++++++++++++++++++++++++++++++
#
# The output type for the probabilities is a list of
# dictionaries.

sess = rt.InferenceSession(onx.SerializeToString())
res = sess.run(None, {'X': X_test})
print(res[1][:2])
print("probabilities type:", type(res[1]))
print("type for the first observations:", type(res[1][0]))

###################################
# Option zipmap=False
# +++++++++++++++++++
#
# Probabilities are now a matrix.

initial_type = [('float_input', FloatTensorType([None, 4]))]
options = {id(clr): {'zipmap': False}}
onx2 = to_onnx(clr, X_train, options=options, target_opset=12)

sess2 = rt.InferenceSession(onx2.SerializeToString())
res2 = sess2.run(None, {'X': X_test})
print(res2[1][:2])
print("probabilities type:", type(res2[1]))
print("type for the first observations:", type(res2[1][0]))

###################################
# Option zipmap='columns'
# +++++++++++++++++++++++
#
# This options removes the final operator ZipMap and splits
# the probabilities into columns. The final model produces
# one output for the label, and one output per class.

options = {id(clr): {'zipmap': 'columns'}}
onx3 = to_onnx(clr, X_train, options=options, target_opset=12)

sess3 = rt.InferenceSession(onx3.SerializeToString())
res3 = sess3.run(None, {'X': X_test})
for i, out in enumerate(sess3.get_outputs()):
print("output: '{}' shape={} values={}...".format(
out.name, res3[i].shape, res3[i][:2]))


###################################
# Let's compare prediction time
# +++++++++++++++++++++++++++++

print("Average time with ZipMap:")
print(sum(repeat(lambda: sess.run(None, {'X': X_test}),
number=100, repeat=10)) / 10)

print("Average time without ZipMap:")
print(sum(repeat(lambda: sess2.run(None, {'X': X_test}),
number=100, repeat=10)) / 10)

print("Average time without ZipMap but with columns:")
print(sum(repeat(lambda: sess3.run(None, {'X': X_test}),
number=100, repeat=10)) / 10)

# The prediction is much faster without ZipMap
# on this example.
# The optimisation is even faster when the classes
# are described with strings and not integers
# as the final result (list of dictionaries) may copy
# many times the same information with onnxruntime.

#######################################
# Option zimpap=False and output_class_labels=True
# ++++++++++++++++++++++++++++++++++++++++++++++++
#
# Option `zipmap=False` seems a better choice because it is
# much faster but labels are lost in the process. Option
# `output_class_labels` can be used to expose the labels
# as a third output.

initial_type = [('float_input', FloatTensorType([None, 4]))]
options = {id(clr): {'zipmap': False, 'output_class_labels': True}}
onx4 = to_onnx(clr, X_train, options=options, target_opset=12)

sess4 = rt.InferenceSession(onx4.SerializeToString())
res4 = sess4.run(None, {'X': X_test})
print(res4[1][:2])
print("probabilities type:", type(res4[1]))
print("class labels:", res4[2])

###########################################
# Processing time.

print("Average time without ZipMap but with output_class_labels:")
print(sum(repeat(lambda: sess4.run(None, {'X': X_test}),
number=100, repeat=10)) / 10)

###########################################
# MultiOutputClassifier
# +++++++++++++++++++++
#
# This model is equivalent to several classifiers, one for every label
# to predict. Instead of returning a matrix of probabilities, it returns
# a sequence of matrices. Let's first modify the labels to get
# a problem for a MultiOutputClassifier.

y = numpy.vstack([y, y + 100]).T
y[::5, 1] = 1000 # Let's a fourth class.
print(y[:5])

########################################
# Let's train a MultiOutputClassifier.

X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = MultiOutputClassifier(LogisticRegression(max_iter=500))
clr.fit(X_train, y_train)
print(clr)

onx5 = to_onnx(clr, X_train, target_opset=12)

sess5 = rt.InferenceSession(onx5.SerializeToString())
res5 = sess5.run(None, {'X': X_test[:3]})
print(res5)

########################################
# Option zipmap is ignored. Labels are missing but they can be
# added back as a third output.

onx6 = to_onnx(clr, X_train, target_opset=12,
options={'zipmap': False, 'output_class_labels': True})

sess6 = rt.InferenceSession(onx6.SerializeToString())
res6 = sess6.run(None, {'X': X_test[:3]})
print("predicted labels", res6[0])
print("predicted probabilies", res6[1])
print("class labels", res6[2])


#################################
# **Versions used for this example**

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
100 changes: 100 additions & 0 deletions _doc/examples/plot_gconverting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Modify the ONNX graph
=====================

This example shows how to change the default ONNX graph such as
renaming the inputs or outputs names.

.. contents::
:local:

Basic example
+++++++++++++

"""
import numpy
from onnxruntime import InferenceSession
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from skl2onnx.common.data_types import FloatTensorType, Int64TensorType
from skl2onnx import to_onnx

iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)

clr = LogisticRegression(solver="liblinear")
clr.fit(X_train, y_train)


onx = to_onnx(clr, X, options={'zipmap': False})

sess = InferenceSession(onx.SerializeToString())
input_names = [i.name for i in sess.get_inputs()]
output_names = [o.name for o in sess.get_outputs()]
print("inputs=%r, outputs=%r" % (input_names, output_names))
print(sess.run(None, {input_names[0]: X_test[:2]}))


####################################
# Changes the input names
# +++++++++++++++++++++++
#
# It is possible to change the input name by using the
# parameter *initial_types*. However, the user must specify the input
# types as well.

onx = to_onnx(clr, X, options={'zipmap': False},
initial_types=[('X56', FloatTensorType([None, X.shape[1]]))])

sess = InferenceSession(onx.SerializeToString())
input_names = [i.name for i in sess.get_inputs()]
output_names = [o.name for o in sess.get_outputs()]
print("inputs=%r, outputs=%r" % (input_names, output_names))
print(sess.run(None, {input_names[0]: X_test[:2]}))


####################################
# Changes the output names
# ++++++++++++++++++++++++
#
# It is possible to change the input name by using the
# parameter *final_types*.

onx = to_onnx(clr, X, options={'zipmap': False},
final_types=[('L', Int64TensorType([None])),
('P', FloatTensorType([None, 3]))])

sess = InferenceSession(onx.SerializeToString())
input_names = [i.name for i in sess.get_inputs()]
output_names = [o.name for o in sess.get_outputs()]
print("inputs=%r, outputs=%r" % (input_names, output_names))
print(sess.run(None, {input_names[0]: X_test[:2]}))

####################################
# Renaming intermediate results
# +++++++++++++++++++++++++++++
#
# It is possible to rename intermediate results by using a prefix
# or by using a function. The result will be post-processed in order
# to unique names. It does not impact the graph inputs or outputs.


def rename_results(proposed_name, existing_names):
result = "_" + proposed_name.upper()
while result in existing_names:
result += "A"
print("changed %r into %r." % (proposed_name, result))
return result


onx = to_onnx(clr, X, options={'zipmap': False},
naming=rename_results)

sess = InferenceSession(onx.SerializeToString())
input_names = [i.name for i in sess.get_inputs()]
output_names = [o.name for o in sess.get_outputs()]
print("inputs=%r, outputs=%r" % (input_names, output_names))
print(sess.run(None, {input_names[0]: X_test[:2]}))
2 changes: 1 addition & 1 deletion _doc/examples/plot_kcustom_converter_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Let's implement a simple custom model using
:epkg:`scikit-learn` API. The model is preprocessing
which decorrelates correlated random variables.
If *X* is a matrix of features, :math:`V=\frac{1}{n}X'X`
If *X* is a matrix of features, :math:`V=\\frac{1}{n}X'X`
is the covariance matrix. We compute :math:`X V^{1/2}`.
"""
from mlprodict.onnxrt import OnnxInference
Expand Down
3 changes: 2 additions & 1 deletion _doc/examples/plot_mcustom_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def decorrelate_transformer_parser(
#############################################
# And conversion.

onx = to_onnx(dec, X.astype(numpy.float32))
onx = to_onnx(dec, X.astype(numpy.float32),
target_opset=14)

sess = InferenceSession(onx.SerializeToString())

Expand Down
25 changes: 25 additions & 0 deletions _doc/examples/plot_woe_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,28 @@
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
ax.axis('off')

########################################
# Half-line
# +++++++++
#
# An interval may have only one extremity defined and the other
# can be infinite.

intervals = [
[(-numpy.inf, 3., True, True),
(5., numpy.inf, True, True)]]
weights = [[55, 107]]

woe1 = WOETransformer(intervals, onehot=False, weights=weights)
woe1.fit(X)
prd = woe1.transform(X)
df = pd.DataFrame({'X': X.ravel(), 'woe': prd.ravel()})
df

#################################
# And the conversion to ONNX using the same instruction.

onxinf = to_onnx(woe1, X)
sess = InferenceSession(onxinf.SerializeToString())
print(sess.run(None, {'X': X})[0])
2 changes: 2 additions & 0 deletions _doc/sphinxdoc/source/tutorial/tutorial_1_simple.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ used in the ONNX graph.
../gyexamples/plot_cbegin_opset
../gyexamples/plot_dbegin_options
../gyexamples/plot_dbegin_options_list
../gyexamples/plot_dbegin_options_zipmap
../gyexamples/plot_ebegin_float_double
../gyexamples/plot_fbegin_investigate
../gyexamples/plot_gbegin_dataframe
../gyexamples/plot_gbegin_transfer_learning
../gyexamples/plot_gbegin_cst
../gyexamples/plot_gconverting