diff --git a/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py b/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py new file mode 100644 index 00000000..c4398800 --- /dev/null +++ b/_doc/examples/plot_orttraining_benchmark_fwbw_cls.py @@ -0,0 +1,268 @@ +""" + +.. _l-orttraining-benchmark-fwbw-cls: + +Benchmark, comparison sklearn - forward-backward - classification +================================================================= + +The benchmark compares the processing time between :epkg:`scikit-learn` +and :epkg:`onnxruntime-training` on a logistic regression regression +and a neural network for classification. +It replicates the benchmark implemented in :ref:`l-orttraining-benchmark-fwbw`. + +.. contents:: + :local: + +First comparison: neural network +++++++++++++++++++++++++++++++++ + +""" +import warnings +import time +import numpy +import matplotlib.pyplot as plt +from pandas import DataFrame +from onnxruntime import get_device +from pyquickhelper.pycode.profiling import profile, profile2graph +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +from sklearn.neural_network import MLPClassifier +from mlprodict.onnx_conv import to_onnx +from mlprodict.plotting.text_plot import onnx_simple_text_plot +from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs +from onnxcustom.utils.onnx_helper import onnx_rename_weights +from onnxcustom.training.optimizers_partial import ( + OrtGradientForwardBackwardOptimizer) +from onnxcustom.training.sgd_learning_rate import LearningRateSGDNesterov +from onnxcustom.training.sgd_learning_loss import NegLogLearningLoss +from onnxcustom.training.sgd_learning_penalty import ElasticLearningPenalty + + +X, y = make_classification(1000, n_features=100, n_classes=2) +X = X.astype(numpy.float32) +y = y.astype(numpy.int64) +X_train, X_test, y_train, y_test = train_test_split(X, y) + +######################################## +# Benchmark function. + + +def benchmark(X, y, skl_model, train_session, name, verbose=True): + """ + :param skl_model: model from scikit-learn + :param train_session: instance of OrtGradientForwardBackwardOptimizer + :param name: experiment name + :param verbose: to debug + """ + print("[benchmark] %s" % name) + begin = time.perf_counter() + skl_model.fit(X, y) + duration_skl = time.perf_counter() - begin + length_skl = len(skl_model.loss_curve_) + print("[benchmark] skl=%r iterations - %r seconds" % ( + length_skl, duration_skl)) + + begin = time.perf_counter() + train_session.fit(X, y) + duration_ort = time.perf_counter() - begin + length_ort = len(train_session.train_losses_) + print("[benchmark] ort=%r iteration - %r seconds" % ( + length_ort, duration_ort)) + + return dict(skl=duration_skl, ort=duration_ort, name=name, + iter_skl=length_skl, iter_ort=length_ort, + losses_skl=skl_model.loss_curve_, + losses_ort=train_session.train_losses_) + + +######################################## +# Common parameters and model + +batch_size = 15 +max_iter = 100 + +nn = MLPClassifier(hidden_layer_sizes=(50, 10), max_iter=max_iter, + solver='sgd', learning_rate_init=1e-1, alpha=1e-4, + n_iter_no_change=max_iter * 3, batch_size=batch_size, + nesterovs_momentum=True, momentum=0.9, + learning_rate="invscaling") + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + nn.fit(X_train, y_train) + +######################################## +# Conversion to ONNX and trainer initialization +# It is slightly different from a regression model. +# Probabilities usually come from raw scores transformed +# through a function such as the sigmoid function. +# The gradient of the loss is computed against the raw scores +# because it is easier to compute than to let onnxruntime +# do it. + +onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15, + options={'zipmap': False, 'nocl': True}) +print(onnx_simple_text_plot(onx)) + +########################################## +# Raw scores are the input of operator *Sigmoid*. + +onx = select_model_inputs_outputs( + onx, outputs=["add_result2"], infer_shapes=True) +print(onnx_simple_text_plot(onx)) + +######################################### +# And the names are renamed to have them follow the +# alphabetical order (see :class:`OrtGradientForwardBackward +# `). + +onx = onnx_rename_weights(onx) +print(onnx_simple_text_plot(onx)) + +################################################ +# We select the log loss (see :class:`NegLogLearningLoss +# `, +# a simple penalty defined with :class:`ElasticLearningPenalty +# `, +# and the Nesterov algorithm to update the weights with +# `LearningRateSGDNesterov +# `. + +train_session = OrtGradientForwardBackwardOptimizer( + onx, device='cpu', warm_start=False, + max_iter=max_iter, batch_size=batch_size, + learning_loss=NegLogLearningLoss(), + learning_rate=LearningRateSGDNesterov( + 1e-7, nesterov=True, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) + + +benches = [benchmark(X_train, y_train, nn, train_session, name='NN-CPU')] + +###################################### +# Profiling +# +++++++++ + + +def clean_name(text): + pos = text.find('onnxruntime') + if pos >= 0: + return text[pos:] + pos = text.find('sklearn') + if pos >= 0: + return text[pos:] + pos = text.find('onnxcustom') + if pos >= 0: + return text[pos:] + pos = text.find('site-packages') + if pos >= 0: + return text[pos:] + return text + + +ps = profile(lambda: benchmark(X_train, y_train, + nn, train_session, name='NN-CPU'))[0] +root, nodes = profile2graph(ps, clean_text=clean_name) +text = root.to_text() +print(text) + +###################################### +# if GPU is available +# +++++++++++++++++++ + +if get_device().upper() == 'GPU': + + train_session = OrtGradientForwardBackwardOptimizer( + onx, device='cuda', warm_start=False, + max_iter=max_iter, batch_size=batch_size, + learning_loss=NegLogLearningLoss(), + learning_rate=LearningRateSGDNesterov( + 1e-7, nesterov=False, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) + + benches.append(benchmark(X_train, y_train, nn, + train_session, name='NN-GPU')) + +####################################### +# A simple linear layer +# +++++++++++++++++++++ + +nn = MLPClassifier(hidden_layer_sizes=tuple(), max_iter=max_iter, + solver='sgd', learning_rate_init=1e-1, alpha=1e-4, + n_iter_no_change=max_iter * 3, batch_size=batch_size, + nesterovs_momentum=True, momentum=0.9, + learning_rate="invscaling", activation='identity') + + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + nn.fit(X_train, y_train) + +onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15, + options={'zipmap': False, 'nocl': True}) +print(onnx_simple_text_plot(onx)) + +onx = select_model_inputs_outputs( + onx, outputs=["add_result"], infer_shapes=True) +print(onnx_simple_text_plot(onx)) + +onx = onnx_rename_weights(onx) +print(onnx_simple_text_plot(onx)) + +train_session = OrtGradientForwardBackwardOptimizer( + onx, device='cpu', warm_start=False, + max_iter=max_iter, batch_size=batch_size, + learning_loss=NegLogLearningLoss(), + learning_rate=LearningRateSGDNesterov( + 1e-5, nesterov=True, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) + + +benches.append(benchmark(X_train, y_train, nn, train_session, name='LR-CPU')) + +if get_device().upper() == 'GPU': + + train_session = OrtGradientForwardBackwardOptimizer( + onx, device='cuda', warm_start=False, + max_iter=max_iter, batch_size=batch_size, + learning_loss=NegLogLearningLoss(), + learning_rate=LearningRateSGDNesterov( + 1e-5, nesterov=False, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4)) + + benches.append(benchmark(X_train, y_train, nn, + train_session, name='LR-GPU')) + + +###################################### +# Graphs +# ++++++ +# +# Dataframe first. + +df = DataFrame(benches).set_index('name') +df + +####################################### +# text output + +print(df) + +####################################### +# Graphs. + +fig, ax = plt.subplots(1, 2, figsize=(10, 4)) +df[['skl', 'ort']].plot.bar(title="Processing time", ax=ax[0]) +ax[0].tick_params(axis='x', rotation=30) +for bench in benches: + ax[1].plot(bench['losses_skl'][1:], label='skl-' + bench['name']) + ax[1].plot(bench['losses_ort'][1:], label='ort-' + bench['name']) +ax[1].set_yscale('log') +ax[1].set_title("Losses") +ax[1].legend() + +######################################## +# The gradient update are not exactly the same. +# It should be improved for a fair comprison. + +# plt.show() diff --git a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst index 9f24f9a8..128221b9 100644 --- a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst +++ b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst @@ -127,6 +127,23 @@ used by the three components mentioned above. They cache the binded pointers if the method is called again with a different `OrtValue` but a same pointer returned by `data_ptr()`. +Binary classification ++++++++++++++++++++++ + +Probabilities are computed from raw scores with a function such as the +`sigmoid function `_. +A binary function produces two probilities: :math:`sigmoid(s)` +:math:`(1 - sigmoid(s))` where *s* is the raw score. The associated loss +function is usually the log loss: :math:`loss(y, X) = +(1-y) \log(1-p(s)) + y \log p(s)` where *y* is the expected class (0 or 1), +*s=s(X)* is the raw score, *p(s)* is the probability. +We could compute the gradient of the loss +against the probability and let :epkg:`onnxruntime-training` handle the +computation of the gradient from the probability to the input. +However, the gradient of the loss against the raw score can easily be +expressed as :math:`grad(loss(y, s)) = y - p(s)`. The second +option is implemented in example :ref:`l-orttraining-benchmark-fwbw-cls`. + Examples ++++++++ @@ -143,3 +160,4 @@ with ONNX and :epkg:`onnxruntime-training`. ../gyexamples/plot_orttraining_nn_gpu_fwbw ../gyexamples/plot_orttraining_nn_gpu_fwbw_nesterov ../gyexamples/plot_orttraining_benchmark_fwbw + ../gyexamples/plot_orttraining_benchmark_fwbw_cls diff --git a/onnxcustom/training/data_loader.py b/onnxcustom/training/data_loader.py index f2b65226..06069e86 100644 --- a/onnxcustom/training/data_loader.py +++ b/onnxcustom/training/data_loader.py @@ -197,12 +197,30 @@ def local_bind(bind, offset, n): shape_X = (n, n_col_x) shape_y = (n, n_col_y) - bind.bind_input( - names[0], self.device, self.desc[0][1], shape_X, - self.X_ort.data_ptr() + offset * n_col_x * size_x) - bind.bind_input( - names[1], self.device, self.desc[1][1], shape_y, - self.y_ort.data_ptr() + offset * n_col_y * size_y) + try: + bind.bind_input( + names[0], self.device, self.desc[0][1], shape_X, + self.X_ort.data_ptr() + offset * n_col_x * size_x) + except RuntimeError as e: + raise RuntimeError( + "Unable to bind data input (X) %r, device=%r desc=%r " + "data_ptr=%r offset=%r n_col_x=%r size_x=%r " + "type(bind)=%r" % ( + names[0], self.device, self.desc[0][1], + self.X_ort.data_ptr(), offset, n_col_x, size_x, + type(bind))) from e + try: + bind.bind_input( + names[1], self.device, self.desc[1][1], shape_y, + self.y_ort.data_ptr() + offset * n_col_y * size_y) + except RuntimeError as e: + raise RuntimeError( + "Unable to bind data input (y) %r, device=%r desc=%r " + "data_ptr=%r offset=%r n_col_y=%r size_y=%r " + "type(bind)=%r" % ( + names[1], self.device, self.desc[1][1], + self.y_ort.data_ptr(), offset, n_col_y, size_y, + type(bind))) from e def local_bindw(bind, offset, n): # This function assumes the data is contiguous. diff --git a/onnxcustom/utils/orttraining_helper.py b/onnxcustom/utils/orttraining_helper.py index c4f9c437..18079c8d 100644 --- a/onnxcustom/utils/orttraining_helper.py +++ b/onnxcustom/utils/orttraining_helper.py @@ -498,7 +498,7 @@ def _replace(ens): elem = output_onx.type.tensor_type.elem_type if elem == 0: raise TypeError( # pragma: no cover - "Unable to guess inut tensor type from %r." + "Unable to guess input tensor type from %r." "" % output_onx) shape = [] for d in output_onx.type.tensor_type.shape.dim: