diff --git a/_doc/examples/plot_gexternal_lightgbm_reg.py b/_doc/examples/plot_gexternal_lightgbm_reg.py index 583502f..913a359 100644 --- a/_doc/examples/plot_gexternal_lightgbm_reg.py +++ b/_doc/examples/plot_gexternal_lightgbm_reg.py @@ -29,23 +29,31 @@ :math:`D'(x) = |\\sum_{k=1}^a \\left[\\sum\\right]_{i=1}^{F/a} float(T_{ak + i}(x)) - \\sum_{i=1}^F T_i(x)|`. +In 2022, :epkg:`onnx` and :epkg:`onnxruntime` updated the specifications +of TreeEnsemble operators and they can now support double thresholds +(see `TreeEnsembleRegressor v3 +`_). +That would be the recommended option to reduce the discrepancies. + .. contents:: :local: Train a LGBMRegressor +++++++++++++++++++++ """ -from distutils.version import StrictVersion import warnings +import time import timeit +from packaging.version import Version import numpy from pandas import DataFrame -# import matplotlib.pyplot as plt +import matplotlib.pyplot as plt from tqdm import tqdm from lightgbm import LGBMRegressor from onnxruntime import InferenceSession -from skl2onnx import to_onnx, update_registered_converter +from skl2onnx import update_registered_converter from skl2onnx.common.shape_calculator import calculate_linear_regressor_output_shapes # noqa +from mlprodict.onnx_conv import to_onnx from onnxmltools import __version__ as oml_version from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm # noqa @@ -59,8 +67,8 @@ reg.fit(X, y) ###################################### -# Register the converter for LGBMClassifier -# +++++++++++++++++++++++++++++++++++++++++ +# Register the converter for LGBMRegressor +# ++++++++++++++++++++++++++++++++++++++++ # # The converter is implemented in :epkg:`onnxmltools`: # `onnxmltools...LightGbm.py @@ -75,7 +83,7 @@ def skl2onnx_convert_lightgbm(scope, operator, container): options = scope.get_options(operator.raw_operator) if 'split' in options: - if StrictVersion(oml_version) < StrictVersion('1.9.2'): + if Version(oml_version) < Version('1.9.2'): warnings.warn( "Option split was released in version 1.9.2 but %s is " "installed. It will be ignored." % oml_version) @@ -100,11 +108,20 @@ def skl2onnx_convert_lightgbm(scope, operator, container): # trees per node TreeEnsembleRegressor. model_onnx = to_onnx(reg, X[:1].astype(numpy.float32), - target_opset={'': 14, 'ai.onnx.ml': 2}) + target_opset={'': 17, 'ai.onnx.ml': 3}) + model_onnx_split = to_onnx(reg, X[:1].astype(numpy.float32), - target_opset={'': 14, 'ai.onnx.ml': 2}, + target_opset={'': 17, 'ai.onnx.ml': 3}, options={'split': 100}) +#################################### +# We create another model using the `ai.onnx.ml == 3`. +# Node thresholds are stored in doubles and not in floats anymore. + +model_onnx_64 = to_onnx(reg, X[:1].astype(numpy.float64), + target_opset={'': 17, 'ai.onnx.ml': 3}, + rewrite_ops=True) + ########################## # Discrepancies # +++++++++++++ @@ -114,17 +131,17 @@ def skl2onnx_convert_lightgbm(scope, operator, container): sess_split = InferenceSession(model_onnx_split.SerializeToString(), providers=['CPUExecutionProvider']) -X32 = X.astype(numpy.float32) +X32 = X.astype(numpy.float32)[:500] expected = reg.predict(X32) got = sess.run(None, {'X': X32})[0].ravel() got_split = sess_split.run(None, {'X': X32})[0].ravel() disp = numpy.abs(got - expected).sum() -disp_split = numpy.abs(got_split - expected).sum() +disc_split = numpy.abs(got_split - expected).sum() -print("sum of discrepancies 1 node", disp) -print("sum of discrepancies split node", - disp_split, "ratio:", disp / disp_split) +print(f"sum of discrepancies 1 node: {disp}") +print(f"sum of discrepancies split node: {disc_split}, " + f"ratio: {disp / disc_split}") ###################################### # The sum of the discrepancies were reduced 4, 5 times. @@ -136,6 +153,21 @@ def skl2onnx_convert_lightgbm(scope, operator, container): print("max discrepancies 1 node", disc) print("max discrepancies split node", disc_split, "ratio:", disc / disc_split) +####################################### +# Let's compare with the double thresholds. +# We compare the inputs into float first and then in double +# to make sure they are the same. + +sess_64 = InferenceSession(model_onnx_64.SerializeToString(), + providers=['CPUExecutionProvider']) + +X64 = X32.astype(numpy.float64) +expected_64 = reg.predict(X64) +got_64 = sess_64.run(None, {'X': X64})[0].ravel() +disc_64 = numpy.abs(got_64 - expected_64).sum() +disc_max64 = numpy.abs(got_64 - expected_64).max() +print(f"sum of discrepancies with doubles: sum={disc_64}, max={disc_max64}") + ################################################ # Processing time # +++++++++++++++ @@ -145,6 +177,9 @@ def skl2onnx_convert_lightgbm(scope, operator, container): print("processing time no split", timeit.timeit( lambda: sess.run(None, {'X': X32})[0], number=150)) +print("processing time no split with double", + timeit.timeit( + lambda: sess_64.run(None, {'X': X64})[0], number=150)) print("processing time split", timeit.timeit( lambda: sess_split.run(None, {'X': X32})[0], number=150)) @@ -159,22 +194,46 @@ def skl2onnx_convert_lightgbm(scope, operator, container): res = [] for i in tqdm(list(range(20, 170, 20)) + [200, 300, 400, 500]): model_onnx_split = to_onnx(reg, X[:1].astype(numpy.float32), - target_opset={'': 14, 'ai.onnx.ml': 2}, + target_opset={'': 17, 'ai.onnx.ml': 3}, options={'split': i}) - sess_split = InferenceSession(model_onnx_split.SerializeToString(), - providers=['CPUExecutionProvider']) + times = [] + for _ in range(0, 4): + begin = time.perf_counter() + sess_split = InferenceSession(model_onnx_split.SerializeToString(), + providers=['CPUExecutionProvider']) + times.append(time.perf_counter() - begin) + times.sort() got_split = sess_split.run(None, {'X': X32})[0].ravel() disc_split = numpy.abs(got_split - expected).max() - res.append(dict(split=i, disc=disc_split)) + res.append(dict(split=i, max_diff=disc_split, time=sum(times[1:3]) / 2)) df = DataFrame(res).set_index('split') df["baseline"] = disc +df["baseline_64"] = disc_max64 print(df) ########################################## # Graph. -ax = df.plot(title="Sum of discrepancies against split\n" - "split = number of tree per node") +fig, ax = plt.subplots(1, 2, figsize=(10, 4)) +df[["max_diff", "baseline", "baseline_64"]].plot( + title="Sum of discrepancies against split\n" + "split = numbers of tree per node", + ax=ax[0]) +df[["time"]].plot(title="Processing time against split\n" + "split = numbers of tree per node", + ax=ax[1]) + +########################################## +# Conclusion +# ++++++++++ +# +# The time curve is too noisy to conclude. +# More measures should be made. +# The double sum reduces the discrepancies +# but increases the processing time. It is a tradeoff. +# The best option is using double for threshold and summation +# but it requires the latest definition of TreeEnsemble `ai.onnx.ml=3`. + # plt.show() diff --git a/_doc/examples/plot_gexternal_lightgbm_reg_per.py b/_doc/examples/plot_gexternal_lightgbm_reg_per.py new file mode 100644 index 0000000..2f7e30e --- /dev/null +++ b/_doc/examples/plot_gexternal_lightgbm_reg_per.py @@ -0,0 +1,180 @@ +""" +.. _example-lightgbm-reg-one-off: + +Batch predictions vs one-off predictions +======================================== + +.. index:: LightGBM + +The goal is to compare the processing time between batch predictions +and one-off prediction for the same number of predictions +on trees. onnxruntime parallelizes the prediction by trees +or by rows. The rule is fixed and cannot be changed but it seems +to have some loopholes. + +.. contents:: + :local: + +Train a LGBMRegressor ++++++++++++++++++++++ +""" +import warnings +import time +import os +from packaging.version import Version +import numpy +from pandas import DataFrame +import onnx +import matplotlib.pyplot as plt +from tqdm import tqdm +from lightgbm import LGBMRegressor +from onnxruntime import InferenceSession +from skl2onnx import update_registered_converter, to_onnx +from skl2onnx.common.shape_calculator import calculate_linear_regressor_output_shapes # noqa +from onnxmltools import __version__ as oml_version +from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm # noqa + + +N = 1000 +X = numpy.random.randn(N, 1000) +y = (numpy.random.randn(N) + + numpy.random.randn(N) * 100 * numpy.random.randint(0, 1, N)) + +filenames = [f"plot_lightgbm_regressor_1000_{X.shape[1]}.onnx", + f"plot_lightgbm_regressor_10_{X.shape[1]}.onnx", + f"plot_lightgbm_regressor_2_{X.shape[1]}.onnx"] + +if not os.path.exists(filenames[0]): + print(f"training with shape={X.shape}") + reg_1000 = LGBMRegressor(n_estimators=1000) + reg_1000.fit(X, y) + reg_10 = LGBMRegressor(n_estimators=10) + reg_10.fit(X, y) + reg_2 = LGBMRegressor(n_estimators=2) + reg_2.fit(X, y) + print("done.") +else: + print("A model was already trained. Reusing it.") + +###################################### +# Register the converter for LGBMRegressor +# ++++++++++++++++++++++++++++++++++++++++ + + +def skl2onnx_convert_lightgbm(scope, operator, container): + options = scope.get_options(operator.raw_operator) + if 'split' in options: + if Version(oml_version) < Version('1.9.2'): + warnings.warn( + "Option split was released in version 1.9.2 but %s is " + "installed. It will be ignored." % oml_version) + operator.split = options['split'] + else: + operator.split = None + convert_lightgbm(scope, operator, container) + + +update_registered_converter( + LGBMRegressor, 'LightGbmLGBMRegressor', + calculate_linear_regressor_output_shapes, + skl2onnx_convert_lightgbm, + options={'split': None}) + +################################## +# Convert +# +++++++ +# +# We convert the same model following the two scenarios, one single +# TreeEnsembleRegressor node, or more. *split* parameter is the number of +# trees per node TreeEnsembleRegressor. + +if not os.path.exists(filenames[0]): + model_onnx_1000 = to_onnx(reg_1000, X[:1].astype(numpy.float32), + target_opset={'': 17, 'ai.onnx.ml': 3}) + with open(filenames[0], "wb") as f: + f.write(model_onnx_1000.SerializeToString()) + model_onnx_10 = to_onnx(reg_10, X[:1].astype(numpy.float32), + target_opset={'': 17, 'ai.onnx.ml': 3}) + with open(filenames[1], "wb") as f: + f.write(model_onnx_10.SerializeToString()) + model_onnx_2 = to_onnx(reg_2, X[:1].astype(numpy.float32), + target_opset={'': 17, 'ai.onnx.ml': 3}) + with open(filenames[2], "wb") as f: + f.write(model_onnx_10.SerializeToString()) +else: + with open(filenames[0], "rb") as f: + model_onnx_1000 = onnx.load(f) + with open(filenames[1], "rb") as f: + model_onnx_10 = onnx.load(f) + with open(filenames[2], "rb") as f: + model_onnx_2 = onnx.load(f) + +sess_1000 = InferenceSession(model_onnx_1000.SerializeToString(), + providers=['CPUExecutionProvider']) +sess_10 = InferenceSession(model_onnx_10.SerializeToString(), + providers=['CPUExecutionProvider']) +sess_2 = InferenceSession(model_onnx_2.SerializeToString(), + providers=['CPUExecutionProvider']) + +########################## +# Processing time +# +++++++++++++++ +# + +repeat = 5 +data = [] +for N in tqdm(list(range(10, 100, 10)) + + list(range(100, 1000, 100)) + + list(range(1000, 10001, 1000))): + + X32 = numpy.random.randn(N, X.shape[1]).astype(numpy.float32) + obs = dict(N=N) + for sess, T in [(sess_1000, 1000), (sess_10, 10), (sess_2, 2)]: + times = [] + for _ in range(repeat): + begin = time.perf_counter() + sess.run(None, {'X': X32}) + end = time.perf_counter() - begin + times.append(end / X32.shape[0]) + times.sort() + obs[f"batch-{T}"] = sum(times[1:-1]) / (len(times) - 2) + + times = [] + for _ in range(repeat): + begin = time.perf_counter() + for i in range(X32.shape[0]): + sess.run(None, {'X': X32[i: i + 1]}) + end = time.perf_counter() - begin + times.append(end / X32.shape[0]) + times.sort() + obs[f"one-off-{T}"] = sum(times[1:-1]) / (len(times) - 2) + data.append(obs) + +df = DataFrame(data).set_index("N") +print(df) + +######################################## +# Plots. +fig, ax = plt.subplots(1, 3, figsize=(12, 4)) + +df[["batch-1000", "one-off-1000"]].plot( + ax=ax[0], title="Processing time per observation\n1000 Trees", + logy=True, logx=True) +df[["batch-10", "one-off-10"]].plot( + ax=ax[1], title="Processing time per observation\n10 Trees", + logy=True, logx=True) +df[["batch-2", "one-off-2"]].plot( + ax=ax[2], title="Processing time per observation\n2 Trees", + logy=True, logx=True) + +########################################## +# Conclusion +# +# The first graph shows a huge drop the prediction time by batch. +# It means the parallelization is triggered. It may have been triggered +# sooner on this machine but this decision could be different on another one. +# An approach like the one TVM chose could be a good answer. If the model +# must be fast, then it is worth benchmarking many strategies to parallelize +# until the best one is found on a specific machine. + +# plt.show()