Skip to content

Commit

Permalink
add examples on lightgbm and xgboost
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jul 9, 2020
1 parent 806559e commit 1e63cdd
Show file tree
Hide file tree
Showing 21 changed files with 314 additions and 0 deletions.
5 changes: 5 additions & 0 deletions doc/conf.py
Expand Up @@ -148,6 +148,8 @@
'cython': 'https://cython.org/',
'DOT': 'https://www.graphviz.org/doc/info/lang.html',
'ImageNet': 'http://www.image-net.org/',
'LightGBM': 'https://lightgbm.readthedocs.io/en/latest/',
'lightgbm': 'https://lightgbm.readthedocs.io/en/latest/',
'mlprodict':
'http://www.xavierdupre.fr/app/mlprodict/helpsphinx/index.html',
'NMF':
Expand All @@ -160,6 +162,7 @@
'https://github.com/onnx/onnx/blob/master/docs/Operators.md',
'ONNX ML operators':
'https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md',
'onnxmltools': 'https://github.com/onnx/onnxmltools',
'OnnxPipeline':
'http://www.xavierdupre.fr/app/mlprodict/helpsphinx/mlprodict/'
'sklapi/onnx_pipeline.html?highlight=onnxpipeline',
Expand All @@ -172,6 +175,8 @@
'skorch': 'https://skorch.readthedocs.io/en/stable/',
'sklearn-onnx': 'https://github.com/onnx/sklearn-onnx',
'sphinx-gallery': 'https://github.com/sphinx-gallery/sphinx-gallery',
'xgboost': 'https://xgboost.readthedocs.io/en/latest/',
'XGBoost': 'https://xgboost.readthedocs.io/en/latest/',
}

warnings.filterwarnings("ignore", category=FutureWarning)
30 changes: 30 additions & 0 deletions doc/tutorial.rst
Expand Up @@ -2,6 +2,8 @@
Tutorial
========

.. index:: tutorial

The tutorial goes from a simple example which
converts a pipeline to a more complex example
involving operator not actually implemented in
Expand All @@ -11,5 +13,33 @@ involving operator not actually implemented in
:maxdepth: 2

tutorial_1_simple
tutorial_1-5_external
tutorial_2_new_converter
tutorial_3_new_operator

The tutorial was tested with following version:

.. runpython::
:showcode:

import numpy
import scipy
import sklearn
import lightgbm
import onnx
import onnxmltools
import onnxruntime
import xgboost
import skl2onnx
import mlprodict
import onnxcustom
import pyquickhelper

mods = [numpy, scipy, sklearn, lightgbm, xgboost,
onnx, onnxmltools, onnxruntime, onnxcustom,
skl2onnx, mlprodict, pyquickhelper]
mods = [(m.__name__, m.__version__) for m in mods]
mx = max(len(_[0]) for _ in mods) + 1
for name, vers in sorted(mods):
print("%s%s%s" % (name, " " * (mx - len(name)), vers))

15 changes: 15 additions & 0 deletions doc/tutorial_1-5_external.rst
@@ -0,0 +1,15 @@
Using converter from other libraries
====================================

Before starting writing our own converter,
we can use some available in other libraries
than :epkg:`sklearn-onnx`. :epkg:`onnxmltools` implements
converters for :epkg:`xgboost` and :epkg:`LightGBM`.
Following examples show how to use the conveter when the
model are part of a pipeline.

.. toctree::
:maxdepth: 1

auto_examples/plot_gexternal_lightgbm
auto_examples/plot_gexternal_xgboost
Binary file added examples/pipeline_lightgbm.onnx
Binary file not shown.
Binary file added examples/pipeline_xgboost.onnx
Binary file not shown.
2 changes: 2 additions & 0 deletions examples/plot_abegin_convert_pipeline.py
Expand Up @@ -2,6 +2,8 @@
Train and deploy a scikit-learn pipeline
========================================
.. index:: pipeline, deployment
This program starts from an example in :epkg:`scikit-learn`
documentation: `Plot individual and voting regression predictions
<https://scikit-learn.org/stable/auto_examples/ensemble/plot_voting_regressor.html>`_,
Expand Down
2 changes: 2 additions & 0 deletions examples/plot_bbegin_measure_time.py
Expand Up @@ -2,6 +2,8 @@
Benchmark ONNX conversion
=========================
.. index:: benchmark
Example :ref:`l-simple-deploy-1` converts a simple model.
This example takes a similar example but on random data
and compares the processing time required by each option
Expand Down
2 changes: 2 additions & 0 deletions examples/plot_cbegin_opset.py
Expand Up @@ -2,6 +2,8 @@
What is the opset number?
=========================
.. index:: opset, target opset, version
Every library is versioned. :epkg:`scikit-learn` may change
the implementation of a specific model. That happens
for example with the `SVC <https://scikit-learn.org/stable/
Expand Down
2 changes: 2 additions & 0 deletions examples/plot_dbegin_options.py
Expand Up @@ -2,6 +2,8 @@
One model, many possible conversions with options
=================================================
.. index:: options
There is not one way to convert a model. A new operator
might have been added in a newer version of :epkg:`ONNX`
and that speeds up the converted model. The rational choice
Expand Down
2 changes: 2 additions & 0 deletions examples/plot_ebegin_float_double.py
Expand Up @@ -2,6 +2,8 @@
Issues when switching to float
==============================
.. index:: float, double, discrepencies
Most models in :epkg:`scikit-learn` compute with double,
not float. Most models in deep learning use float because
that's the most common situation with GPU. ONNX was initially
Expand Down
2 changes: 2 additions & 0 deletions examples/plot_fbegin_investigate.py
Expand Up @@ -2,6 +2,8 @@
Intermediate results and investigation
======================================
.. index:: investigate, intermediate results
There are many reasons why a user wants more than using
the converted model into ONNX. Intermediate results may be
needed, the output of every node in the graph. The ONNX
Expand Down
2 changes: 2 additions & 0 deletions examples/plot_gbegin_transfer_learning.py
Expand Up @@ -2,6 +2,8 @@
Transfer Learning with ONNX
===========================
.. index:: transfer learning, deep learning
Transfer learning is common with deep learning.
A deep learning model is used as preprocessing before
the output is sent to a final classifier or regressor.
Expand Down
108 changes: 108 additions & 0 deletions examples/plot_gexternal_lightgbm.py
@@ -0,0 +1,108 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""
.. _example-lightgbm:
Convert a pipeline with a LightGBM model
========================================
.. index:: LightGBM
:epkg:`sklearn-onnx` only converts :epkg:`scikit-learn` models into *ONNX*
but many libraries implement :epkg:`scikit-learn` API so that their models
can be included in a :epkg:`scikit-learn` pipeline. This example considers
a pipeline including a :epkg:`LightGBM` model. :epkg:`sklearn-onnx` can convert
the whole pipeline as long as it knows the converter associated to
a *LGBMClassifier*. Let's see how to do it.
.. contents::
:local:
Train a LightGBM classifier
+++++++++++++++++++++++++++
"""
from pyquickhelper.helpgen.graphviz_helper import plot_graphviz
from mlprodict.onnxrt import OnnxInference
import onnxruntime as rt
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes # noqa
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm # noqa
from skl2onnx.common.data_types import FloatTensorType
import numpy
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from lightgbm import LGBMClassifier

data = load_iris()
X = data.data[:, :2]
y = data.target

ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()

pipe = Pipeline([('scaler', StandardScaler()),
('lgbm', LGBMClassifier(n_estimators=3))])
pipe.fit(X, y)

######################################
# Register the converter for LGBMClassifier
# +++++++++++++++++++++++++++++++++++++++++
#
# The converter is implemented in :epkg:`onnxmltools`:
# `onnxmltools...LightGbm.py
# <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
# lightgbm/operator_converters/LightGbm.py>`_.
# and the shape calculator:
# `onnxmltools...Classifier.py
# <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
# lightgbm/shape_calculators/Classifier.py>`_.

update_registered_converter(
LGBMClassifier, 'LightGbmLGBMClassifier',
calculate_linear_classifier_output_shapes, convert_lightgbm,
options={'nocl': [True, False], 'zipmap': [True, False]})

##################################
# Convert again
# +++++++++++++

model_onnx = convert_sklearn(
pipe, 'pipeline_lightgbm',
[('input', FloatTensorType([None, 2]))],
target_opset=12)

# And save.
with open("pipeline_lightgbm.onnx", "wb") as f:
f.write(model_onnx.SerializeToString())

###########################
# Compare the predictions
# +++++++++++++++++++++++
#
# Predictions with LightGbm.

print("predict", pipe.predict(X[:5]))
print("predict_proba", pipe.predict_proba(X[:1]))

##########################
# Predictions with onnxruntime.

sess = rt.InferenceSession("pipeline_lightgbm.onnx")

pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])

#############################
# Final graph
# +++++++++++


oinf = OnnxInference(model_onnx)
ax = plot_graphviz(oinf.to_dot())
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
127 changes: 127 additions & 0 deletions examples/plot_gexternal_xgboost.py
@@ -0,0 +1,127 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""
.. _example-xgboost:
Convert a pipeline with a XGBoost model
========================================
.. index:: XGBoost
:epkg:`sklearn-onnx` only converts :epkg:`scikit-learn` models
into :epkg:`ONNX` but many libraries implement :epkg:`scikit-learn`
API so that their models can be included in a :epkg:`scikit-learn`
pipeline. This example considers a pipeline including a :epkg:`XGBoost`
model. :epkg:`sklearn-onnx` can convert the whole pipeline as long as
it knows the converter associated to a *XGBClassifier*. Let's see
how to do it.
.. contents::
:local:
Train a XGBoost classifier
++++++++++++++++++++++++++
"""
from pyquickhelper.helpgen.graphviz_helper import plot_graphviz
from mlprodict.onnxrt import OnnxInference
import numpy
import onnxruntime as rt
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from xgboost import XGBClassifier
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import (
calculate_linear_classifier_output_shapes)
from onnxmltools.convert.xgboost.operator_converters.XGBoost import (
convert_xgboost)

data = load_iris()
X = data.data[:, :2]
y = data.target

ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()

pipe = Pipeline([('scaler', StandardScaler()),
('lgbm', XGBClassifier(n_estimators=3))])
pipe.fit(X, y)

# The conversion fails but it is expected.

try:
convert_sklearn(pipe, 'pipeline_xgboost',
[('input', FloatTensorType([None, 2]))],
target_opset=12)
except Exception as e:
print(e)

# The error message tells no converter was found
# for :epkg:`XGBoost` models. By default, :epkg:`sklearn-onnx`
# only handles models from :epkg:`scikit-learn` but it can
# be extended to every model following :epkg:`scikit-learn`
# API as long as the module knows there exists a converter
# for every model used in a pipeline. That's why
# we need to register a converter.

######################################
# Register the converter for XGBClassifier
# ++++++++++++++++++++++++++++++++++++++++
#
# The converter is implemented in :epkg:`onnxmltools`:
# `onnxmltools...XGBoost.py
# <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
# xgboost/operator_converters/XGBoost.py>`_.
# and the shape calculator:
# `onnxmltools...Classifier.py
# <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
# xgboost/shape_calculators/Classifier.py>`_.

update_registered_converter(
XGBClassifier, 'XGBoostXGBClassifier',
calculate_linear_classifier_output_shapes, convert_xgboost,
options={'nocl': [True, False], 'zipmap': [True, False]})

##################################
# Convert again
# +++++++++++++

model_onnx = convert_sklearn(
pipe, 'pipeline_xgboost',
[('input', FloatTensorType([None, 2]))],
target_opset=12)

# And save.
with open("pipeline_xgboost.onnx", "wb") as f:
f.write(model_onnx.SerializeToString())

###########################
# Compare the predictions
# +++++++++++++++++++++++
#
# Predictions with XGBoost.

print("predict", pipe.predict(X[:5]))
print("predict_proba", pipe.predict_proba(X[:1]))

##########################
# Predictions with onnxruntime.

sess = rt.InferenceSession("pipeline_xgboost.onnx")
pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])

#############################
# Final graph
# +++++++++++


oinf = OnnxInference(model_onnx)
ax = plot_graphviz(oinf.to_dot())
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
2 changes: 2 additions & 0 deletions examples/plot_icustom_converter.py
Expand Up @@ -4,6 +4,8 @@
Implement a new converter
=========================
.. index:: custom converter
By default, :epkg:`sklearn-onnx` assumes that a classifier
has two outputs (label and probabilities), a regressor
has one output (prediction), a transform has one output
Expand Down

0 comments on commit 1e63cdd

Please sign in to comment.