Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 78 additions & 19 deletions _doc/examples/plot_gexternal_lightgbm_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,31 @@
:math:`D'(x) = |\\sum_{k=1}^a \\left[\\sum\\right]_{i=1}^{F/a}
float(T_{ak + i}(x)) - \\sum_{i=1}^F T_i(x)|`.

In 2022, :epkg:`onnx` and :epkg:`onnxruntime` updated the specifications
of TreeEnsemble operators and they can now support double thresholds
(see `TreeEnsembleRegressor v3
<https://onnx.ai/onnx/operators/onnx_aionnxml_TreeEnsembleRegressor.html>`_).
That would be the recommended option to reduce the discrepancies.

.. contents::
:local:

Train a LGBMRegressor
+++++++++++++++++++++
"""
from distutils.version import StrictVersion
import warnings
import time
import timeit
from packaging.version import Version
import numpy
from pandas import DataFrame
# import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from tqdm import tqdm
from lightgbm import LGBMRegressor
from onnxruntime import InferenceSession
from skl2onnx import to_onnx, update_registered_converter
from skl2onnx import update_registered_converter
from skl2onnx.common.shape_calculator import calculate_linear_regressor_output_shapes # noqa
from mlprodict.onnx_conv import to_onnx
from onnxmltools import __version__ as oml_version
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm # noqa

Expand All @@ -59,8 +67,8 @@
reg.fit(X, y)

######################################
# Register the converter for LGBMClassifier
# +++++++++++++++++++++++++++++++++++++++++
# Register the converter for LGBMRegressor
# ++++++++++++++++++++++++++++++++++++++++
#
# The converter is implemented in :epkg:`onnxmltools`:
# `onnxmltools...LightGbm.py
Expand All @@ -75,7 +83,7 @@
def skl2onnx_convert_lightgbm(scope, operator, container):
options = scope.get_options(operator.raw_operator)
if 'split' in options:
if StrictVersion(oml_version) < StrictVersion('1.9.2'):
if Version(oml_version) < Version('1.9.2'):
warnings.warn(
"Option split was released in version 1.9.2 but %s is "
"installed. It will be ignored." % oml_version)
Expand All @@ -100,11 +108,20 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
# trees per node TreeEnsembleRegressor.

model_onnx = to_onnx(reg, X[:1].astype(numpy.float32),
target_opset={'': 14, 'ai.onnx.ml': 2})
target_opset={'': 17, 'ai.onnx.ml': 3})

model_onnx_split = to_onnx(reg, X[:1].astype(numpy.float32),
target_opset={'': 14, 'ai.onnx.ml': 2},
target_opset={'': 17, 'ai.onnx.ml': 3},
options={'split': 100})

####################################
# We create another model using the `ai.onnx.ml == 3`.
# Node thresholds are stored in doubles and not in floats anymore.

model_onnx_64 = to_onnx(reg, X[:1].astype(numpy.float64),
target_opset={'': 17, 'ai.onnx.ml': 3},
rewrite_ops=True)

##########################
# Discrepancies
# +++++++++++++
Expand All @@ -114,17 +131,17 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
sess_split = InferenceSession(model_onnx_split.SerializeToString(),
providers=['CPUExecutionProvider'])

X32 = X.astype(numpy.float32)
X32 = X.astype(numpy.float32)[:500]
expected = reg.predict(X32)
got = sess.run(None, {'X': X32})[0].ravel()
got_split = sess_split.run(None, {'X': X32})[0].ravel()

disp = numpy.abs(got - expected).sum()
disp_split = numpy.abs(got_split - expected).sum()
disc_split = numpy.abs(got_split - expected).sum()

print("sum of discrepancies 1 node", disp)
print("sum of discrepancies split node",
disp_split, "ratio:", disp / disp_split)
print(f"sum of discrepancies 1 node: {disp}")
print(f"sum of discrepancies split node: {disc_split}, "
f"ratio: {disp / disc_split}")

######################################
# The sum of the discrepancies were reduced 4, 5 times.
Expand All @@ -136,6 +153,21 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
print("max discrepancies 1 node", disc)
print("max discrepancies split node", disc_split, "ratio:", disc / disc_split)

#######################################
# Let's compare with the double thresholds.
# We compare the inputs into float first and then in double
# to make sure they are the same.

sess_64 = InferenceSession(model_onnx_64.SerializeToString(),
providers=['CPUExecutionProvider'])

X64 = X32.astype(numpy.float64)
expected_64 = reg.predict(X64)
got_64 = sess_64.run(None, {'X': X64})[0].ravel()
disc_64 = numpy.abs(got_64 - expected_64).sum()
disc_max64 = numpy.abs(got_64 - expected_64).max()
print(f"sum of discrepancies with doubles: sum={disc_64}, max={disc_max64}")

################################################
# Processing time
# +++++++++++++++
Expand All @@ -145,6 +177,9 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
print("processing time no split",
timeit.timeit(
lambda: sess.run(None, {'X': X32})[0], number=150))
print("processing time no split with double",
timeit.timeit(
lambda: sess_64.run(None, {'X': X64})[0], number=150))
print("processing time split",
timeit.timeit(
lambda: sess_split.run(None, {'X': X32})[0], number=150))
Expand All @@ -159,22 +194,46 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
res = []
for i in tqdm(list(range(20, 170, 20)) + [200, 300, 400, 500]):
model_onnx_split = to_onnx(reg, X[:1].astype(numpy.float32),
target_opset={'': 14, 'ai.onnx.ml': 2},
target_opset={'': 17, 'ai.onnx.ml': 3},
options={'split': i})
sess_split = InferenceSession(model_onnx_split.SerializeToString(),
providers=['CPUExecutionProvider'])
times = []
for _ in range(0, 4):
begin = time.perf_counter()
sess_split = InferenceSession(model_onnx_split.SerializeToString(),
providers=['CPUExecutionProvider'])
times.append(time.perf_counter() - begin)
times.sort()
got_split = sess_split.run(None, {'X': X32})[0].ravel()
disc_split = numpy.abs(got_split - expected).max()
res.append(dict(split=i, disc=disc_split))
res.append(dict(split=i, max_diff=disc_split, time=sum(times[1:3]) / 2))

df = DataFrame(res).set_index('split')
df["baseline"] = disc
df["baseline_64"] = disc_max64
print(df)

##########################################
# Graph.

ax = df.plot(title="Sum of discrepancies against split\n"
"split = number of tree per node")
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
df[["max_diff", "baseline", "baseline_64"]].plot(
title="Sum of discrepancies against split\n"
"split = numbers of tree per node",
ax=ax[0])
df[["time"]].plot(title="Processing time against split\n"
"split = numbers of tree per node",
ax=ax[1])

##########################################
# Conclusion
# ++++++++++
#
# The time curve is too noisy to conclude.
# More measures should be made.
# The double sum reduces the discrepancies
# but increases the processing time. It is a tradeoff.
# The best option is using double for threshold and summation
# but it requires the latest definition of TreeEnsemble `ai.onnx.ml=3`.


# plt.show()
180 changes: 180 additions & 0 deletions _doc/examples/plot_gexternal_lightgbm_reg_per.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
.. _example-lightgbm-reg-one-off:

Batch predictions vs one-off predictions
========================================

.. index:: LightGBM

The goal is to compare the processing time between batch predictions
and one-off prediction for the same number of predictions
on trees. onnxruntime parallelizes the prediction by trees
or by rows. The rule is fixed and cannot be changed but it seems
to have some loopholes.

.. contents::
:local:

Train a LGBMRegressor
+++++++++++++++++++++
"""
import warnings
import time
import os
from packaging.version import Version
import numpy
from pandas import DataFrame
import onnx
import matplotlib.pyplot as plt
from tqdm import tqdm
from lightgbm import LGBMRegressor
from onnxruntime import InferenceSession
from skl2onnx import update_registered_converter, to_onnx
from skl2onnx.common.shape_calculator import calculate_linear_regressor_output_shapes # noqa
from onnxmltools import __version__ as oml_version
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm # noqa


N = 1000
X = numpy.random.randn(N, 1000)
y = (numpy.random.randn(N) +
numpy.random.randn(N) * 100 * numpy.random.randint(0, 1, N))

filenames = [f"plot_lightgbm_regressor_1000_{X.shape[1]}.onnx",
f"plot_lightgbm_regressor_10_{X.shape[1]}.onnx",
f"plot_lightgbm_regressor_2_{X.shape[1]}.onnx"]

if not os.path.exists(filenames[0]):
print(f"training with shape={X.shape}")
reg_1000 = LGBMRegressor(n_estimators=1000)
reg_1000.fit(X, y)
reg_10 = LGBMRegressor(n_estimators=10)
reg_10.fit(X, y)
reg_2 = LGBMRegressor(n_estimators=2)
reg_2.fit(X, y)
print("done.")
else:
print("A model was already trained. Reusing it.")

######################################
# Register the converter for LGBMRegressor
# ++++++++++++++++++++++++++++++++++++++++


def skl2onnx_convert_lightgbm(scope, operator, container):
options = scope.get_options(operator.raw_operator)
if 'split' in options:
if Version(oml_version) < Version('1.9.2'):
warnings.warn(
"Option split was released in version 1.9.2 but %s is "
"installed. It will be ignored." % oml_version)
operator.split = options['split']
else:
operator.split = None
convert_lightgbm(scope, operator, container)


update_registered_converter(
LGBMRegressor, 'LightGbmLGBMRegressor',
calculate_linear_regressor_output_shapes,
skl2onnx_convert_lightgbm,
options={'split': None})

##################################
# Convert
# +++++++
#
# We convert the same model following the two scenarios, one single
# TreeEnsembleRegressor node, or more. *split* parameter is the number of
# trees per node TreeEnsembleRegressor.

if not os.path.exists(filenames[0]):
model_onnx_1000 = to_onnx(reg_1000, X[:1].astype(numpy.float32),
target_opset={'': 17, 'ai.onnx.ml': 3})
with open(filenames[0], "wb") as f:
f.write(model_onnx_1000.SerializeToString())
model_onnx_10 = to_onnx(reg_10, X[:1].astype(numpy.float32),
target_opset={'': 17, 'ai.onnx.ml': 3})
with open(filenames[1], "wb") as f:
f.write(model_onnx_10.SerializeToString())
model_onnx_2 = to_onnx(reg_2, X[:1].astype(numpy.float32),
target_opset={'': 17, 'ai.onnx.ml': 3})
with open(filenames[2], "wb") as f:
f.write(model_onnx_10.SerializeToString())
else:
with open(filenames[0], "rb") as f:
model_onnx_1000 = onnx.load(f)
with open(filenames[1], "rb") as f:
model_onnx_10 = onnx.load(f)
with open(filenames[2], "rb") as f:
model_onnx_2 = onnx.load(f)

sess_1000 = InferenceSession(model_onnx_1000.SerializeToString(),
providers=['CPUExecutionProvider'])
sess_10 = InferenceSession(model_onnx_10.SerializeToString(),
providers=['CPUExecutionProvider'])
sess_2 = InferenceSession(model_onnx_2.SerializeToString(),
providers=['CPUExecutionProvider'])

##########################
# Processing time
# +++++++++++++++
#

repeat = 5
data = []
for N in tqdm(list(range(10, 100, 10)) +
list(range(100, 1000, 100)) +
list(range(1000, 10001, 1000))):

X32 = numpy.random.randn(N, X.shape[1]).astype(numpy.float32)
obs = dict(N=N)
for sess, T in [(sess_1000, 1000), (sess_10, 10), (sess_2, 2)]:
times = []
for _ in range(repeat):
begin = time.perf_counter()
sess.run(None, {'X': X32})
end = time.perf_counter() - begin
times.append(end / X32.shape[0])
times.sort()
obs[f"batch-{T}"] = sum(times[1:-1]) / (len(times) - 2)

times = []
for _ in range(repeat):
begin = time.perf_counter()
for i in range(X32.shape[0]):
sess.run(None, {'X': X32[i: i + 1]})
end = time.perf_counter() - begin
times.append(end / X32.shape[0])
times.sort()
obs[f"one-off-{T}"] = sum(times[1:-1]) / (len(times) - 2)
data.append(obs)

df = DataFrame(data).set_index("N")
print(df)

########################################
# Plots.
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

df[["batch-1000", "one-off-1000"]].plot(
ax=ax[0], title="Processing time per observation\n1000 Trees",
logy=True, logx=True)
df[["batch-10", "one-off-10"]].plot(
ax=ax[1], title="Processing time per observation\n10 Trees",
logy=True, logx=True)
df[["batch-2", "one-off-2"]].plot(
ax=ax[2], title="Processing time per observation\n2 Trees",
logy=True, logx=True)

##########################################
# Conclusion
#
# The first graph shows a huge drop the prediction time by batch.
# It means the parallelization is triggered. It may have been triggered
# sooner on this machine but this decision could be different on another one.
# An approach like the one TVM chose could be a good answer. If the model
# must be fast, then it is worth benchmarking many strategies to parallelize
# until the best one is found on a specific machine.

# plt.show()