From a5bc72e76180ea0a3c596c573c9881c6759eae53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 23 Jan 2022 00:23:06 +0100 Subject: [PATCH 1/6] use bind_ortvalue_input to be faster --- .../test_optimizers_classification.py | 2 +- onnxcustom/training/optimizers.py | 22 +++++++------------ 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/_unittests/ut_training/test_optimizers_classification.py b/_unittests/ut_training/test_optimizers_classification.py index 239f3705..dfc0ad59 100644 --- a/_unittests/ut_training/test_optimizers_classification.py +++ b/_unittests/ut_training/test_optimizers_classification.py @@ -51,7 +51,7 @@ def test_ort_gradient_optimizers_binary(self): train_session = OrtGradientOptimizer( onx_loss, inits, learning_rate=1e-3) self.assertRaise(lambda: train_session.get_state(), AttributeError) - train_session.fit(X_train, y_train.reshape((-1, 1)), use_numpy=True) + train_session.fit(X_train, y_train.reshape((-1, 1)), use_numpy=False) state_tensors = train_session.get_state() self.assertEqual(len(state_tensors), 2) r = repr(train_session) diff --git a/onnxcustom/training/optimizers.py b/onnxcustom/training/optimizers.py index 99f64c6d..eceef8ad 100644 --- a/onnxcustom/training/optimizers.py +++ b/onnxcustom/training/optimizers.py @@ -174,15 +174,9 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue): raise TypeError( # pragma: no cover "Unexpected type %r." % type(bind)) if isinstance(c_ortvalue, C_OrtValue): - # does not work - # bind._iobinding.bind_ortvalue_input(name, c_ortvalue) - dtype = proto_type_to_dtype( - c_ortvalue.proto_type() if hasattr(c_ortvalue, 'proto_type') - else c_ortvalue.data_type()) - bind.bind_input( - name, self.device, dtype, c_ortvalue.shape(), - c_ortvalue.data_ptr()) + bind.bind_ortvalue_input(name, c_ortvalue) elif isinstance(c_ortvalue, numpy.ndarray): + # This fails on linux with int64. bind.bind_input( name, self.device, c_ortvalue.dtype, c_ortvalue.shape, c_ortvalue.__array_interface__['data'][0]) @@ -222,8 +216,8 @@ def _iteration(self, data_loader, ort_lr, bind, use_numpy, sample_weight): self.input_names_[2], bind, weight) self.train_session_._sess.run_with_iobinding(bind, None) - outputs = bind.copy_outputs_to_cpu() - if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]): + loss = bind.get_outputs()[0].numpy() + if numpy.isinf(loss) or numpy.isnan(loss): raise ConvergenceError( "Loss is nan, learning_rate=%r, " "the gradient descent has failed " @@ -232,7 +226,7 @@ def _iteration(self, data_loader, ort_lr, bind, use_numpy, sample_weight): [float(v[0]) for v in ( actual_losses if len(actual_losses) < 5 else actual_losses[-5:])])) - actual_losses.append(outputs[0] / data.shape[0]) + actual_losses.append(loss / data.shape[0]) else: idx = 3 if sample_weight else 2 self._bind_input_ortvalue(self.input_names_[idx], bind, ort_lr) @@ -242,8 +236,8 @@ def _iteration(self, data_loader, ort_lr, bind, use_numpy, sample_weight): for batch_size in data_loader.iter_bind(bind, self.input_names_): self.train_session_._sess.run_with_iobinding(bind, None) # We copy the predicted output as well which is not needed. - outputs = bind.copy_outputs_to_cpu() - if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]): + loss = bind.get_outputs()[0].numpy() + if numpy.isinf(loss) or numpy.isnan(loss): raise ConvergenceError( "Loss is nan or infinite, learning_rate=%r, " "the gradient descent has failed " @@ -252,7 +246,7 @@ def _iteration(self, data_loader, ort_lr, bind, use_numpy, sample_weight): [float(v[0]) for v in ( actual_losses if len(actual_losses) < 5 else actual_losses[-5:])])) - actual_losses.append(outputs[0] / batch_size) + actual_losses.append(loss / batch_size) return numpy.array(actual_losses).mean() From 63179e447e0c7823791e5904b7bedfdeac558b89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 23 Jan 2022 00:24:37 +0100 Subject: [PATCH 2/6] lint --- onnxcustom/training/optimizers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxcustom/training/optimizers.py b/onnxcustom/training/optimizers.py index eceef8ad..c0902ebe 100644 --- a/onnxcustom/training/optimizers.py +++ b/onnxcustom/training/optimizers.py @@ -7,7 +7,6 @@ TrainingParameters, SessionOptions, TrainingSession) from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 OrtValue as C_OrtValue, SessionIOBinding as C_IOBinding) -from ..utils.onnx_helper import proto_type_to_dtype from ..utils.onnxruntime_helper import ( numpy_to_ort_value, device_to_providers) from .data_loader import OrtDataLoader From 7f19943d1d84b414bb85013ad59f4c535fcb6f47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 23 Jan 2022 17:05:13 +0100 Subject: [PATCH 3/6] lint --- _unittests/ut_training/test_optimizers_classification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_training/test_optimizers_classification.py b/_unittests/ut_training/test_optimizers_classification.py index dfc0ad59..c4ae3e59 100644 --- a/_unittests/ut_training/test_optimizers_classification.py +++ b/_unittests/ut_training/test_optimizers_classification.py @@ -2,11 +2,11 @@ @brief test log(time=8s) """ import unittest +import numpy from onnx import TensorProto +from onnx.helper import set_model_props from pyquickhelper.pycode import ( ExtTestCase, get_temp_folder, ignore_warnings) -import numpy -from onnx.helper import set_model_props from sklearn.exceptions import ConvergenceWarning from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split @@ -47,7 +47,7 @@ def test_ort_gradient_optimizers_binary(self): inputs = onx_loss.graph.input self.assertEqual(len(inputs), 2) dt = inputs[1].type.tensor_type.elem_type - self.assertEqual(TensorProto.INT64, dt) + self.assertEqual(TensorProto.INT64, dt) # pylint: disable=E1101 train_session = OrtGradientOptimizer( onx_loss, inits, learning_rate=1e-3) self.assertRaise(lambda: train_session.get_state(), AttributeError) From af232aea5ceaeea49999f20d6ba65511bd28d310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 23 Jan 2022 19:08:33 +0100 Subject: [PATCH 4/6] Adds example about classification --- .../plot_orttraining_benchmark_fwbw_cls.py | 218 ++++++++++++++++++ .../tutorial_6_training_partial.rst | 18 ++ onnxcustom/utils/orttraining_helper.py | 2 +- 3 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 _doc/examples/plot_orttraining_benchmark_fwbw_cls.py diff --git a/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py b/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py new file mode 100644 index 00000000..4a196bb4 --- /dev/null +++ b/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py @@ -0,0 +1,218 @@ +""" + +.. _l-orttraining-benchmark-fwbw-cls: + +Benchmark, comparison sklearn - forward-backward - classification +================================================================= + +The benchmark compares the processing time between :epkg:`scikit-learn` +and :epkg:`onnxruntime-training` on a logistic regression regression +and a neural network for classification. +It replicates the benchmark implemented in :ref:`l-orttraining-benchmark-fwbw`. + +.. contents:: + :local: + +First comparison: neural network +++++++++++++++++++++++++++++++++ + +""" +import warnings +import time +import numpy +import matplotlib.pyplot as plt +from pandas import DataFrame +from onnxruntime import get_device +from pyquickhelper.pycode.profiling import profile, profile2graph +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +from sklearn.neural_network import MLPClassifier +from mlprodict.onnx_conv import to_onnx +from mlprodict.plotting.text_plot import onnx_simple_text_plot +from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs +from onnxcustom.utils.onnx_helper import onnx_rename_weights +from onnxcustom.training.optimizers_partial import ( + OrtGradientForwardBackwardOptimizer) +from onnxcustom.training.sgd_learning_rate import LearningRateSGDNesterov +from onnxcustom.training.sgd_learning_loss import NegLogLearningLoss +from onnxcustom.training.sgd_learning_penalty import ElasticLearningPenalty + + +X, y = make_classification(1000, n_features=100, n_classes=2) +X = X.astype(numpy.float32) +y = y.astype(numpy.int64) +X_train, X_test, y_train, y_test = train_test_split(X, y) + +######################################## +# Benchmark function. + + +def benchmark(X, y, skl_model, train_session, name, verbose=True): + """ + :param skl_model: model from scikit-learn + :param train_session: instance of OrtGradientForwardBackwardOptimizer + :param name: experiment name + :param verbose: to debug + """ + print("[benchmark] %s" % name) + begin = time.perf_counter() + skl_model.fit(X, y) + duration_skl = time.perf_counter() - begin + length_skl = len(skl_model.loss_curve_) + print("[benchmark] skl=%r iterations - %r seconds" % ( + length_skl, duration_skl)) + + begin = time.perf_counter() + train_session.fit(X, y) + duration_ort = time.perf_counter() - begin + length_ort = len(train_session.train_losses_) + print("[benchmark] ort=%r iteration - %r seconds" % ( + length_ort, duration_ort)) + + return dict(skl=duration_skl, ort=duration_ort, name=name, + iter_skl=length_skl, iter_ort=length_ort, + losses_skl=skl_model.loss_curve_, + losses_ort=train_session.train_losses_) + + +######################################## +# Common parameters and model + +batch_size = 15 +max_iter = 100 + +nn = MLPClassifier(hidden_layer_sizes=(50, 10), max_iter=max_iter, + solver='sgd', learning_rate_init=5e-3, alpha=1e-4, + n_iter_no_change=max_iter * 3, batch_size=batch_size, + nesterovs_momentum=True, momentum=0.9, + learning_rate="invscaling") + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + nn.fit(X_train, y_train) + +######################################## +# Conversion to ONNX and trainer initialization +# It is slightly different from a regression model. +# Probabilities usually come from raw scores transformed +# through a function such as the sigmoid function. +# The gradient of the loss is computed against the raw scores +# because it is easier to compute than to let onnxruntime +# do it. + +onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15, + options={'zipmap': False, 'nocl': True}) +print(onnx_simple_text_plot(onx)) + +########################################## +# Raw scores are the input of operator *Sigmoid*. + +onx = select_model_inputs_outputs( + onx, outputs=["add_result2"], infer_shapes=True) +print(onnx_simple_text_plot(onx)) + +######################################### +# And the names are renamed to have them follow the +# alphabetical order (see :class:`OrtGradientForwardBackward +# `). + +onx = onnx_rename_weights(onx) +print(onnx_simple_text_plot(onx)) + +################################################ +# We select the log loss (see :class:`NegLogLearningLoss +# `, +# a simple penalty defined with :class:`ElasticLearningPenalty +# `, +# and the Nesterov algorithm to update the weights with +# `LearningRateSGDNesterov +# `. + +train_session = OrtGradientForwardBackwardOptimizer( + onx, device='cpu', warm_start=False, + max_iter=max_iter, batch_size=batch_size, + learning_loss=NegLogLearningLoss(), + learning_rate=LearningRateSGDNesterov( + 1e-5, nesterov=True, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) + + +benches = [benchmark(X_train, y_train, nn, train_session, name='NN-CPU')] + +###################################### +# Profiling +# +++++++++ + + +def clean_name(text): + pos = text.find('onnxruntime') + if pos >= 0: + return text[pos:] + pos = text.find('sklearn') + if pos >= 0: + return text[pos:] + pos = text.find('onnxcustom') + if pos >= 0: + return text[pos:] + pos = text.find('site-packages') + if pos >= 0: + return text[pos:] + return text + + +ps = profile(lambda: benchmark(X_train, y_train, + nn, train_session, name='NN-CPU'))[0] +root, nodes = profile2graph(ps, clean_text=clean_name) +text = root.to_text() +print(text) + +###################################### +# if GPU is available +# +++++++++++++++++++ + +if get_device().upper() == 'GPU': + + train_session = OrtGradientForwardBackwardOptimizer( + onx, device='cuda', warm_start=False, + max_iter=max_iter, batch_size=batch_size, + learning_loss=NegLogLearningLoss(), + learning_rate=LearningRateSGDNesterov( + 1e-5, nesterov=False, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) + + benches.append(benchmark(X_train, y_train, nn, + train_session, name='NN-GPU')) + + +###################################### +# Graphs +# ++++++ +# +# Dataframe first. + +df = DataFrame(benches).set_index('name') +df + +####################################### +# text output + +print(df) + +####################################### +# Graphs. + +fig, ax = plt.subplots(1, 2, figsize=(10, 4)) +df[['skl', 'ort']].plot.bar(title="Processing time", ax=ax[0]) +ax[0].tick_params(axis='x', rotation=30) +for bench in benches: + ax[1].plot(bench['losses_skl'][1:], label='skl-' + bench['name']) + ax[1].plot(bench['losses_ort'][1:], label='ort-' + bench['name']) +ax[1].set_yscale('log') +ax[1].set_title("Losses") +ax[1].legend() + +######################################## +# The gradient update are not exactly the same. +# It should be improved for a fair comprison. + +plt.show() diff --git a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst index 9f24f9a8..128221b9 100644 --- a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst +++ b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst @@ -127,6 +127,23 @@ used by the three components mentioned above. They cache the binded pointers if the method is called again with a different `OrtValue` but a same pointer returned by `data_ptr()`. +Binary classification ++++++++++++++++++++++ + +Probabilities are computed from raw scores with a function such as the +`sigmoid function `_. +A binary function produces two probilities: :math:`sigmoid(s)` +:math:`(1 - sigmoid(s))` where *s* is the raw score. The associated loss +function is usually the log loss: :math:`loss(y, X) = +(1-y) \log(1-p(s)) + y \log p(s)` where *y* is the expected class (0 or 1), +*s=s(X)* is the raw score, *p(s)* is the probability. +We could compute the gradient of the loss +against the probability and let :epkg:`onnxruntime-training` handle the +computation of the gradient from the probability to the input. +However, the gradient of the loss against the raw score can easily be +expressed as :math:`grad(loss(y, s)) = y - p(s)`. The second +option is implemented in example :ref:`l-orttraining-benchmark-fwbw-cls`. + Examples ++++++++ @@ -143,3 +160,4 @@ with ONNX and :epkg:`onnxruntime-training`. ../gyexamples/plot_orttraining_nn_gpu_fwbw ../gyexamples/plot_orttraining_nn_gpu_fwbw_nesterov ../gyexamples/plot_orttraining_benchmark_fwbw + ../gyexamples/plot_orttraining_benchmark_fwbw_cls diff --git a/onnxcustom/utils/orttraining_helper.py b/onnxcustom/utils/orttraining_helper.py index c4f9c437..18079c8d 100644 --- a/onnxcustom/utils/orttraining_helper.py +++ b/onnxcustom/utils/orttraining_helper.py @@ -498,7 +498,7 @@ def _replace(ens): elem = output_onx.type.tensor_type.elem_type if elem == 0: raise TypeError( # pragma: no cover - "Unable to guess inut tensor type from %r." + "Unable to guess input tensor type from %r." "" % output_onx) shape = [] for d in output_onx.type.tensor_type.shape.dim: From 7fe1ff8978f381fa046fbb62e654a4d455d5e7d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 23 Jan 2022 19:17:51 +0100 Subject: [PATCH 5/6] catch one issue --- onnxcustom/training/data_loader.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/onnxcustom/training/data_loader.py b/onnxcustom/training/data_loader.py index f2b65226..06069e86 100644 --- a/onnxcustom/training/data_loader.py +++ b/onnxcustom/training/data_loader.py @@ -197,12 +197,30 @@ def local_bind(bind, offset, n): shape_X = (n, n_col_x) shape_y = (n, n_col_y) - bind.bind_input( - names[0], self.device, self.desc[0][1], shape_X, - self.X_ort.data_ptr() + offset * n_col_x * size_x) - bind.bind_input( - names[1], self.device, self.desc[1][1], shape_y, - self.y_ort.data_ptr() + offset * n_col_y * size_y) + try: + bind.bind_input( + names[0], self.device, self.desc[0][1], shape_X, + self.X_ort.data_ptr() + offset * n_col_x * size_x) + except RuntimeError as e: + raise RuntimeError( + "Unable to bind data input (X) %r, device=%r desc=%r " + "data_ptr=%r offset=%r n_col_x=%r size_x=%r " + "type(bind)=%r" % ( + names[0], self.device, self.desc[0][1], + self.X_ort.data_ptr(), offset, n_col_x, size_x, + type(bind))) from e + try: + bind.bind_input( + names[1], self.device, self.desc[1][1], shape_y, + self.y_ort.data_ptr() + offset * n_col_y * size_y) + except RuntimeError as e: + raise RuntimeError( + "Unable to bind data input (y) %r, device=%r desc=%r " + "data_ptr=%r offset=%r n_col_y=%r size_y=%r " + "type(bind)=%r" % ( + names[1], self.device, self.desc[1][1], + self.y_ort.data_ptr(), offset, n_col_y, size_y, + type(bind))) from e def local_bindw(bind, offset, n): # This function assumes the data is contiguous. From d48fd2a0961aec948cc7988a918c4b2173a2ff96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 23 Jan 2022 22:02:34 +0100 Subject: [PATCH 6/6] Update plot_orttraining_benchmark_fwbw_cls.py --- .../plot_orttraining_benchmark_fwbw_cls.py | 58 +++++++++++++++++-- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py b/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py index 4a196bb4..c4398800 100644 --- a/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py +++ b/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py @@ -82,7 +82,7 @@ def benchmark(X, y, skl_model, train_session, name, verbose=True): max_iter = 100 nn = MLPClassifier(hidden_layer_sizes=(50, 10), max_iter=max_iter, - solver='sgd', learning_rate_init=5e-3, alpha=1e-4, + solver='sgd', learning_rate_init=1e-1, alpha=1e-4, n_iter_no_change=max_iter * 3, batch_size=batch_size, nesterovs_momentum=True, momentum=0.9, learning_rate="invscaling") @@ -133,7 +133,7 @@ def benchmark(X, y, skl_model, train_session, name, verbose=True): max_iter=max_iter, batch_size=batch_size, learning_loss=NegLogLearningLoss(), learning_rate=LearningRateSGDNesterov( - 1e-5, nesterov=True, momentum=0.9), + 1e-7, nesterov=True, momentum=0.9), learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) @@ -177,12 +177,62 @@ def clean_name(text): max_iter=max_iter, batch_size=batch_size, learning_loss=NegLogLearningLoss(), learning_rate=LearningRateSGDNesterov( - 1e-5, nesterov=False, momentum=0.9), + 1e-7, nesterov=False, momentum=0.9), learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) benches.append(benchmark(X_train, y_train, nn, train_session, name='NN-GPU')) +####################################### +# A simple linear layer +# +++++++++++++++++++++ + +nn = MLPClassifier(hidden_layer_sizes=tuple(), max_iter=max_iter, + solver='sgd', learning_rate_init=1e-1, alpha=1e-4, + n_iter_no_change=max_iter * 3, batch_size=batch_size, + nesterovs_momentum=True, momentum=0.9, + learning_rate="invscaling", activation='identity') + + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + nn.fit(X_train, y_train) + +onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15, + options={'zipmap': False, 'nocl': True}) +print(onnx_simple_text_plot(onx)) + +onx = select_model_inputs_outputs( + onx, outputs=["add_result"], infer_shapes=True) +print(onnx_simple_text_plot(onx)) + +onx = onnx_rename_weights(onx) +print(onnx_simple_text_plot(onx)) + +train_session = OrtGradientForwardBackwardOptimizer( + onx, device='cpu', warm_start=False, + max_iter=max_iter, batch_size=batch_size, + learning_loss=NegLogLearningLoss(), + learning_rate=LearningRateSGDNesterov( + 1e-5, nesterov=True, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) + + +benches.append(benchmark(X_train, y_train, nn, train_session, name='LR-CPU')) + +if get_device().upper() == 'GPU': + + train_session = OrtGradientForwardBackwardOptimizer( + onx, device='cuda', warm_start=False, + max_iter=max_iter, batch_size=batch_size, + learning_loss=NegLogLearningLoss(), + learning_rate=LearningRateSGDNesterov( + 1e-5, nesterov=False, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) + + benches.append(benchmark(X_train, y_train, nn, + train_session, name='LR-GPU')) + ###################################### # Graphs @@ -215,4 +265,4 @@ def clean_name(text): # The gradient update are not exactly the same. # It should be improved for a fair comprison. -plt.show() +# plt.show()