From e27d78ea70169acb0c295246ac5c3d2ac22be2b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Thu, 25 Nov 2021 17:10:41 +0100 Subject: [PATCH 1/3] Adds function plot_onnxs --- .../plot_orttraining_linear_regression.py | 586 +++--------------- .../plot_orttraining_linear_regression_cpu.py | 516 +++++++++++++++ .../plot_orttraining_linear_regression_gpu.py | 29 +- _doc/examples/plot_orttraining_nn_gpu.py | 66 +- .../source/tutorial/tutorial_6_training.rst | 19 +- _unittests/ut_plotting/test_plotting_onnx.py | 92 +++ onnxcustom/plotting/__init__.py | 5 + onnxcustom/plotting/plotting_onnx.py | 61 ++ 8 files changed, 784 insertions(+), 590 deletions(-) create mode 100644 _doc/examples/plot_orttraining_linear_regression_cpu.py create mode 100644 _unittests/ut_plotting/test_plotting_onnx.py create mode 100644 onnxcustom/plotting/__init__.py create mode 100644 onnxcustom/plotting/plotting_onnx.py diff --git a/_doc/examples/plot_orttraining_linear_regression.py b/_doc/examples/plot_orttraining_linear_regression.py index 6242538b..91ee9dc3 100644 --- a/_doc/examples/plot_orttraining_linear_regression.py +++ b/_doc/examples/plot_orttraining_linear_regression.py @@ -1,6 +1,6 @@ """ -.. _l-orttraining-linreg-cpu: +.. _l-orttraining-linreg: Train a linear regression with onnxruntime-training =================================================== @@ -8,7 +8,7 @@ This example explores how :epkg:`onnxruntime-training` can be used to train a simple linear regression using a gradient descent. It compares the results with those obtained by -:class:`sklearn.linear_model.LinearRegression`. +:class:`sklearn.linear_model.SGDRegressor` .. contents:: :local: @@ -20,51 +20,73 @@ from pprint import pprint import numpy from pandas import DataFrame -import matplotlib.pyplot as plt -from onnx import helper, numpy_helper, TensorProto from onnxruntime import ( - InferenceSession, __version__ as ort_version, - TrainingParameters, SessionOptions, TrainingSession, - get_device) + InferenceSession, get_device) from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split -from sklearn.linear_model import LinearRegression +from sklearn.linear_model import SGDRegressor +from sklearn.neural_network import MLPRegressor from mlprodict.onnx_conv import to_onnx -from mlprodict.plotting.plotting_onnx import plot_onnx +from onnxcustom.plotting.plotting_onnx import plot_onnxs from onnxcustom.training import add_loss_output, get_train_initializer from onnxcustom.training.optimizers import OrtGradientOptimizer -from tqdm import tqdm X, y = make_regression(n_features=2, bias=2) X = X.astype(numpy.float32) y = y.astype(numpy.float32) X_train, X_test, y_train, y_test = train_test_split(X, y) -lr = LinearRegression() +lr = SGDRegressor(l1_ratio=0, max_iter=200, eta0=5e-2) lr.fit(X, y) print(lr.predict(X[:5])) ################################## -# Simplified training code -# ++++++++++++++++++++++++ +# The trained coefficients are: +print("trained coefficients:", lr.coef_, lr.intercept_) + +############################################ +# However this model does not show the training curve. +# We switch to a :class:`sklearn.neural_network.MLPRegressor`. + +lr = MLPRegressor(hidden_layer_sizes=tuple(), + activation='identity', max_iter=200, + batch_size=10, solver='sgd', + alpha=0, learning_rate_init=1e-2, + n_iter_no_change=200) +lr.fit(X, y) +print(lr.predict(X[:5])) + +################################## +# The trained coefficients are: +print("trained coefficients:", lr.coefs_, lr.intercepts_) + +################################## +# ONNX graph +# ++++++++++ # -# Next lines illustrates how to train a linear regression -# with :epkg:`onnxruntime`. It includes the conversion of -# a linear regression into ONNX, the computation of the gradient -# graph and the implementation of a simple stochastic gradient -# descent. This section does not explain how it works yet but -# shows how it could look like written with :epkg:`scikit-learn` -# design. +# Training with :pekg:`onnxruntime-training` starts with an ONNX +# graph which defines the model to learn. It is obtained by simply +# converting the previous linear regression into ONNX. onx = to_onnx(lr, X_train[:1].astype(numpy.float32), target_opset=15, black_op={'LinearRegressor'}) ############################################### -# The loss function is the square function. We use function -# :func:`add_loss_output `. +# Choosing a loss +# +++++++++++++++ +# +# The training requires a loss function. By default, it +# is the square function but it could be the absolute error or +# include penalties. Function +# :func:`add_loss_output ` +# appends the loss function to the ONNX graph. onx_train = add_loss_output(onx) +plot_onnxs(onx, onx_train, + title=['Linear Regression', + 'Linear Regression + Loss with ONNX']) + ##################################### # Let's check inference is working. @@ -73,522 +95,64 @@ print("onnx loss=%r" % (res[0][0, 0] / X_test.shape[0])) ##################################### -# Let's retrieve the constants, the weights to optimize. -# We remove initializer which cannot be optimized. +# Weights +# +++++++ +# +# Every initializer is a set of weights which can be trained +# and a gradient will be computed for it. +# However an initializer used to modify a shape or to +# extract a subpart of a tensor does not need training. +# Let's remove them from the list of initializer to train. inits = get_train_initializer(onx) weights = {k: v for k, v in inits.items() if k != "shape_tensor"} pprint(list((k, v[0].shape) for k, v in weights.items())) ##################################### -# Train on CPU or GPU if available. +# Train on CPU or GPU if available +# ++++++++++++++++++++++++++++++++ device = "cuda" if get_device() == 'GPU' else 'cpu' print("device=%r get_device()=%r" % (device, get_device())) ####################################### +# Stochastic Gradient Descent +# +++++++++++++++++++++++++++ +# # The training logic is hidden in class # :class:`OrtGradientOptimizer # `. # It follows :epkg:`scikit-learn` API (see `SGDRegressor # `_. +# The gradient graph is not available at this stage. train_session = OrtGradientOptimizer( - onx_train, list(weights), device=device, verbose=1, eta0=1e-4, + onx_train, list(weights), device=device, verbose=1, eta0=1e-2, warm_start=False, max_iter=200, batch_size=10) train_session.fit(X, y) -# " +###################################### # And the trained coefficient are... state_tensors = train_session.get_state() -print("train losses:", train_session.train_losses_) +pprint(["trained coefficients:", state_tensors]) +print("last_losses:", train_session.train_losses_[-5:]) -df = DataFrame({'losses': train_session.train_losses_}) +min_length = min(len(train_session.train_losses_), len(lr.loss_curve_)) +df = DataFrame({'ort losses': train_session.train_losses_[:min_length], + 'skl losses': lr.loss_curve_[:min_length]}) df.plot(title="Train loss against iterations") -# Let's see know what is behind these short lines of codes. -# -################################### -# An equivalent ONNX graph. -# +++++++++++++++++++++++++ -# -# This graph can be obtained with *sklearn-onnx` as we need to -# modify it for training, it is easier to create an explicit one. - - -def onnx_linear_regression(coefs, intercept): - if len(coefs.shape) == 1: - coefs = coefs.reshape((1, -1)) - coefs = coefs.T - - # input and output - X = helper.make_tensor_value_info( - 'X', TensorProto.FLOAT, [None, coefs.shape[0]]) - Y = helper.make_tensor_value_info( - 'Y', TensorProto.FLOAT, [None, coefs.shape[1]]) - - # inference - node_matmul = helper.make_node('MatMul', ['X', 'coefs'], ['y1'], name='N1') - node_add = helper.make_node('Add', ['y1', 'intercept'], ['Y'], name='N2') - - # initializer - init_coefs = numpy_helper.from_array(coefs, name="coefs") - init_intercept = numpy_helper.from_array(intercept, name="intercept") - - # graph - graph_def = helper.make_graph( - [node_matmul, node_add], 'lr', [X], [Y], - [init_coefs, init_intercept]) - model_def = helper.make_model( - graph_def, producer_name='orttrainer', ir_version=7, - producer_version=ort_version, - opset_imports=[helper.make_operatorsetid('', 14)]) - return model_def - - -onx = onnx_linear_regression(lr.coef_.astype(numpy.float32), - lr.intercept_.astype(numpy.float32)) - -######################################## -# Let's visualize it. - -plot_onnx(onx) - -################################### -# We check it produces the same outputs. - -sess = InferenceSession(onx.SerializeToString()) -print(sess.run(None, {'X': X[:5]})[0]) - -##################################### -# It works. - -##################################### -# Training with onnxruntime-training -# ++++++++++++++++++++++++++++++++++ -# -# It is possible only if the graph to train has a gradient. -# Then the model can be trained with a gradient descent algorithm. -# The previous graph only predicts, a new graph needs to be created -# to compute the loss as well. In our case, it is a square loss. -# The new graph requires another input for the label -# and another output for the loss value. - - -def onnx_linear_regression_training(coefs, intercept): - if len(coefs.shape) == 1: - coefs = coefs.reshape((1, -1)) - coefs = coefs.T - - # input - X = helper.make_tensor_value_info( - 'X', TensorProto.FLOAT, [None, coefs.shape[0]]) - - # expected input - label = helper.make_tensor_value_info( - 'label', TensorProto.FLOAT, [None, coefs.shape[1]]) - - # output - Y = helper.make_tensor_value_info( - 'Y', TensorProto.FLOAT, [None, coefs.shape[1]]) - - # loss - loss = helper.make_tensor_value_info('loss', TensorProto.FLOAT, []) - - # inference - node_matmul = helper.make_node('MatMul', ['X', 'coefs'], ['y1'], name='N1') - node_add = helper.make_node('Add', ['y1', 'intercept'], ['Y'], name='N2') - - # loss - node_diff = helper.make_node('Sub', ['Y', 'label'], ['diff'], name='L1') - node_square = helper.make_node( - 'Mul', ['diff', 'diff'], ['diff2'], name='L2') - node_square_sum = helper.make_node( - 'ReduceSum', ['diff2'], ['loss'], name='L3') - - # initializer - init_coefs = numpy_helper.from_array(coefs, name="coefs") - init_intercept = numpy_helper.from_array(intercept, name="intercept") - - # graph - graph_def = helper.make_graph( - [node_matmul, node_add, node_diff, node_square, node_square_sum], - 'lrt', [X, label], [loss, Y], [init_coefs, init_intercept]) - model_def = helper.make_model( - graph_def, producer_name='orttrainer', ir_version=7, - producer_version=ort_version, - opset_imports=[helper.make_operatorsetid('', 14)]) - return model_def - ####################################### -# We create a graph with random coefficients. - - -onx_train = onnx_linear_regression_training( - numpy.random.randn(*lr.coef_.shape).astype(numpy.float32), - numpy.random.randn( - *lr.intercept_.reshape((-1, )).shape).astype(numpy.float32)) - -plot_onnx(onx_train) - -################################################ -# DataLoader -# ++++++++++ -# -# Next class draws consecutive random observations from a dataset -# by batch. It iterates over the datasets by drawing *n* consecutive -# observations. - - -class DataLoader: - """ - Draws consecutive random observations from a dataset - by batch. It iterates over the datasets by drawing - *batch_size* consecutive observations. - - :param X: features - :param y: labels - :param batch_size: batch size (consecutive observations) - """ - - def __init__(self, X, y, batch_size=20): - self.X, self.y = X, y - self.batch_size = batch_size - if len(self.y.shape) == 1: - self.y = self.y.reshape((-1, 1)) - if self.X.shape[0] != self.y.shape[0]: - raise ValueError( - "Shape mismatch X.shape=%r, y.shape=%r." % ( - self.X.shape, self.y.shape)) - - def __len__(self): - "Returns the number of observations." - return self.X.shape[0] - - def __iter__(self): - """ - Iterates over the datasets by drawing - *batch_size* consecutive observations. - """ - N = 0 - b = len(self) - self.batch_size - while N < len(self): - i = numpy.random.randint(0, b) - N += self.batch_size - yield (self.X[i:i + self.batch_size], - self.y[i:i + self.batch_size]) - - @property - def data(self): - "Returns a tuple of the datasets." - return self.X, self.y - - -data_loader = DataLoader(X_train, y_train, batch_size=2) - - -for i, batch in enumerate(data_loader): - if i >= 2: - break - print("batch %r: %r" % (i, batch)) - - -######################################### -# First iterations of training -# ++++++++++++++++++++++++++++ -# -# Prediction needs an instance of class *InferenceSession*, -# the training needs an instance of class *TrainingSession*. -# Next function creates this one. - - -def create_training_session( - training_onnx, weights_to_train, loss_output_name='loss', - training_optimizer_name='SGDOptimizer'): - """ - Creates an instance of class `TrainingSession`. - - :param training_onnx: ONNX graph used to train - :param weights_to_train: names of initializers to be optimized - :param loss_output_name: name of the loss output - :param training_optimizer_name: optimizer name - :return: instance of `TrainingSession` - """ - ort_parameters = TrainingParameters() - ort_parameters.loss_output_name = loss_output_name - ort_parameters.use_mixed_precision = False - # ort_parameters.world_rank = -1 - # ort_parameters.world_size = 1 - # ort_parameters.gradient_accumulation_steps = 1 - # ort_parameters.allreduce_post_accumulation = False - # ort_parameters.deepspeed_zero_stage = 0 - # ort_parameters.enable_grad_norm_clip = False - # ort_parameters.set_gradients_as_graph_outputs = False - # ort_parameters.use_memory_efficient_gradient = False - # ort_parameters.enable_adasum = False - - output_types = {} - for output in training_onnx.graph.output: - output_types[output.name] = output.type.tensor_type - - ort_parameters.weights_to_train = set(weights_to_train) - ort_parameters.training_optimizer_name = training_optimizer_name - # ort_parameters.lr_params_feed_name = lr_params_feed_name - - ort_parameters.optimizer_attributes_map = { - name: {} for name in weights_to_train} - ort_parameters.optimizer_int_attributes_map = { - name: {} for name in weights_to_train} - - session_options = SessionOptions() - session_options.use_deterministic_compute = True - - session = TrainingSession( - training_onnx.SerializeToString(), ort_parameters, session_options) - return session - - -train_session = create_training_session(onx_train, ['coefs', 'intercept']) -print(train_session) - -###################################### -# Let's look into the expected inputs and outputs. - -for i in train_session.get_inputs(): - print("+input: %s (%s%s)" % (i.name, i.type, i.shape)) -for o in train_session.get_outputs(): - print("output: %s (%s%s)" % (o.name, o.type, o.shape)) - -###################################### -# A third parameter `Learning_Rate` was added. -# The training updates the weight with a gradient multiplied -# by this parameter. Let's see now how to -# retrieve the trained coefficients. - -state_tensors = train_session.get_state() -pprint(state_tensors) - -###################################### -# We can now check the coefficients are updated after one iteration. - -inputs = {'X': X_train[:1], - 'label': y_train[:1].reshape((-1, 1)), - 'Learning_Rate': numpy.array([0.001], dtype=numpy.float32)} - -train_session.run(None, inputs) -state_tensors = train_session.get_state() -pprint(state_tensors) - -###################################### -# They changed. Another iteration to be sure. - -inputs = {'X': X_train[:1], - 'label': y_train[:1].reshape((-1, 1)), - 'Learning_Rate': numpy.array([0.001], dtype=numpy.float32)} -res = train_session.run(None, inputs) -state_tensors = train_session.get_state() -pprint(state_tensors) - -##################################### -# It works. The training loss can be obtained by looking into the results. - -pprint(res) - -###################################### -# Training -# ++++++++ -# -# We need to implement a gradient descent. -# Let's wrap this into a class similar following scikit-learn's API. - - -class CustomTraining: - """ - Implements a simple :epkg:`Stochastic Gradient Descent`. - - :param model_onnx: ONNX graph to train - :param weights_to_train: list of initializers to train - :param loss_output_name: name of output loss - :param max_iter: number of training iterations - :param training_optimizer_name: optimizing algorithm - :param batch_size: batch size (see class *DataLoader*) - :param eta0: initial learning rate for the `'constant'`, `'invscaling'` - or `'adaptive'` schedules. - :param alpha: constant that multiplies the regularization term, - the higher the value, the stronger the regularization. - Also used to compute the learning rate when set to *learning_rate* - is set to `'optimal'`. - :param power_t: exponent for inverse scaling learning rate - :param learning_rate: learning rate schedule: - * `'constant'`: `eta = eta0` - * `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen - by a heuristic proposed by Leon Bottou. - * `'invscaling'`: `eta = eta0 / pow(t, power_t)` - :param verbose: use :epkg:`tqdm` to display the training progress - """ - - def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', - max_iter=100, training_optimizer_name='SGDOptimizer', - batch_size=10, eta0=0.01, alpha=0.0001, power_t=0.25, - learning_rate='invscaling', verbose=0): - # See https://scikit-learn.org/stable/modules/generated/ - # sklearn.linear_model.SGDRegressor.html - self.model_onnx = model_onnx - self.batch_size = batch_size - self.weights_to_train = weights_to_train - self.loss_output_name = loss_output_name - self.training_optimizer_name = training_optimizer_name - self.verbose = verbose - self.max_iter = max_iter - self.eta0 = eta0 - self.alpha = alpha - self.power_t = power_t - self.learning_rate = learning_rate.lower() - - def _init_learning_rate(self): - self.eta0_ = self.eta0 - if self.learning_rate == "optimal": - typw = numpy.sqrt(1.0 / numpy.sqrt(self.alpha)) - self.eta0_ = typw / max(1.0, (1 + typw) * 2) - self.optimal_init_ = 1.0 / (self.eta0_ * self.alpha) - else: - self.eta0_ = self.eta0 - return self.eta0_ - - def _update_learning_rate(self, t, eta): - if self.learning_rate == "optimal": - eta = 1.0 / (self.alpha * (self.optimal_init_ + t)) - elif self.learning_rate == "invscaling": - eta = self.eta0_ / numpy.power(t + 1, self.power_t) - return eta - - def fit(self, X, y): - """ - Trains the model. - :param X: features - :param y: expected output - :return: self - """ - self.train_session_ = create_training_session( - self.model_onnx, self.weights_to_train, - loss_output_name=self.loss_output_name, - training_optimizer_name=self.training_optimizer_name) - - data_loader = DataLoader(X, y, batch_size=self.batch_size) - lr = self._init_learning_rate() - self.input_names_ = [i.name for i in self.train_session_.get_inputs()] - self.output_names_ = [ - o.name for o in self.train_session_.get_outputs()] - self.loss_index_ = self.output_names_.index(self.loss_output_name) - - loop = ( - tqdm(range(self.max_iter)) - if self.verbose else range(self.max_iter)) - train_losses = [] - for it in loop: - loss = self._iteration(data_loader, lr) - lr = self._update_learning_rate(it, lr) - if self.verbose > 1: - loop.set_description("loss=%1.3g lr=%1.3g" % (loss, lr)) - train_losses.append(loss) - self.train_losses_ = train_losses - self.trained_coef_ = self.train_session_.get_state() - return self - - def _iteration(self, data_loader, learning_rate): - """ - Processes one gradient iteration. - - :param data_lower: instance of class `DataLoader` - :return: loss - """ - actual_losses = [] - lr = numpy.array([learning_rate], dtype=numpy.float32) - for batch_idx, (data, target) in enumerate(data_loader): - if len(target.shape) == 1: - target = target.reshape((-1, 1)) - - inputs = {self.input_names_[0]: data, - self.input_names_[1]: target, - self.input_names_[2]: lr} - res = self.train_session_.run(None, inputs) - actual_losses.append(res[self.loss_index_]) - return numpy.array(actual_losses).mean() - -########################################### -# Let's now train the model in a very similar way -# that it would be done with *scikit-learn*. - - -trainer = CustomTraining(onx_train, ['coefs', 'intercept'], verbose=1, - max_iter=10) -trainer.fit(X, y) -print("training losses:", trainer.train_losses_) - -df = DataFrame({"iteration": numpy.arange(len(trainer.train_losses_)), - "loss": trainer.train_losses_}) -df.set_index('iteration').plot(title="Training loss", logy=True) - -###################################################### -# Let's compare scikit-learn trained coefficients and the coefficients -# obtained with onnxruntime and check they are very close. - -print("scikit-learn", lr.coef_, lr.intercept_) -print("onnxruntime", trainer.trained_coef_) - -#################################################### -# It works. We could stop here or we could update the weights -# in the training model or the first model. That requires to -# update the constants in an ONNX graph. -# -# Update weights in an ONNX graph -# +++++++++++++++++++++++++++++++ -# -# Let's first check the output of the first model in ONNX. - -sess = InferenceSession(onx.SerializeToString()) -before = sess.run(None, {'X': X[:5]})[0] -print(before) - -################################# -# Let's replace the initializer. - - -def update_onnx_graph(model_onnx, new_weights): - replace_weights = [] - replace_indices = [] - for i, w in enumerate(model_onnx.graph.initializer): - if w.name in new_weights: - replace_weights.append( - numpy_helper.from_array(new_weights[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del model_onnx.graph.initializer[w_i] - model_onnx.graph.initializer.extend(replace_weights) - - -update_onnx_graph(onx, trainer.trained_coef_) - -######################################## -# Let's compare with the previous output. - -sess = InferenceSession(onx.SerializeToString()) -after = sess.run(None, {'X': X[:5]})[0] -print(after) - -###################################### -# It looks almost the same but slighly different. - -print(after - before) - - -################################################ -# Next example will show how to train a linear regression on GPU: -# :ref:`l-orttraining-linreg-gpu`. - - -plt.show() +# The convergence speed is not the same but both gradient descents +# do not update the gradient multiplier the same way. +# :epkg:`onnxruntime-training` does not implement any gradient descent, +# it just computes the gradient. +# That's the purpose of :class:`OrtGradientOptimizer +# `. Next example +# digs into the implementation details. + +# import matplotlib.pyplot as plt +# plt.show() diff --git a/_doc/examples/plot_orttraining_linear_regression_cpu.py b/_doc/examples/plot_orttraining_linear_regression_cpu.py new file mode 100644 index 00000000..e9081b3d --- /dev/null +++ b/_doc/examples/plot_orttraining_linear_regression_cpu.py @@ -0,0 +1,516 @@ +""" + +.. _l-orttraining-linreg-cpu: + +Train a linear regression with onnxruntime-training +=================================================== + +:epkg:`onnxruntime-training` only computes the gradient values. +A gradient descent can then use it to train a model. +This example goes step by step from the gradient computation +to the gradient descent to get a trained linear regression. + +.. contents:: + :local: + +A simple linear regression with scikit-learn +++++++++++++++++++++++++++++++++++++++++++++ + +""" +from pprint import pprint +import numpy +from pandas import DataFrame +import matplotlib.pyplot as plt +from onnx import helper, numpy_helper, TensorProto +from onnxruntime import ( + InferenceSession, __version__ as ort_version, + TrainingParameters, SessionOptions, TrainingSession) +from sklearn.datasets import make_regression +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LinearRegression +from onnxcustom.plotting.plotting_onnx import plot_onnxs +from tqdm import tqdm + +X, y = make_regression(n_features=2, bias=2) +X = X.astype(numpy.float32) +y = y.astype(numpy.float32) +X_train, X_test, y_train, y_test = train_test_split(X, y) + +lr = LinearRegression() +lr.fit(X, y) +print(lr.predict(X[:5])) + +################################### +# An equivalent ONNX graph. +# +++++++++++++++++++++++++ +# +# This graph can be obtained with *sklearn-onnx` +# (see :ref:`l-orttraining-linreg`). +# For clarity, this step is replaced by +# next function which builds the exact same graph. It +# implements a linear regression :math:`y = AX + B` with onnx operators. + + +def onnx_linear_regression(coefs, intercept): + if len(coefs.shape) == 1: + coefs = coefs.reshape((1, -1)) + coefs = coefs.T + + # input and output + X = helper.make_tensor_value_info( + 'X', TensorProto.FLOAT, [None, coefs.shape[0]]) + Y = helper.make_tensor_value_info( + 'Y', TensorProto.FLOAT, [None, coefs.shape[1]]) + + # inference + node_matmul = helper.make_node('MatMul', ['X', 'coefs'], ['y1'], name='N1') + node_add = helper.make_node('Add', ['y1', 'intercept'], ['Y'], name='N2') + + # initializer + init_coefs = numpy_helper.from_array(coefs, name="coefs") + init_intercept = numpy_helper.from_array(intercept, name="intercept") + + # graph + graph_def = helper.make_graph( + [node_matmul, node_add], 'lr', [X], [Y], + [init_coefs, init_intercept]) + model_def = helper.make_model( + graph_def, producer_name='orttrainer', ir_version=7, + producer_version=ort_version, + opset_imports=[helper.make_operatorsetid('', 14)]) + return model_def + + +onx = onnx_linear_regression(lr.coef_.astype(numpy.float32), + lr.intercept_.astype(numpy.float32)) + +######################################## +# Let's visualize it. + +plot_onnxs(onx, title="Linear Regression") + +################################### +# We check it produces the same outputs. + +sess = InferenceSession(onx.SerializeToString()) +print(sess.run(None, {'X': X[:5]})[0]) + +##################################### +# It works. + +##################################### +# Training with onnxruntime-training +# ++++++++++++++++++++++++++++++++++ +# +# The model can be trained with a gradient descent algorithm. +# The previous graph only predicts. A new graph needs to be created +# to compute the loss as function of the inputs and the expected outputs. +# In our case, it is a square loss. +# The new graph then requires two inputs, the features and the labels. +# It has two outputs, the predicted values and the loss. + + +def onnx_linear_regression_training(coefs, intercept): + if len(coefs.shape) == 1: + coefs = coefs.reshape((1, -1)) + coefs = coefs.T + + # input + X = helper.make_tensor_value_info( + 'X', TensorProto.FLOAT, [None, coefs.shape[0]]) + + # expected input + label = helper.make_tensor_value_info( + 'label', TensorProto.FLOAT, [None, coefs.shape[1]]) + + # output + Y = helper.make_tensor_value_info( + 'Y', TensorProto.FLOAT, [None, coefs.shape[1]]) + + # loss + loss = helper.make_tensor_value_info('loss', TensorProto.FLOAT, []) + + # inference + node_matmul = helper.make_node('MatMul', ['X', 'coefs'], ['y1'], name='N1') + node_add = helper.make_node('Add', ['y1', 'intercept'], ['Y'], name='N2') + + # loss + node_diff = helper.make_node('Sub', ['Y', 'label'], ['diff'], name='L1') + node_square = helper.make_node( + 'Mul', ['diff', 'diff'], ['diff2'], name='L2') + node_square_sum = helper.make_node( + 'ReduceSum', ['diff2'], ['loss'], name='L3') + + # initializer + init_coefs = numpy_helper.from_array(coefs, name="coefs") + init_intercept = numpy_helper.from_array(intercept, name="intercept") + + # graph + graph_def = helper.make_graph( + [node_matmul, node_add, node_diff, node_square, node_square_sum], + 'lrt', [X, label], [loss, Y], [init_coefs, init_intercept]) + model_def = helper.make_model( + graph_def, producer_name='orttrainer', ir_version=7, + producer_version=ort_version, + opset_imports=[helper.make_operatorsetid('', 14)]) + return model_def + +####################################### +# We create a graph with random coefficients. + + +onx_train = onnx_linear_regression_training( + numpy.random.randn(*lr.coef_.shape).astype(numpy.float32), + numpy.random.randn( + *lr.intercept_.reshape((-1, )).shape).astype(numpy.float32)) + +plot_onnxs(onx_train, "Linear Regression with a loss") + +################################################ +# DataLoader +# ++++++++++ +# +# Next class draws consecutive random observations from a dataset +# by batch. It iterates over the datasets by drawing *n* consecutive +# observations. This class is equivalent to +# :class:`OrtDataLoader `. + + +class DataLoader: + """ + Draws consecutive random observations from a dataset + by batch. It iterates over the datasets by drawing + *batch_size* consecutive observations. + + :param X: features + :param y: labels + :param batch_size: batch size (consecutive observations) + """ + + def __init__(self, X, y, batch_size=20): + self.X, self.y = X, y + self.batch_size = batch_size + if len(self.y.shape) == 1: + self.y = self.y.reshape((-1, 1)) + if self.X.shape[0] != self.y.shape[0]: + raise ValueError( + "Shape mismatch X.shape=%r, y.shape=%r." % ( + self.X.shape, self.y.shape)) + + def __len__(self): + "Returns the number of observations." + return self.X.shape[0] + + def __iter__(self): + """ + Iterates over the datasets by drawing + *batch_size* consecutive observations. + """ + N = 0 + b = len(self) - self.batch_size + while N < len(self): + i = numpy.random.randint(0, b) + N += self.batch_size + yield (self.X[i:i + self.batch_size], + self.y[i:i + self.batch_size]) + + @property + def data(self): + "Returns a tuple of the datasets." + return self.X, self.y + + +data_loader = DataLoader(X_train, y_train, batch_size=2) + + +for i, batch in enumerate(data_loader): + if i >= 2: + break + print("batch %r: %r" % (i, batch)) + + +######################################### +# First iterations of training +# ++++++++++++++++++++++++++++ +# +# Prediction needs an instance of class *InferenceSession*, +# the training needs an instance of class *TrainingSession*. +# Next function creates this one. + + +def create_training_session( + training_onnx, weights_to_train, loss_output_name='loss', + training_optimizer_name='SGDOptimizer'): + """ + Creates an instance of class `TrainingSession`. + + :param training_onnx: ONNX graph used to train + :param weights_to_train: names of initializers to be optimized + :param loss_output_name: name of the loss output + :param training_optimizer_name: optimizer name + :return: instance of `TrainingSession` + """ + ort_parameters = TrainingParameters() + ort_parameters.loss_output_name = loss_output_name + + output_types = {} + for output in training_onnx.graph.output: + output_types[output.name] = output.type.tensor_type + + ort_parameters.weights_to_train = set(weights_to_train) + ort_parameters.training_optimizer_name = training_optimizer_name + + ort_parameters.optimizer_attributes_map = { + name: {} for name in weights_to_train} + ort_parameters.optimizer_int_attributes_map = { + name: {} for name in weights_to_train} + + session_options = SessionOptions() + session_options.use_deterministic_compute = True + + session = TrainingSession( + training_onnx.SerializeToString(), ort_parameters, session_options) + return session + + +train_session = create_training_session(onx_train, ['coefs', 'intercept']) +print(train_session) + +###################################### +# Let's look into the expected inputs and outputs. + +for i in train_session.get_inputs(): + print("+input: %s (%s%s)" % (i.name, i.type, i.shape)) +for o in train_session.get_outputs(): + print("output: %s (%s%s)" % (o.name, o.type, o.shape)) + +###################################### +# A third parameter `Learning_Rate` was automatically added. +# The training updates the weight with a gradient multiplied +# by this parameter. Let's see now how to +# retrieve the trained coefficients. + +state_tensors = train_session.get_state() +pprint(state_tensors) + +###################################### +# We can now check the coefficients are updated after one iteration. + +inputs = {'X': X_train[:1], + 'label': y_train[:1].reshape((-1, 1)), + 'Learning_Rate': numpy.array([0.001], dtype=numpy.float32)} + +train_session.run(None, inputs) +state_tensors = train_session.get_state() +pprint(state_tensors) + +###################################### +# They changed. Another iteration to be sure. + +inputs = {'X': X_train[:1], + 'label': y_train[:1].reshape((-1, 1)), + 'Learning_Rate': numpy.array([0.001], dtype=numpy.float32)} +res = train_session.run(None, inputs) +state_tensors = train_session.get_state() +pprint(state_tensors) + +##################################### +# It works. The training loss can be obtained by looking into the results. + +pprint(res) + +###################################### +# Training +# ++++++++ +# +# We need to implement a gradient descent. +# Let's wrap this into a class similar following scikit-learn's API. + + +class CustomTraining: + """ + Implements a simple :epkg:`Stochastic Gradient Descent`. + + :param model_onnx: ONNX graph to train + :param weights_to_train: list of initializers to train + :param loss_output_name: name of output loss + :param max_iter: number of training iterations + :param training_optimizer_name: optimizing algorithm + :param batch_size: batch size (see class *DataLoader*) + :param eta0: initial learning rate for the `'constant'`, `'invscaling'` + or `'adaptive'` schedules. + :param alpha: constant that multiplies the regularization term, + the higher the value, the stronger the regularization. + Also used to compute the learning rate when set to *learning_rate* + is set to `'optimal'`. + :param power_t: exponent for inverse scaling learning rate + :param learning_rate: learning rate schedule: + * `'constant'`: `eta = eta0` + * `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen + by a heuristic proposed by Leon Bottou. + * `'invscaling'`: `eta = eta0 / pow(t, power_t)` + :param verbose: use :epkg:`tqdm` to display the training progress + """ + + def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', + max_iter=100, training_optimizer_name='SGDOptimizer', + batch_size=10, eta0=0.01, alpha=0.0001, power_t=0.25, + learning_rate='invscaling', verbose=0): + # See https://scikit-learn.org/stable/modules/generated/ + # sklearn.linear_model.SGDRegressor.html + self.model_onnx = model_onnx + self.batch_size = batch_size + self.weights_to_train = weights_to_train + self.loss_output_name = loss_output_name + self.training_optimizer_name = training_optimizer_name + self.verbose = verbose + self.max_iter = max_iter + self.eta0 = eta0 + self.alpha = alpha + self.power_t = power_t + self.learning_rate = learning_rate.lower() + + def _init_learning_rate(self): + self.eta0_ = self.eta0 + if self.learning_rate == "optimal": + typw = numpy.sqrt(1.0 / numpy.sqrt(self.alpha)) + self.eta0_ = typw / max(1.0, (1 + typw) * 2) + self.optimal_init_ = 1.0 / (self.eta0_ * self.alpha) + else: + self.eta0_ = self.eta0 + return self.eta0_ + + def _update_learning_rate(self, t, eta): + if self.learning_rate == "optimal": + eta = 1.0 / (self.alpha * (self.optimal_init_ + t)) + elif self.learning_rate == "invscaling": + eta = self.eta0_ / numpy.power(t + 1, self.power_t) + return eta + + def fit(self, X, y): + """ + Trains the model. + :param X: features + :param y: expected output + :return: self + """ + self.train_session_ = create_training_session( + self.model_onnx, self.weights_to_train, + loss_output_name=self.loss_output_name, + training_optimizer_name=self.training_optimizer_name) + + data_loader = DataLoader(X, y, batch_size=self.batch_size) + lr = self._init_learning_rate() + self.input_names_ = [i.name for i in self.train_session_.get_inputs()] + self.output_names_ = [ + o.name for o in self.train_session_.get_outputs()] + self.loss_index_ = self.output_names_.index(self.loss_output_name) + + loop = ( + tqdm(range(self.max_iter)) + if self.verbose else range(self.max_iter)) + train_losses = [] + for it in loop: + loss = self._iteration(data_loader, lr) + lr = self._update_learning_rate(it, lr) + if self.verbose > 1: + loop.set_description("loss=%1.3g lr=%1.3g" % (loss, lr)) + train_losses.append(loss) + self.train_losses_ = train_losses + self.trained_coef_ = self.train_session_.get_state() + return self + + def _iteration(self, data_loader, learning_rate): + """ + Processes one gradient iteration. + + :param data_lower: instance of class `DataLoader` + :return: loss + """ + actual_losses = [] + lr = numpy.array([learning_rate], dtype=numpy.float32) + for batch_idx, (data, target) in enumerate(data_loader): + if len(target.shape) == 1: + target = target.reshape((-1, 1)) + + inputs = {self.input_names_[0]: data, + self.input_names_[1]: target, + self.input_names_[2]: lr} + res = self.train_session_.run(None, inputs) + actual_losses.append(res[self.loss_index_]) + return numpy.array(actual_losses).mean() + +########################################### +# Let's now train the model in a very similar way +# that it would be done with *scikit-learn*. + + +trainer = CustomTraining(onx_train, ['coefs', 'intercept'], verbose=1, + max_iter=10) +trainer.fit(X, y) +print("training losses:", trainer.train_losses_) + +df = DataFrame({"iteration": numpy.arange(len(trainer.train_losses_)), + "loss": trainer.train_losses_}) +df.set_index('iteration').plot(title="Training loss", logy=True) + +###################################################### +# Let's compare scikit-learn trained coefficients and the coefficients +# obtained with onnxruntime and check they are very close. + +print("scikit-learn", lr.coef_, lr.intercept_) +print("onnxruntime", trainer.trained_coef_) + +#################################################### +# It works. We could stop here or we could update the weights +# in the training model or the first model. That requires to +# update the constants in an ONNX graph. +# +# Update weights in an ONNX graph +# +++++++++++++++++++++++++++++++ +# +# Let's first check the output of the first model in ONNX. + +sess = InferenceSession(onx.SerializeToString()) +before = sess.run(None, {'X': X[:5]})[0] +print(before) + +################################# +# Let's replace the initializer. + + +def update_onnx_graph(model_onnx, new_weights): + replace_weights = [] + replace_indices = [] + for i, w in enumerate(model_onnx.graph.initializer): + if w.name in new_weights: + replace_weights.append( + numpy_helper.from_array(new_weights[w.name], w.name)) + replace_indices.append(i) + replace_indices.sort(reverse=True) + for w_i in replace_indices: + del model_onnx.graph.initializer[w_i] + model_onnx.graph.initializer.extend(replace_weights) + + +update_onnx_graph(onx, trainer.trained_coef_) + +######################################## +# Let's compare with the previous output. + +sess = InferenceSession(onx.SerializeToString()) +after = sess.run(None, {'X': X[:5]})[0] +print(after) + +###################################### +# It looks almost the same but slighly different. + +print(after - before) + + +################################################ +# Next example will show how to train a linear regression on GPU: +# :ref:`l-orttraining-linreg-gpu`. + + +plt.show() diff --git a/_doc/examples/plot_orttraining_linear_regression_gpu.py b/_doc/examples/plot_orttraining_linear_regression_gpu.py index 9814a62a..aa9e08af 100644 --- a/_doc/examples/plot_orttraining_linear_regression_gpu.py +++ b/_doc/examples/plot_orttraining_linear_regression_gpu.py @@ -8,7 +8,8 @@ This example follows the same steps introduced in example :ref:`l-orttraining-linreg-cpu` but on GPU. This example works on CPU and GPU but automatically chooses GPU if it is -available. +available. The main change in this example is the parameter `device` +which indicates where the computation takes place, on CPU or GPU. .. contents:: :local: @@ -16,7 +17,7 @@ A simple linear regression with scikit-learn ++++++++++++++++++++++++++++++++++++++++++++ -This code begins like example :ref:`l-orttraining-linreg-gpu`. +This code begins like example :ref:`l-orttraining-linreg-cpu`. It creates a graph to train a linear regression initialized with random coefficients. """ @@ -29,7 +30,7 @@ TrainingParameters, SessionOptions, TrainingSession) from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split -from mlprodict.plotting.plotting_onnx import plot_onnx +from mlprodict.plotting.plotting_onnx import plot_onnxs from tqdm import tqdm X, y = make_regression(n_features=2, bias=2) @@ -47,7 +48,7 @@ def onnx_linear_regression_training(coefs, intercept): X = helper.make_tensor_value_info( 'X', TensorProto.FLOAT, [None, coefs.shape[0]]) - # expected output + # expected input label = helper.make_tensor_value_info( 'label', TensorProto.FLOAT, [None, coefs.shape[1]]) @@ -55,7 +56,7 @@ def onnx_linear_regression_training(coefs, intercept): Y = helper.make_tensor_value_info( 'Y', TensorProto.FLOAT, [None, coefs.shape[1]]) - # loss output + # loss loss = helper.make_tensor_value_info('loss', TensorProto.FLOAT, []) # inference @@ -76,9 +77,7 @@ def onnx_linear_regression_training(coefs, intercept): # graph graph_def = helper.make_graph( [node_matmul, node_add, node_diff, node_square, node_square_sum], - 'lrt', - [X, label], [loss, Y], - [init_coefs, init_intercept]) + 'lrt', [X, label], [loss, Y], [init_coefs, init_intercept]) model_def = helper.make_model( graph_def, producer_name='orttrainer', ir_version=7, producer_version=ort_version, @@ -90,7 +89,7 @@ def onnx_linear_regression_training(coefs, intercept): numpy.random.randn(2).astype(numpy.float32), numpy.random.randn(1).astype(numpy.float32)) -plot_onnx(onx_train) +plot_onnxs(onx_train, title="Graph with Loss") ######################################### @@ -124,16 +123,6 @@ def create_training_session( """ ort_parameters = TrainingParameters() ort_parameters.loss_output_name = loss_output_name - ort_parameters.use_mixed_precision = False - # ort_parameters.world_rank = -1 - # ort_parameters.world_size = 1 - # ort_parameters.gradient_accumulation_steps = 1 - # ort_parameters.allreduce_post_accumulation = False - # ort_parameters.deepspeed_zero_stage = 0 - # ort_parameters.enable_grad_norm_clip = False - # ort_parameters.set_gradients_as_graph_outputs = False - # ort_parameters.use_memory_efficient_gradient = False - # ort_parameters.enable_adasum = False output_types = {} for output in training_onnx.graph.output: @@ -141,7 +130,6 @@ def create_training_session( ort_parameters.weights_to_train = set(weights_to_train) ort_parameters.training_optimizer_name = training_optimizer_name - # ort_parameters.lr_params_feed_name = lr_params_feed_name ort_parameters.optimizer_attributes_map = { name: {} for name in weights_to_train} @@ -200,6 +188,7 @@ def create_training_session( # # We still need to implement a gradient descent. # Let's wrap this into a class similar following scikit-learn's API. +# It needs to have an extra parameter *device*. class DataLoaderDevice: diff --git a/_doc/examples/plot_orttraining_nn_gpu.py b/_doc/examples/plot_orttraining_nn_gpu.py index e4f961b1..990fa20f 100644 --- a/_doc/examples/plot_orttraining_nn_gpu.py +++ b/_doc/examples/plot_orttraining_nn_gpu.py @@ -6,7 +6,9 @@ ==================================================================== This example leverages example :ref:`l-orttraining-linreg-gpu` to -train a neural network from :epkg:`scikit-learn` on GPU. +train a neural network from :epkg:`scikit-learn` on GPU. However, the code +is using classes implemented in this module, following the pattern +introduced in exemple :ref:`l-orttraining-linreg`. .. contents:: :local: @@ -24,9 +26,8 @@ from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPRegressor from sklearn.metrics import mean_squared_error -from mlprodict.plotting.plotting_onnx import plot_onnx +from mlprodict.plotting.plotting_onnx import plot_onnxs from mlprodict.onnx_conv import to_onnx -from mlprodict.tools import measure_time from onnxcustom.training import add_loss_output, get_train_initializer from onnxcustom.training.optimizers import OrtGradientOptimizer @@ -55,7 +56,7 @@ # ++++++++++++++++++ onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15) -plot_onnx(onx) +plot_onnxs(onx) ####################################### # Training graph @@ -67,7 +68,7 @@ # :ref:`l-orttraining-linreg-cpu`. onx_train = add_loss_output(onx) -plot_onnx(onx_train) +plot_onnxs(onx_train) ##################################### # Let's check inference is working. @@ -100,7 +101,7 @@ # The training session. train_session = OrtGradientOptimizer( - onx_train, list(weights), device=device, verbose=1, eta0=1e-4, + onx_train, list(weights), device=device, verbose=1, eta0=5e-4, warm_start=False, max_iter=200, batch_size=10) train_session.fit(X, y) @@ -108,54 +109,9 @@ print(train_session.train_losses_) -df = DataFrame({'losses': train_session.train_losses_}) +df = DataFrame({'ort losses': train_session.train_losses_, + 'skl losses:': nn.loss_curve_}) df.plot(title="Train loss against iterations", logy=True) - -################################################ -# Benchmark -# +++++++++ -# -# The last part compares the speed between the two training. - -nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=200, - solver='sgd', learning_rate_init=1e-4, - n_iter_no_change=1000, batch_size=10) - - -def skl_train(): - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - nn.fit(X_train, y_train) - - -obs = [] -res = measure_time("skl_train()", context=dict(skl_train=skl_train), - repeat=1, number=1) -res['framework'] = ['skl'] -pprint(res) -obs.append(res) - -train_session = OrtGradientOptimizer( - onx_train, list(weights), device=device, verbose=0, eta0=1e-4, - warm_start=False, max_iter=200, batch_size=10) - - -def ort_train(): - train_session.fit(X, y) - - -res = measure_time("ort_train()", context=dict(ort_train=ort_train), - repeat=1, number=1) -res['framework'] = ['ort'] -pprint(res) -obs.append(res) - -df = DataFrame(obs) -df = df[['average', 'framework']] -print(df) - -# " -# Graph. - -df.set_index('framework').plot.hist() +# import matplotlib.pyplot as plt +# plt.show() diff --git a/_doc/sphinxdoc/source/tutorial/tutorial_6_training.rst b/_doc/sphinxdoc/source/tutorial/tutorial_6_training.rst index 22378609..2fca39d0 100644 --- a/_doc/sphinxdoc/source/tutorial/tutorial_6_training.rst +++ b/_doc/sphinxdoc/source/tutorial/tutorial_6_training.rst @@ -8,14 +8,25 @@ It builds a graph equivalent to the gradient function also based on onnx operators and specific gradient operators. Initializers are weights that can be trained. The gradient graph has as many as outputs as initializers. -The first example explains how to train a linear model -with onnxruntime. The second one modifies the first example -to do it with GPU. The third example extends that experiment -with a neural network built by :epkg:`scikit-learn`. +The first example compares a linear regression trained with +:epkg:`scikit-learn` and another one trained with +:epkg:`onnxruntime-training`. + +The two next examples explains in details how the training +with :epkg:`onnxruntime-training`. They dig into class +:class:`OrtGradientOptimizer +`. + +The fourth example replicates what was done with the linear regression +but with a neural network built by :epkg:`scikit-learn`. +It trains the network on CPU or GPU +if it is available. The last example benchmarks the different +approaches. .. toctree:: :maxdepth: 1 ../gyexamples/plot_orttraining_linear_regression + ../gyexamples/plot_orttraining_linear_regression_cpu ../gyexamples/plot_orttraining_linear_regression_gpu ../gyexamples/plot_orttraining_nn_gpu diff --git a/_unittests/ut_plotting/test_plotting_onnx.py b/_unittests/ut_plotting/test_plotting_onnx.py new file mode 100644 index 00000000..2866ec93 --- /dev/null +++ b/_unittests/ut_plotting/test_plotting_onnx.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +""" +@brief test log(time=3s) +""" +import os +import warnings +import unittest +import numpy +from pyquickhelper.pycode import ( + ExtTestCase, skipif_travis, skipif_circleci, get_temp_folder) +from skl2onnx.algebra.onnx_ops import OnnxConcat # pylint: disable=E0611 +from skl2onnx.common.data_types import FloatTensorType +from onnxcustom.plotting.plotting_onnx import plot_onnxs + + +class TestPlotOnnx(ExtTestCase): + + @skipif_travis('graphviz is not installed') + @skipif_circleci('graphviz is not installed') + def test_plot_onnx(self): + + cst = numpy.array([[1, 2]], dtype=numpy.float32) + onx = OnnxConcat('X', 'Y', cst, output_names=['Z'], + op_version=12) + X = numpy.array([[1, 2], [3, 4]], dtype=numpy.float64) + Y = numpy.array([[8, 9], [10, 11], [12, 13]], dtype=numpy.float64) + model_def = onx.to_onnx({'X': X.astype(numpy.float32), + 'Y': Y.astype(numpy.float32)}, + outputs=[('Z', FloatTensorType([2]))], + target_opset=12) + + import matplotlib.pyplot as plt + self.assertRaise(lambda: plot_onnxs(*[]), ValueError) + + try: + ax = plot_onnxs(model_def, title="GRAPH") + except FileNotFoundError as e: + if "No such file or directory: 'dot'" in str(e): + warnings.warn( + "Unable to test the dot syntax, dot is mssing", UserWarning) + return + raise e + self.assertNotEmpty(ax) + plt.close('all') + + @skipif_travis('graphviz is not installed') + @skipif_circleci('graphviz is not installed') + def test_plot_onnx2(self): + + cst = numpy.array([[1, 2]], dtype=numpy.float32) + onx = OnnxConcat('X', 'Y', cst, output_names=['Z'], + op_version=12) + X = numpy.array([[1, 2], [3, 4]], dtype=numpy.float64) + Y = numpy.array([[8, 9], [10, 11], [12, 13]], dtype=numpy.float64) + model_def = onx.to_onnx({'X': X.astype(numpy.float32), + 'Y': Y.astype(numpy.float32)}, + outputs=[('Z', FloatTensorType([2]))], + target_opset=12) + + import matplotlib.pyplot as plt + ax = numpy.array([0]) + self.assertRaise( + lambda: plot_onnxs(model_def, model_def, ax=ax), ValueError) + + try: + ax = plot_onnxs(model_def, model_def, title=["GRAPH1", "GRAPH2"]) + except FileNotFoundError as e: + if "No such file or directory: 'dot'" in str(e): + warnings.warn( + "Unable to test the dot syntax, dot is mssing", UserWarning) + return + raise e + self.assertNotEmpty(ax) + try: + ax = plot_onnxs(model_def, model_def, title="GRAPH1") + except FileNotFoundError as e: + if "No such file or directory: 'dot'" in str(e): + warnings.warn( + "Unable to test the dot syntax, dot is mssing", UserWarning) + return + raise e + self.assertNotEmpty(ax) + if __name__ == "__main__": + temp = get_temp_folder(__file__, "temp_plot_onnx2") + img = os.path.join(temp, "img.png") + plt.savefig(img) + plt.show() + plt.close('all') + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxcustom/plotting/__init__.py b/onnxcustom/plotting/__init__.py new file mode 100644 index 00000000..9e2441f2 --- /dev/null +++ b/onnxcustom/plotting/__init__.py @@ -0,0 +1,5 @@ +# flake8: noqa: F401 +""" +@file +@brief Shortcuts to plotting. +""" diff --git a/onnxcustom/plotting/plotting_onnx.py b/onnxcustom/plotting/plotting_onnx.py new file mode 100644 index 00000000..42e4260c --- /dev/null +++ b/onnxcustom/plotting/plotting_onnx.py @@ -0,0 +1,61 @@ +# flake8: noqa: F401 +""" +@file +@brief Shortcuts to plotting. +""" +from mlprodict.plotting.plotting_onnx import plot_onnx + + +def plot_onnxs(*onx, ax=None, dpi=300, temp_dot=None, temp_img=None, + show=False, title=None): + """ + Plots one or several ONNX graph into a :epkg:`matplotlib` graph. + + :param onx: ONNX objects + :param ax: existing axes + :param dpi: resolution + :param temp_dot: temporary file, + if None, a file is created and removed + :param temp_img: temporary image, + if None, a file is created and removed + :param show: calls `plt.show()` + :return: axes + """ + if len(onx) == 1: + if ax is None: + import matplotlib.pyplot as plt # pylint: disable=C0415 + ax = plt.gca() + elif isinstance(ax, str) and ax == 'new': + import matplotlib.pyplot as plt # pylint: disable=C0415 + _, ax = plt.subplots(1, 1) + ax = plot_onnx(onx[0], ax=ax, dpi=dpi, temp_dot=temp_dot, + temp_img=temp_img) + if title is not None: + ax.set_title(title) + return ax + + if len(onx) == 0: + raise ValueError( + "Empty list of graph to plot.") + + if ax is None: + import matplotlib.pyplot as plt # pylint: disable=C0415 + fig, ax = plt.subplots(1, len(onx)) + else: + fig = None + if ax.shape[0] != len(onx): + raise ValueError( + "ax must be an array of shape (%d, )." % len(onx)) + for i, ox in enumerate(onx): + plot_onnx(ox, ax=ax[i], dpi=dpi, temp_dot=temp_dot, + temp_img=temp_img) + if title is None or isinstance(title, str): + continue + if i < len(title): + ax[i].set_title(title[i]) + if isinstance(title, str): + if fig is None: + raise ValueError( + "Main title cannot be set if ax is defined.") + fig.suptitle(title) + return ax From d0a6940002f5100297a6e1e12dda76cbc21bed22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Thu, 25 Nov 2021 20:09:21 +0100 Subject: [PATCH 2/3] add benchmark --- _doc/examples/plot_orttraining_benchmark.py | 177 ++++++++++++++++++ .../plot_orttraining_linear_regression_gpu.py | 2 +- _doc/examples/plot_orttraining_nn_gpu.py | 2 +- _doc/sphinxdoc/source/api/plotting.rst | 5 + .../source/tutorial/tutorial_6_training.rst | 1 + 5 files changed, 185 insertions(+), 2 deletions(-) create mode 100644 _doc/examples/plot_orttraining_benchmark.py create mode 100644 _doc/sphinxdoc/source/api/plotting.rst diff --git a/_doc/examples/plot_orttraining_benchmark.py b/_doc/examples/plot_orttraining_benchmark.py new file mode 100644 index 00000000..cb14d63d --- /dev/null +++ b/_doc/examples/plot_orttraining_benchmark.py @@ -0,0 +1,177 @@ +""" + +.. _l-orttraining-benchmark: + +Benchmark, comparison scikit-learn - onnxruntime-training +========================================================= + +The benchmark compares the processing time between :epkg:`scikit-learn` +and :epkg:`onnxruntime-training` on a linear regression and a neural network. +It uses the model trained in :ref:`l-orttraining-nn-gpu`. + + +.. contents:: + :local: + +First comparison: neural network +++++++++++++++++++++++++++++++++ + +""" +import warnings +from pprint import pprint +import time +import numpy +from pandas import DataFrame +from onnxruntime import get_device +from pyquickhelper.pycode.profiling import profile, profile2graph +from sklearn.datasets import make_regression +from sklearn.model_selection import train_test_split +from sklearn.neural_network import MLPRegressor +from mlprodict.onnx_conv import to_onnx +from onnxcustom.training import add_loss_output, get_train_initializer +from onnxcustom.training.optimizers import OrtGradientOptimizer + + +X, y = make_regression(2000, n_features=100, bias=2) +X = X.astype(numpy.float32) +y = y.astype(numpy.float32) +X_train, X_test, y_train, y_test = train_test_split(X, y) + +######################################## +# Common parameters and model + +batch_size = 15 +max_iter = 200 + +nn = MLPRegressor(hidden_layer_sizes=(50, 10), max_iter=batch_size, + solver='sgd', learning_rate_init=1e-4, + n_iter_no_change=max_iter * 3, batch_size=batch_size) + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + nn.fit(X_train, y_train) + +######################################## +# Conversion to ONNX and trainer initialization + +onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15) +onx_train = add_loss_output(onx) + +inits = get_train_initializer(onx) +weights = {k: v for k, v in inits.items() if k != "shape_tensor"} +pprint(list((k, v[0].shape) for k, v in weights.items())) + +train_session = OrtGradientOptimizer( + onx_train, list(weights), device='cpu', eta0=5e-4, + warm_start=False, max_iter=max_iter, batch_size=batch_size) + + +def benchmark(skl_model, train_session, name, verbose=True): + + print("[benchmark] %s" % name) + begin = time.perf_counter() + skl_model.fit(X, y) + duration_skl = time.perf_counter() - begin + length_skl = len(skl_model.loss_curve_) + print("[benchmark] skl=%r iterations - %r seconds" % ( + length_skl, duration_skl)) + + begin = time.perf_counter() + train_session.fit(X, y) + duration_ort = time.perf_counter() - begin + length_ort = len(skl_model.loss_curve_) + print("[benchmark] ort=%r iteration - %r seconds" % ( + length_ort, duration_ort)) + + return dict(skl=duration_skl, ort=duration_ort, name=name, + iter_skl=length_skl, iter_ort=length_ort) + + +benches = [benchmark(nn, train_session, name='NN-CPU')] + +###################################### +# Profiling +# +++++++++ + + +def clean_name(text): + pos = text.find('onnxruntime') + if pos >= 0: + return text[pos:] + pos = text.find('onnxcustom') + if pos >= 0: + return text[pos:] + pos = text.find('site-packages') + if pos >= 0: + return text[pos:] + return text + + +ps = profile(lambda: benchmark(nn, train_session, name='NN-CPU'))[0] +root, nodes = profile2graph(ps, clean_text=clean_name) +text = root.to_text() +print(text) + +###################################### +# if GPU is available +# +++++++++++++++++++ + +if get_device() == 'GPU': + + train_session = OrtGradientOptimizer( + onx_train, list(weights), device='cuda', eta0=5e-4, + warm_start=False, max_iter=200, batch_size=batch_size) + + benches.append(benchmark(nn, train_session, name='NN-GPU')) + +###################################### +# Linear Regression +# +++++++++++++++++ + +lr = MLPRegressor(hidden_layer_sizes=tuple(), max_iter=batch_size, + solver='sgd', learning_rate_init=1e-4, + n_iter_no_change=max_iter * 3, batch_size=batch_size) + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + lr.fit(X, y) + + +onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15) +onx_train = add_loss_output(onx) + +inits = get_train_initializer(onx) +weights = {k: v for k, v in inits.items() if k != "shape_tensor"} +pprint(list((k, v[0].shape) for k, v in weights.items())) + +train_session = OrtGradientOptimizer( + onx_train, list(weights), device='cpu', eta0=5e-4, + warm_start=False, max_iter=max_iter, batch_size=batch_size) + +benches.append(benchmark(lr, train_session, name='LR-CPU')) + +if get_device() == 'GPU': + + train_session = OrtGradientOptimizer( + onx_train, list(weights), device='cuda', eta0=5e-4, + warm_start=False, max_iter=200, batch_size=batch_size) + + benches.append(benchmark(nn, train_session, name='NN-GPU')) + + +###################################### +# Graphs +# ++++++ +# +# Dataframe first. + +df = DataFrame(benches).set_index('name') +df + +####################################### +# Graphs. + +df[['skl', 'ort']].plot.bar(title="Processing time") + +# import matplotlib.pyplot as plt +# plt.show() diff --git a/_doc/examples/plot_orttraining_linear_regression_gpu.py b/_doc/examples/plot_orttraining_linear_regression_gpu.py index aa9e08af..14ba58f3 100644 --- a/_doc/examples/plot_orttraining_linear_regression_gpu.py +++ b/_doc/examples/plot_orttraining_linear_regression_gpu.py @@ -30,7 +30,7 @@ TrainingParameters, SessionOptions, TrainingSession) from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split -from mlprodict.plotting.plotting_onnx import plot_onnxs +from onnxcustom.plotting.plotting_onnx import plot_onnxs from tqdm import tqdm X, y = make_regression(n_features=2, bias=2) diff --git a/_doc/examples/plot_orttraining_nn_gpu.py b/_doc/examples/plot_orttraining_nn_gpu.py index 990fa20f..56ced857 100644 --- a/_doc/examples/plot_orttraining_nn_gpu.py +++ b/_doc/examples/plot_orttraining_nn_gpu.py @@ -26,7 +26,7 @@ from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPRegressor from sklearn.metrics import mean_squared_error -from mlprodict.plotting.plotting_onnx import plot_onnxs +from onnxcustom.plotting.plotting_onnx import plot_onnxs from mlprodict.onnx_conv import to_onnx from onnxcustom.training import add_loss_output, get_train_initializer from onnxcustom.training.optimizers import OrtGradientOptimizer diff --git a/_doc/sphinxdoc/source/api/plotting.rst b/_doc/sphinxdoc/source/api/plotting.rst new file mode 100644 index 00000000..3b7e11b0 --- /dev/null +++ b/_doc/sphinxdoc/source/api/plotting.rst @@ -0,0 +1,5 @@ + +Plotting +======== + +.. autofunction:: onnxcustom.plloting.plotting_onnx.plot_onnxs diff --git a/_doc/sphinxdoc/source/tutorial/tutorial_6_training.rst b/_doc/sphinxdoc/source/tutorial/tutorial_6_training.rst index 2fca39d0..e6531b99 100644 --- a/_doc/sphinxdoc/source/tutorial/tutorial_6_training.rst +++ b/_doc/sphinxdoc/source/tutorial/tutorial_6_training.rst @@ -30,3 +30,4 @@ approaches. ../gyexamples/plot_orttraining_linear_regression_cpu ../gyexamples/plot_orttraining_linear_regression_gpu ../gyexamples/plot_orttraining_nn_gpu + ../gyexamples/plot_orttraining_benchmark From 549e8807dc9caa908c490e8322aa72389b5b3ef9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 26 Nov 2021 00:15:51 +0100 Subject: [PATCH 3/3] Update test_documentation_examples_training_long.py --- .../test_documentation_examples_training_long.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_documentation/test_documentation_examples_training_long.py b/_unittests/ut_documentation/test_documentation_examples_training_long.py index 7303973a..d267bfee 100644 --- a/_unittests/ut_documentation/test_documentation_examples_training_long.py +++ b/_unittests/ut_documentation/test_documentation_examples_training_long.py @@ -14,6 +14,8 @@ except ImportError: ortt = None from pyquickhelper.pycode import skipif_circleci +from pyquickhelper.texthelper import compare_module_version +from mlprodict import __version__ as mlp_version def import_source(module_file_path, module_name): @@ -33,7 +35,10 @@ def import_source(module_file_path, module_name): class TestDocumentationExampleTrainingLong(unittest.TestCase): @unittest.skipIf( - ortt is None, reason="onnxruntime-training not installed.") + compare_module_version(mlp_version, "0.8") < 0, + reason="onnxruntime-training not installed.") + @unittest.skipIf( + ortt is None, reason="plot_onnx was updated.") @skipif_circleci("stuck") def test_documentation_examples_training(self):