diff --git a/_doc/examples/plot_orttraining_linear_regression.py b/_doc/examples/plot_orttraining_linear_regression.py index b49293f2..1a598897 100644 --- a/_doc/examples/plot_orttraining_linear_regression.py +++ b/_doc/examples/plot_orttraining_linear_regression.py @@ -81,7 +81,8 @@ # The training requires a loss function. By default, it # is the square function but it could be the absolute error or # include penalties. Function -# :func:`add_loss_output ` +# :func:`add_loss_output +# ` # appends the loss function to the ONNX graph. onx_train = add_loss_output(onx) diff --git a/_doc/examples/plot_orttraining_linear_regression_fwbw.py b/_doc/examples/plot_orttraining_linear_regression_fwbw.py index 194b1d38..1f519c91 100644 --- a/_doc/examples/plot_orttraining_linear_regression_fwbw.py +++ b/_doc/examples/plot_orttraining_linear_regression_fwbw.py @@ -170,7 +170,7 @@ class :class:`OrtGradientForwardBackwardOptimizer # and returns the updated weights. This graph works on tensors of any shape # but with the same element type. -plot_onnxs(train_session.loss_grad_onnx_, +plot_onnxs(train_session.learning_loss.loss_grad_onnx_, train_session.learning_rate.axpy_onnx_, title=['error gradient + loss', 'gradient update']) diff --git a/_doc/examples/plot_orttraining_nn_gpu.py b/_doc/examples/plot_orttraining_nn_gpu.py index 0081a702..b06de60e 100644 --- a/_doc/examples/plot_orttraining_nn_gpu.py +++ b/_doc/examples/plot_orttraining_nn_gpu.py @@ -65,7 +65,8 @@ # ++++++++++++++ # # The loss function is the square function. We use function -# :func:`add_loss_output `. +# :func:`add_loss_output +# `. # It does something what is implemented in example # :ref:`l-orttraining-linreg-cpu`. diff --git a/_doc/sphinxdoc/source/api/training.rst b/_doc/sphinxdoc/source/api/training.rst index e4814669..63eae5b4 100644 --- a/_doc/sphinxdoc/source/api/training.rst +++ b/_doc/sphinxdoc/source/api/training.rst @@ -1,24 +1,58 @@ +======== Training ======== +There exists two APIs in :epkg:`onnxruntime`. One assumes +the loss function is part of the graph to derive, the other +one assumes the users provides the derivative of the loss +against the output of the graph. With the first API, +the weights are automatically updated. In the second API, +the users has to do it. It is more complex but gives more +freedom. + +Both API are wrapped into two classes, +:ref:`l-api-prt-gradient-optimizer` for the first API, +:ref:`l-api-prt-gradient-optimizer-fw` for the second API. +Both classes make it easier to a user accustomed to +:epkg:`scikit-learn` API to train any graph with a +stochastic gradient descent algorithm. + .. contents:: :local: BaseEstimator -+++++++++++++ +============= + +Ancestor to both classes wrapping :epkg:`onnxruntime` API. .. autosignature:: onnxcustom.training.base_estimator.BaseEstimator :members: -LearningRate -++++++++++++ +Exceptions +========== -.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGD - :members: +.. autosignature:: onnxcustom.training.excs.ConvergenceError -.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGDNesterov - :members: +.. autosignature:: onnxcustom.training.excs.EvaluationError + +.. autosignature:: onnxcustom.training.excs.ProviderError + +First API: loss part of the graph +================================= + +Helpers ++++++++ + +Function `add_loss_output` adds a loss function to the graph +if this loss is part of the a predefined list. It may +be combination of L1, L2 losses and L1, L2 penalties. + +.. autosignature:: onnxcustom.utils.orttraining_helper.add_loss_output + +.. autosignature:: onnxcustom.utils.orttraining_helper.get_train_initializer + +.. _l-api-prt-gradient-optimizer: OrtGradientOptimizer ++++++++++++++++++++ @@ -26,25 +60,58 @@ OrtGradientOptimizer .. autosignature:: onnxcustom.training.optimizers.OrtGradientOptimizer :members: -OrtGradientForwardBackward -++++++++++++++++++++++++++ +Second API: loss part of the graph +================================== -.. autosignature:: onnxcustom.training.optimizers_partial.OrtGradientForwardBackwardOptimizer +ONNX +++++ + +Second API relies on class :epkg:`TrainingAgent`. It expects to find +the weight to train in alphabetical order. That's usual not the case. +The following function does not change the order but renames all +of them to fulfil that requirement. + +.. autosignature:: onnxcustom.utils.onnx_helper.onnx_rename_weights + +LearningPenalty ++++++++++++++++ + +.. autosignature:: onnxcustom.training.sgd_learning_penalty.NoLearningPenalty :members: -Helpers -+++++++ +.. autosignature:: onnxcustom.training.sgd_learning_penalty.ElasticLearningPenalty + :members: -.. autosignature:: onnxcustom.utils.orttraining_helper.add_loss_output +LearningRate +++++++++++++ -.. autosignature:: onnxcustom.utils.orttraining_helper.get_train_initializer +.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGD + :members: -Exceptions -++++++++++ +.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGDNesterov + :members: -.. autosignature:: onnxcustom.training.excs.ConvergenceError +LearningLoss +++++++++++++ + +.. autosignature:: onnxcustom.training.sgd_learning_loss.AbsoluteLearningLoss + :members: + +.. autosignature:: onnxcustom.training.sgd_learning_loss.ElasticLearningLoss + :members: + +.. autosignature:: onnxcustom.training.sgd_learning_loss.SquareLearningLoss + :members: Loss function +++++++++++++ .. autosignature:: onnxcustom.utils.onnx_function.function_onnx_graph + +.. _l-api-prt-gradient-optimizer-fw: + +OrtGradientForwardBackward +++++++++++++++++++++++++++ + +.. autosignature:: onnxcustom.training.optimizers_partial.OrtGradientForwardBackwardOptimizer + :members: diff --git a/_doc/sphinxdoc/source/api/utils.rst b/_doc/sphinxdoc/source/api/utils.rst index 41aa6a5f..79de4ba6 100644 --- a/_doc/sphinxdoc/source/api/utils.rst +++ b/_doc/sphinxdoc/source/api/utils.rst @@ -19,11 +19,6 @@ Labelling .. autosignature:: onnxcustom.utils.imagenet_classes.get_class_names -ONNX -++++ - -.. autosignature:: onnxcustom.utils.onnx_helper.onnx_rename_weights - Time ++++ diff --git a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training.rst b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training.rst index feb84749..b7c1f319 100644 --- a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training.rst +++ b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training.rst @@ -1,8 +1,14 @@ .. _l-full-training: -Full Training -============= +Full Training with OrtGradientOptimizer +======================================= + +.. contents:: + :local: + +Design +++++++ :epkg:`onnxruntime` was initially designed to speed up inference and deployment but it can also be used to train a model. @@ -10,6 +16,70 @@ It builds a graph equivalent to the gradient function also based on onnx operators and specific gradient operators. Initializers are weights that can be trained. The gradient graph has as many as outputs as initializers. + +:class:`OrtGradientOptimizer +` wraps +class :epkg:`TrainingSession` from :epkg:`onnxruntime-training`. +It starts with one model converted into ONNX graph. +A loss must be added to this graph. Then class :epkg:`TrainingSession` +is able to compute another ONNX graph equivalent to the gradient +of the loss against the weights defined by intializers. + +The first ONNX graph implements a function *Y=f(W, X)*. +Then function :func:`add_loss_output +` +adds a loss to define a graph *loss, Y=loss(f(W, X), W, expected_Y)*. +This same function is able to add the necessary nodes to compute +L1 and L2 losses or a combination of both, a L1 or L2 penalties +or a combination of both. Assuming the user was able to create +an an ONNX graph, he would add *0.1 L1 loss + 0.9 L2 loss* +and a L2 penalty on the coefficients by calling :func:`add_loss_output +` +like that: + +:: + + onx_loss = add_loss_output( + onx, weight_name='weight', score_name='elastic', + l1_weight=0.1, l2_weight=0.9, + penalty={'coef': {'l2': 0.01}}) + +An instance of class :class:`OrtGradientOptimizer +` is +initialized: + +:: + + train_session = OrtGradientOptimizer( + onx_loss, ['intercept', 'coef'], learning_rate=1e-3) + +And then trained: + +:: + + train_session.fit(X_train, y_train, w_train) + +Coefficients can be retrieved like the following: + +:: + + state_tensors = train_session.get_state() + +And train losses: + +:: + + losses = train_session.train_losses_ + +This design does not allow any training with momentum, +keeping an accumulator for gradients yet. +The class does not expose all the possibilies implemented in +:epkg:`onnxruntime-training`. +Next examples show that in practice. + +Examples +++++++++ + The first example compares a linear regression trained with :epkg:`scikit-learn` and another one trained with :epkg:`onnxruntime-training`. @@ -18,6 +88,9 @@ The two next examples explains in details how the training with :epkg:`onnxruntime-training`. They dig into class :class:`OrtGradientOptimizer `. +It leverages class :epkg:`TrainingSession` from :epkg:`onnxruntime-training`. +This one assumes the loss function is part of the graph to train. +It takes care to the weight updating as well. The fourth example replicates what was done with the linear regression but with a neural network built by :epkg:`scikit-learn`. diff --git a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst index 7952eedf..f18b51bb 100644 --- a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst +++ b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst @@ -1,6 +1,12 @@ -Partial Training -================ +Partial Training with OrtGradientForwardBackwardOptimizer +========================================================= + +.. contents:: + :local: + +Design +++++++ Section :ref:`l-full-training` introduces a class able a while ONNX graph. :epkg:`onnxruntime-training` handles the computation @@ -13,7 +19,64 @@ ONNX, and be trained by a gradient descent implemented in python. Partial training is another way to train an ONNX model. It can be trained as a standalone ONNX graph or be integrated in a :epkg:`torch` model or any framework implementing *forward* and *backward* mechanism. -Next example introduced how this is done with ONNX. +It leverages class :epkg:`TrainingAgent` from :epkg:`onnxruntime-training`. + +Main class is :class:`OrtGradientForwardBackwardOptimizer +`. +It is initialized with an ONNX graph defining + +:: + + train_session = OrtGradientForwardBackwardOptimizer( + onx, ['coef', 'intercept'], + learning_rate=LearningRateSGDNesterov() + learning_loss=ElasticLearningLoss(l1_weight=0.1, l2_weight=0.9), + learning_penalty=ElasticLearningPenalty(l1=0.1, l2=0.9)) + +The class holds three attributes defining the loss, its gradient, +the penalty, its gradient, a learning rate possibly with momentum. + +* an object inheriting from :class:`BaseLearningLoss + ` +* an object inheriting from :class:`BaseLearningPenalty + ` +* an object inheriting from :class:`BaseLearningRate + ` + +Because :epkg:`onnxruntime-training` does not implement any standard +operations on :epkg:`OrtValue`, the only remaining is to create +simple ONNX graph execute by :epkg:`InferenceSession` to compute +loss, penalty and their gradient, and to update the weights accordingly. +These three classes all implement meth `build_onnx_function` which +creates create the ONNX graph based on the argument the classes were +initialized with. Traning happens this way: + +:: + + train_session.fit(X_train, y_train, w_train) + +Coefficients can be retrieved like the following: + +:: + + state_tensors = train_session.get_state() + +And train losses: + +:: + + losses = train_session.train_losses_ + +Next examples show that in practice. + +Examples +++++++++ + +This example assumes the loss function is not part of the graph to train +but the gradient of the loss against the graph output is provided. +It does not take care to the weight. This part must be separatly +implemented as well. Next examples introduce how this is done +with ONNX and :epkg:`onnxruntime-training`. .. toctree:: :maxdepth: 1 diff --git a/_unittests/ut_training/test_optimizers_forward_backward.py b/_unittests/ut_training/test_optimizers_forward_backward.py index eee64efa..57a68b8b 100644 --- a/_unittests/ut_training/test_optimizers_forward_backward.py +++ b/_unittests/ut_training/test_optimizers_forward_backward.py @@ -1,5 +1,5 @@ """ -@brief test log(time=9s) +@brief test log(time=10s) """ import unittest @@ -18,6 +18,11 @@ from onnxcustom import __max_supported_opset__ as opset from onnxcustom.training.sgd_learning_rate import ( LearningRateSGD, LearningRateSGDNesterov) +from onnxcustom.training.sgd_learning_loss import ( + BaseLearningLoss, SquareLearningLoss, AbsoluteLearningLoss, + ElasticLearningLoss) +from onnxcustom.training.sgd_learning_penalty import ( + BaseLearningPenalty, NoLearningPenalty, ElasticLearningPenalty) from onnxcustom.utils.onnx_helper import onnx_rename_weights from onnxcustom.training import ConvergenceError try: @@ -47,6 +52,7 @@ def test_ort_gradient_optimizers_use_numpy_zero(self): set_model_props(onx, {'info': 'unit test'}) inits = ['coef', 'intercept'] train_session = OrtGradientForwardBackwardOptimizer(onx, inits) + self.assertIsInstance(train_session.learning_loss, SquareLearningLoss) self.assertRaise(lambda: train_session.get_state(), AttributeError) train_session.fit(X_train, y_train, use_numpy=True) state_tensors = train_session.get_state() @@ -562,6 +568,113 @@ def test_ort_gradient_optimizers_optimal_use_ort_w(self): self.assertGreater(len(losses), 1) self.assertFalse(any(map(numpy.isnan, losses))) + @unittest.skipIf(TrainingSession is None, reason="not training") + def test_ort_gradient_optimizers_optimal_use_ort_w_absolute(self): + from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=10, bias=2, random_state=0) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + w = (numpy.random.rand(y.shape[0]) + 1).astype(X.dtype) + X_train, _, y_train, __, w_train, ___ = train_test_split( + X, y, w) + reg = LinearRegression() + reg.fit(X_train, y_train, w_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + onx = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + inits = ['coef', 'intercept'] + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, max_iter=10, weight_name='weight', + learning_rate=LearningRateSGD(learning_rate='optimal'), + learning_loss='absolute_error') + self.assertIsInstance( + train_session.learning_loss, AbsoluteLearningLoss) + self.assertRaise(lambda: train_session.get_state(), AttributeError) + train_session.fit(X_train, y_train, w_train, use_numpy=False) + state_tensors = train_session.get_state() + self.assertEqual(len(state_tensors), 2) + r = repr(train_session) + self.assertIn("OrtGradientForwardBackwardOptimizer(model_onnx=", r) + self.assertIn("learning_rate='optimal'", r) + losses = train_session.train_losses_ + self.assertGreater(len(losses), 1) + self.assertFalse(any(map(numpy.isnan, losses))) + + @unittest.skipIf(TrainingSession is None, reason="not training") + def test_ort_gradient_optimizers_optimal_use_ort_w_elastic(self): + from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=10, bias=2, random_state=0) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + w = (numpy.random.rand(y.shape[0]) + 1).astype(X.dtype) + X_train, _, y_train, __, w_train, ___ = train_test_split( + X, y, w) + reg = LinearRegression() + reg.fit(X_train, y_train, w_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + onx = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + inits = ['coef', 'intercept'] + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, max_iter=10, weight_name='weight', + learning_rate=LearningRateSGD(learning_rate='optimal'), + learning_loss=BaseLearningLoss.select( + 'elastic', l1_weight=0.1, l2_weight=0.9)) + self.assertIsInstance( + train_session.learning_loss, ElasticLearningLoss) + self.assertIsInstance( + train_session.learning_penalty, NoLearningPenalty) + self.assertRaise(lambda: train_session.get_state(), AttributeError) + train_session.fit(X_train, y_train, w_train, use_numpy=False) + state_tensors = train_session.get_state() + self.assertEqual(len(state_tensors), 2) + r = repr(train_session) + self.assertIn("OrtGradientForwardBackwardOptimizer(model_onnx=", r) + self.assertIn("learning_rate='optimal'", r) + losses = train_session.train_losses_ + self.assertGreater(len(losses), 1) + self.assertFalse(any(map(numpy.isnan, losses))) + + @unittest.skipIf(TrainingSession is None, reason="not training") + def test_ort_gradient_optimizers_optimal_use_ort_w_elastic_penalty(self): + from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=10, bias=2, random_state=0) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + w = (numpy.random.rand(y.shape[0]) + 1).astype(X.dtype) + X_train, _, y_train, __, w_train, ___ = train_test_split( + X, y, w) + reg = LinearRegression() + reg.fit(X_train, y_train, w_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + onx = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + inits = ['coef', 'intercept'] + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, max_iter=10, weight_name='weight', + learning_rate=LearningRateSGD(learning_rate='optimal'), + learning_loss=BaseLearningLoss.select( + 'elastic', l1_weight=0.1, l2_weight=0.9), + learning_penalty=BaseLearningPenalty.select( + 'elastic', l1=0.1, l2=0.9)) + self.assertIsInstance( + train_session.learning_loss, ElasticLearningLoss) + self.assertIsInstance( + train_session.learning_penalty, ElasticLearningPenalty) + self.assertRaise(lambda: train_session.get_state(), AttributeError) + train_session.fit(X_train, y_train, w_train, use_numpy=False) + state_tensors = train_session.get_state() + self.assertEqual(len(state_tensors), 2) + r = repr(train_session) + self.assertIn("OrtGradientForwardBackwardOptimizer(model_onnx=", r) + self.assertIn("learning_rate='optimal'", r) + losses = train_session.train_losses_ + self.assertGreater(len(losses), 1) + self.assertFalse(any(map(numpy.isnan, losses))) + @unittest.skipIf(TrainingSession is None, reason="not training") def test_ort_gradient_optimizers_evaluation_use_numpy(self): from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer @@ -933,12 +1046,71 @@ def test_ort_gradient_optimizers_use_numpy_pickle_w_nesterov(self): self.assertGreater(len(losses), 1) self.assertFalse(any(map(numpy.isnan, losses))) + @unittest.skipIf(TrainingSession is None, reason="not training") + def test_ort_gradient_optimizers_use_numpy_pickle_w_nesterov_rate(self): + from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=10, bias=2, random_state=0) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + w = (numpy.random.rand(y.shape[0]) + 1).astype(X.dtype) + X_train, _, y_train, __, w_train, ___ = train_test_split( + X, y, w) + reg = LinearRegression() + reg.fit(X_train, y_train, w_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + onx = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + set_model_props(onx, {'info': 'unit test'}) + inits = ['coef', 'intercept'] + train_session0 = OrtGradientForwardBackwardOptimizer( + onx, inits, learning_rate="Nesterov", weight_name='weight', + learning_loss=BaseLearningLoss.select( + 'elastic', l1_weight=0.1, l2_weight=0.9), + learning_penalty=BaseLearningPenalty.select( + 'elastic', l1=0.1, l2=0.9)) + self.assertIsInstance(train_session0.learning_rate, + LearningRateSGDNesterov) + self.assertIsInstance(train_session0.learning_loss, + ElasticLearningLoss) + self.assertIsInstance(train_session0.learning_penalty, + ElasticLearningPenalty) + st = io.BytesIO() + pickle.dump(train_session0, st) + st2 = io.BytesIO(st.getvalue()) + train_session1 = pickle.load(st2) + + train_session1.fit(X_train, y_train, w_train, use_numpy=True) + self.assertIsInstance(train_session1.learning_rate, + LearningRateSGDNesterov) + self.assertIsInstance(train_session1.learning_loss, + ElasticLearningLoss) + self.assertIsInstance(train_session1.learning_penalty, + ElasticLearningPenalty) + + st = io.BytesIO() + pickle.dump(train_session1, st) + st2 = io.BytesIO(st.getvalue()) + train_session = pickle.load(st2) + state_tensors = train_session.get_state() + self.assertEqual(len(state_tensors), 2) + + train_session.fit(X_train, y_train, w_train, use_numpy=True) + state_tensors = train_session.get_state() + self.assertEqual(len(state_tensors), 2) + r = repr(train_session) + self.assertIn("OrtGradientForwardBackwardOptimizer(model_onnx=", r) + self.assertIn("learning_rate='invscaling'", r) + losses = train_session.train_losses_ + self.assertGreater(len(losses), 1) + self.assertFalse(any(map(numpy.isnan, losses))) + if __name__ == "__main__": # import logging # logger = logging.getLogger('onnxcustom') # logger.setLevel(logging.DEBUG) # logging.basicConfig(level=logging.DEBUG) - # TestOptimizersForwardBackward().test_ort_gradient_optimizers_use_numpy_nesterov() + # TestOptimizersForwardBackward().test_ort_gradient_optimizers_optimal_use_ort_w_elastic_penalty() # stop unittest.main() diff --git a/_unittests/ut_utils/test_onnx_function.py b/_unittests/ut_utils/test_onnx_function.py index f09a2056..2758289f 100644 --- a/_unittests/ut_utils/test_onnx_function.py +++ b/_unittests/ut_utils/test_onnx_function.py @@ -108,10 +108,11 @@ def common_check_alpha(self, name, fct): def test_grad_onnx_axpy(self): self.common_check_alpha("axpy", lambda x1, x2, alpha: x1 * alpha + x2) - def common_check_2(self, name, fct, weight_name=None): + def common_check_2(self, name, fct, weight_name=None, **kwargs): onx = function_onnx_graph( name, target_opset=get_max_opset(), - dtype=numpy.float32, weight_name=weight_name) + dtype=numpy.float32, weight_name=weight_name, + **kwargs) x1 = numpy.random.randn(10, 1).astype(numpy.float32) x2 = numpy.random.randn(10, 1).astype(numpy.float32) w = numpy.random.rand(10).astype(numpy.float32) @@ -159,6 +160,39 @@ def test_loss_grad_onnx_square_error_w(self): (x1 - x2) * (-2) * w.reshape((-1, 1))), weight_name='weight') + def test_loss_grad_onnx_absolute_error(self): + self.common_check_2( + "grad_loss_absolute_error", + lambda x1, x2: (numpy.abs(x1 - x2).sum(), + numpy.sign(x1 - x2))) + + def test_loss_grad_onnx_absolute_error_w(self): + self.common_check_2( + "grad_loss_absolute_error", + lambda x1, x2, w: ((numpy.abs(x1 - x2) * w.reshape((-1, 1))).sum(), + numpy.sign(x1 - x2) * w.reshape((-1, 1))), + weight_name='weight') + + def test_loss_grad_onnx_elastic_error(self): + self.common_check_2( + "grad_loss_elastic_error", + lambda x1, x2: ( + numpy.abs(x1 - x2).sum() * 0.1 + ((x1 - x2) ** 2).sum() * 0.9, + numpy.sign(x1 - x2) * 0.1 - 2 * 0.9 * (x1 - x2) + ), + l1_weight=0.1, l2_weight=0.9) + + def test_loss_grad_onnx_elastic_error_w(self): + self.common_check_2( + "grad_loss_elastic_error", + lambda x1, x2, w: ( + (numpy.abs(x1 - x2) * w.reshape((-1, 1))).sum() * 0.1 + + ((x1 - x2) ** 2 * w.reshape((-1, 1))).sum() * 0.9, + numpy.sign(x1 - x2) * w.reshape((-1, 1)) * 0.1 + + (x1 - x2) * (-2) * w.reshape((-1, 1)) * 0.9 + ), + weight_name='weight', l1_weight=0.1, l2_weight=0.9) + def common_check_3(self, name, fct): onx = function_onnx_graph( name, target_opset=get_max_opset(), @@ -268,6 +302,88 @@ def test_grad_onnx_axpyw2(self): (x1 * alpha + x2 + beta * (x1 * alpha + beta * g), x1 * alpha + beta * g)) + def common_check_1(self, name, fct, weight_name=None, **kwargs): + onx = function_onnx_graph( + name, target_opset=get_max_opset(), + dtype=numpy.float32, weight_name=weight_name, + **kwargs) + x = numpy.random.randn(10, 1).astype(numpy.float32) + exp_loss, exp_grad = fct(x) + + oinf = OnnxInference(onx) + got = oinf.run({'X': x}) + self.assertEqualArray(exp_loss, got['Y'], decimal=5) + self.assertEqualArray(exp_grad, got['Z'], decimal=5) + + providers = device_to_providers('cpu') + so = SessionOptions() + so.log_severity_level = 4 + sess = InferenceSession( + onx.SerializeToString(), so, providers=providers) + got = sess.run(None, {'X': x}) + self.assertEqualArray(exp_loss, got[0], decimal=5) + self.assertEqualArray(exp_grad, got[1], decimal=5) + + def test_penalty_grad_onnx_elastic_error(self): + self.common_check_1( + "grad_penalty_elastic_error", + lambda x: ( + numpy.abs(x).sum() * 0.1 + ((x) ** 2).sum() * 0.9, + numpy.sign(x) * 0.1 + 2 * 0.9 * x + ), + l1_weight=0.1, l2_weight=0.9) + + def test_penalty_3(self): + loss = numpy.random.randn(1, 1).astype(numpy.float32) + w1 = numpy.random.randn(10, 1).astype(numpy.float32) + w2 = numpy.random.randn(5, 1).astype(numpy.float32) + + def fct(x): + return numpy.abs(x).sum() * 0.1 + ((x) ** 2).sum() * 0.9 + + exp_loss = loss + fct(w1) + fct(w2) + + onx = function_onnx_graph( + 'n_penalty_elastic_error', target_opset=get_max_opset(), + dtype=numpy.float32, n_tensors=2, + l1_weight=0.1, l2_weight=0.9) + + oinf = OnnxInference(onx) + got = oinf.run({'loss': loss, 'W0': w1, 'W1': w2}) + self.assertEqualArray(exp_loss, got['Y'], decimal=5) + + providers = device_to_providers('cpu') + so = SessionOptions() + so.log_severity_level = 4 + sess = InferenceSession( + onx.SerializeToString(), so, providers=providers) + got = sess.run(None, {'loss': loss, 'W0': w1, 'W1': w2}) + self.assertEqualArray(exp_loss, got[0], decimal=5) + + def test_penalty_update(self): + x = numpy.random.randn(10, 1).astype(numpy.float32) + + def fct(x): + return numpy.sign(x) * 0.1 + (x * 0.9 * 2) + + exp_loss = x - fct(x) + + onx = function_onnx_graph( + 'update_penalty_elastic_error', target_opset=get_max_opset(), + dtype=numpy.float32, l1=0.1, l2=0.9) + + oinf = OnnxInference(onx) + got = oinf.run({'X': x}) + self.assertEqualArray(exp_loss, got['Y'], decimal=5) + + providers = device_to_providers('cpu') + so = SessionOptions() + so.log_severity_level = 4 + sess = InferenceSession( + onx.SerializeToString(), so, providers=providers) + got = sess.run(None, {'X': x}) + self.assertEqualArray(exp_loss, got[0], decimal=5) + if __name__ == "__main__": unittest.main() diff --git a/onnxcustom/training/base_onnx_function.py b/onnxcustom/training/base_onnx_function.py index 0b7339db..332a530d 100644 --- a/onnxcustom/training/base_onnx_function.py +++ b/onnxcustom/training/base_onnx_function.py @@ -81,7 +81,7 @@ def __repr__(self): return "%s(%s)%s" % ( self.__class__.__name__, ", ".join(ps), self.__repr_extended__()) - def build_onnx_function(self, opset, device): + def build_onnx_function(self, opset, device, *args): """ This class updates the weights. It assumes it can do operator on *OrtValue*. @@ -91,6 +91,7 @@ def build_onnx_function(self, opset, device): :param opset: opset to use :param device: :epkg:`C_OrtDevice` + :param args: additional arguments """ raise NotImplementedError( "This method must be overwritten.") diff --git a/onnxcustom/training/optimizers.py b/onnxcustom/training/optimizers.py index c2a35690..685964d9 100644 --- a/onnxcustom/training/optimizers.py +++ b/onnxcustom/training/optimizers.py @@ -90,7 +90,7 @@ def fit(self, X, y, sample_weight=None, X_val=None, y_val=None, input_names = [i.name for i in self.model_onnx.graph.input] if ((len(input_names) == 2 and sample_weight is not None) or (len(input_names) == 3 and sample_weight is None)): - raise RuntimeError( + raise RuntimeError( # pragma: no cover "Number of inputs should be 2 if sample_weight is None " "or 3 if not None but it is %d." % len(input_names)) self.train_session_ = self._create_training_session( @@ -171,7 +171,7 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue): it can be also a numpy array """ if not isinstance(bind, C_IOBinding): - raise TypeError( + raise TypeError( # pragma: no cover "Unexpected type %r." % type(bind)) if isinstance(c_ortvalue, C_OrtValue): # does not work diff --git a/onnxcustom/training/optimizers_partial.py b/onnxcustom/training/optimizers_partial.py index 0c8caa45..d82ee37e 100644 --- a/onnxcustom/training/optimizers_partial.py +++ b/onnxcustom/training/optimizers_partial.py @@ -9,14 +9,16 @@ OrtValue as C_OrtValue) from ..utils.onnx_helper import get_onnx_opset, proto_type_to_dtype from ..utils.onnxruntime_helper import ( - device_to_providers, numpy_to_ort_value, ort_device_to_string) + device_to_providers, numpy_to_ort_value) from ..utils.onnx_function import function_onnx_graph from ..utils.print_helper import str_ortvalue from ..utils.orttraining_helper import get_train_initializer from .ortgradient import OrtGradientForwardBackward from .base_estimator import BaseEstimator +from .sgd_learning_loss import BaseLearningLoss +from .sgd_learning_penalty import BaseLearningPenalty from .data_loader import OrtDataLoader -from .excs import ConvergenceError, ProviderError +from .excs import ConvergenceError class OrtGradientForwardBackwardOptimizer(BaseEstimator): @@ -40,7 +42,7 @@ class OrtGradientForwardBackwardOptimizer(BaseEstimator): :param warm_start: when set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. - :param loss_function: loss function (see below) + :param learning_loss: loss function (see below) :param verbose: use :epkg:`tqdm` to display the training progress :param validation_every: validation with a test set every *validation_every* iterations @@ -48,9 +50,18 @@ class OrtGradientForwardBackwardOptimizer(BaseEstimator): as it slows down the training) :param weight_name: if not None, the class assumes it is trained with training weight - - *loss_function* can be: - * `square_error`: mean square error, used for regression + :param learning_penalty: weight penalty, None, or instance of + @see cl BaseLearningPenalty + + *learning_rate* can be any instance of @see cl BaseLearningRate or + a nick name in the following list as specified in + :meth:`BaseLearningRate.select + `. + + *learning_loss* can be any instance of @see cl BaseLearningLoss or + a nick name in the following list as specified in + :meth:`BaseLearningLoss.select + `. """ def __init__(self, model_onnx, weights_to_train=None, @@ -58,8 +69,9 @@ def __init__(self, model_onnx, weights_to_train=None, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate='SGD', device='cpu', warm_start=False, verbose=0, - validation_every=0.1, loss_function="square_error", - enable_logging=False, weight_name=None): + validation_every=0.1, learning_loss="square_error", + enable_logging=False, weight_name=None, + learning_penalty=None): if weights_to_train is None: weights_to_train = list(get_train_initializer(model_onnx)) BaseEstimator.__init__(self, learning_rate, device) @@ -71,7 +83,8 @@ def __init__(self, model_onnx, weights_to_train=None, self.verbose = verbose self.max_iter = max_iter self.warm_start = warm_start - self.loss_function = loss_function + self.learning_loss = BaseLearningLoss.select(learning_loss) + self.learning_penalty = BaseLearningPenalty.select(learning_penalty) self.enable_logging = enable_logging self.weight_name = weight_name if validation_every < 1: @@ -115,7 +128,7 @@ def __setstate__(self, state): elif k == 'train_grad_state': self.set_state(v, check_trained=False, kind='grad') else: - raise ValueError( + raise ValueError( # pragma: no cover "Unexpected key state %r." % k) self.build_onnx_function() return self @@ -201,14 +214,15 @@ def build_onnx_function(self): so.log_severity_level = 4 # loss_grad - self.loss_grad_onnx_ = function_onnx_graph( - "grad_loss_" + self.loss_function, target_opset=opset, - weight_name=self.weight_name) - self.loss_grad_sess_ = InferenceSession( - self.loss_grad_onnx_.SerializeToString(), so, - providers=device_to_providers(self.device)) - self.loss_grad_sess_bind_ = ( - self.loss_grad_sess_.io_binding()._iobinding) + self.learning_loss.build_onnx_function( + opset, self.device, self.weight_name) + + # weight update + self.learning_rate.build_onnx_function(opset, self.device) + + # penalty + n = len(self.weights_to_train) + self.learning_penalty.build_onnx_function(opset, self.device, n) # zero self.zero_onnx_ = function_onnx_graph("zero") @@ -222,8 +236,6 @@ def build_onnx_function(self): else: self._logger = None - self.learning_rate.build_onnx_function(opset, self.device) - def fit(self, X, y, sample_weight=None, X_val=None, y_val=None, use_numpy=False): """ @@ -355,67 +367,6 @@ def fit(self, X, y, sample_weight=None, "end loss=%r", self.train_losses_[-1]) return self - def _bind_input_ortvalue(self, name, bind, c_ortvalue): - """ - Binds :epkg:`C_OrtValue` to the structure used by - :epkg:`InferenceSession` to run inference. - - :param name: str - :param bind: python structure - :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`), - it can be also a numpy array - """ - if isinstance(c_ortvalue, C_OrtValue): - bind.bind_ortvalue_input(name, c_ortvalue) - elif isinstance(c_ortvalue, numpy.ndarray): - if self.device_type() != self.device.cpu(): # pylint: disable=E1101 - raise ProviderError( - "device=%s is not CPU." % ort_device_to_string( - self.device)) - bind.bind_input( - name, self.device, c_ortvalue.dtype, c_ortvalue.shape, - c_ortvalue.__array_interface__['data'][0]) - else: - raise TypeError( # pragma: no cover - "Unable to bind type %r for name %r." % ( - type(c_ortvalue), name)) - - def _bind_output_ortvalue(self, name, bind, c_ortvalue): - """ - Binds :epkg:`C_OrtValue` to the structure used by - :epkg:`InferenceSession` to run inference. - - :param name: str - :param bind: python structure - :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`) - - This method can be used for inplace computation. - """ - if isinstance(c_ortvalue, C_OrtValue): - bind.bind_ortvalue_output(name, c_ortvalue) - else: - raise TypeError( # pragma: no cover - "Unable to bind type %r for name %r." % ( - type(c_ortvalue), name)) - - def _loss_gradient(self, expected, predicted, weight=None): - """ - Returns the loss and the gradient as OrtValue. - """ - if weight is not None: - self._bind_input_ortvalue( - "weight", self.loss_grad_sess_bind_, weight) - else: - self.loss_grad_sess_bind_.clear_binding_inputs() - self._bind_input_ortvalue("X1", self.loss_grad_sess_bind_, expected) - self._bind_input_ortvalue("X2", self.loss_grad_sess_bind_, predicted) - self.loss_grad_sess_bind_.bind_output('Y', self.device) - self.loss_grad_sess_bind_.bind_output('Z', self.device) - self.loss_grad_sess_._sess.run_with_iobinding( - self.loss_grad_sess_bind_, None) - loss, grad = self.loss_grad_sess_bind_.get_outputs() - return loss, grad - def _iteration(self, data_loader, states, n_weights): actual_losses = [] bs = data_loader.batch_size @@ -446,8 +397,11 @@ def _iteration(self, data_loader, states, n_weights): "batch %d", ib) prediction = self.train_function_.forward(states[0], training=True) - loss, loss_gradient = self._loss_gradient( - orty, prediction[0], weight=ortw) + loss, loss_gradient = self.learning_loss.loss_gradient( + self.device, orty, prediction[0], weight=ortw) + n = len(state) - n_weights + loss = self.learning_penalty.penalty_loss( + self.device, loss, *state[n:]) cpu_loss = loss.numpy() if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss): raise ConvergenceError( @@ -468,6 +422,7 @@ def _iteration(self, data_loader, states, n_weights): n = len(state) - n_weights for i in range(n, len(state)): + self.learning_penalty.update_weights(self.device, state[i]) self.learning_rate.update_weights( self.device, state[i], gradient[i], bs, None if grad is None else grad[i]) @@ -503,7 +458,8 @@ def _evaluation(self, data_loader, state): "batch %d", ib) prediction = self.train_function_.forward(state, training=False) - loss, _ = self._loss_gradient(orty, prediction[0]) + loss, _ = self.learning_loss.loss_gradient( + self.device, orty, prediction[0]) cpu_loss = loss.numpy() if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss): raise ConvergenceError( diff --git a/onnxcustom/training/ortgradient.py b/onnxcustom/training/ortgradient.py index 09d65822..96b3ea92 100644 --- a/onnxcustom/training/ortgradient.py +++ b/onnxcustom/training/ortgradient.py @@ -513,7 +513,7 @@ def device_name(device): return 'Cpu' if device.device_type() == OrtDevice.cuda(): return 'Gpu' - raise RuntimeError( + raise RuntimeError( # pragma: no cover "Unexpected value for device type %r." % device.device_type()) @staticmethod diff --git a/onnxcustom/training/sgd_learning_loss.py b/onnxcustom/training/sgd_learning_loss.py new file mode 100644 index 00000000..2bbfe64b --- /dev/null +++ b/onnxcustom/training/sgd_learning_loss.py @@ -0,0 +1,171 @@ +# pylint: disable=W0105 +""" +@file +@brief Helper for :epkg:`onnxruntime-training`. +""" +from onnxruntime import SessionOptions, InferenceSession +from ..utils.onnx_function import function_onnx_graph +from ..utils.onnxruntime_helper import device_to_providers +from .base_onnx_function import BaseLearningOnnx + + +class BaseLearningLoss(BaseLearningOnnx): + """ + Class handling the loss for class + @see cl OrtGradientForwardBackwardOptimizer. + All classes inheriting from this one creates one ONNX function, + returning the loss and the gradient of the loss against the + outputs. Method `loss_gradient` is the main method, it computes + the loss and the gradient defiend by one ONNX graph and + executed by an instance of :epkg:`InferenceSession`. + """ + + def __init__(self): + BaseLearningOnnx.__init__(self) + + def loss_gradient( # pylint: disable=E1101 + self, device, expected, predicted, weight=None): + """ + Returns the loss and the gradient as OrtValue. + + :param device: device where the training takes place + :param expected: expected value + :param predicted: predicted value + :param weight: optional, training weights + (same dimension as expected and predicted tensors) + :return: loss and gradient + """ + if (not hasattr(self, "loss_grad_sess_") or + not hasattr(self, "loss_grad_sess_bind_")): + raise RuntimeError( # pragma: no cover + "Attributes 'loss_grad_sess_bind_' or 'loss_grad_sess_' " + "is missing. Method 'build_onnx_function' has not been called.") + if weight is not None: + self._bind_input_ortvalue( + "weight", self.loss_grad_sess_bind_, weight, device) + else: + self.loss_grad_sess_bind_.clear_binding_inputs() + self._bind_input_ortvalue( + "X1", self.loss_grad_sess_bind_, expected, device) + self._bind_input_ortvalue( + "X2", self.loss_grad_sess_bind_, predicted, device) + self.loss_grad_sess_bind_.bind_output('Y', device) + self.loss_grad_sess_bind_.bind_output('Z', device) + self.loss_grad_sess_._sess.run_with_iobinding( + self.loss_grad_sess_bind_, None) + loss, grad = self.loss_grad_sess_bind_.get_outputs() + return loss, grad + + @staticmethod + def select(class_name, **kwargs): + """ + Returns an instance of a given initialized with + *kwargs*. + :param class_name: an instance of @see cl BaseLearningLoss + or a string among the following class names (see below) + :return: instance of @see cl BaseLearningLoss + + Possible values for *class_name*: + * `'square_error'`: see @see cl SquareLearningLoss + * `'absolute_error'`: see @see cl AbsoluteLearningLoss + * `'elastic_error'`: see @see cl ElasticLearningLoss + """ + if isinstance(class_name, BaseLearningLoss): + return class_name + cls = {SquareLearningLoss: ['square_error', 'square'], + AbsoluteLearningLoss: ['absolute_error', 'absolute'], + ElasticLearningLoss: ['elastic_error', 'elastic']} + for cl, aliases in cls.items(): + if class_name == cl.__class__.__name__ or class_name in aliases: + return cl(**kwargs) + raise ValueError( # pragma: no cover + "Unexpected class name %r. It should be one of %r." % ( + class_name, list(map(lambda c: c.__name__, cls)))) + + +class SquareLearningLoss(BaseLearningLoss): + """ + Implements a square loss :math:`(Y - Z)^2` + where *Y* is the output and *Z* the expected output. + See @see fn _onnx_grad_loss_square_error for the ONNX + implementation. + """ + + def __init__(self): + BaseLearningLoss.__init__(self) + + def build_onnx_function(self, opset, device, weight_name): + so = SessionOptions() + so.log_severity_level = 4 + + # loss_grad + self.loss_grad_onnx_ = function_onnx_graph( + "grad_loss_square_error", target_opset=opset, + weight_name=weight_name) + self.loss_grad_sess_ = InferenceSession( + self.loss_grad_onnx_.SerializeToString(), so, + providers=device_to_providers(device)) + self.loss_grad_sess_bind_ = ( + self.loss_grad_sess_.io_binding()._iobinding) + + +class AbsoluteLearningLoss(BaseLearningLoss): + """ + Implements a square loss :math:`|Y - Z|` + where *Y* is the output and *Z* the expected output. + See @see fn _onnx_grad_loss_absolute_error for the ONNX + implementation. + """ + + def __init__(self): + BaseLearningLoss.__init__(self) + + def build_onnx_function(self, opset, device, weight_name): + so = SessionOptions() + so.log_severity_level = 4 + + # loss_grad + self.loss_grad_onnx_ = function_onnx_graph( + "grad_loss_absolute_error", target_opset=opset, + weight_name=weight_name) + self.loss_grad_sess_ = InferenceSession( + self.loss_grad_onnx_.SerializeToString(), so, + providers=device_to_providers(device)) + self.loss_grad_sess_bind_ = ( + self.loss_grad_sess_.io_binding()._iobinding) + + +class ElasticLearningLoss(BaseLearningLoss): + """ + Implements a square loss + :math:`(Y - Z)^2 \\alpha + |Y - Z| * \\beta` + where *Y* is the output and *Z* the expected output, + :math:`\\alpha` is *l2_weight* and :math:`\\beta` + is *l1_weight*. + + :param l1_weight: weight of L1 norm + :param l2_weight: weight of L2 norm + + See @see fn _onnx_grad_loss_elastic_error for the ONNX + implementation. + """ + + def __init__(self, l1_weight=0.5, l2_weight=0.5): + BaseLearningLoss.__init__(self) + self.l1_weight = l1_weight + self.l2_weight = l2_weight + + def build_onnx_function(self, opset, device, weight_name): + so = SessionOptions() + so.log_severity_level = 4 + + # loss_grad + self.loss_grad_onnx_ = function_onnx_graph( + "grad_loss_elastic_error", target_opset=opset, + weight_name=weight_name, l1_weight=self.l1_weight, + l2_weight=self.l2_weight) + self.loss_grad_sess_ = InferenceSession( + self.loss_grad_onnx_.SerializeToString(), so, + providers=device_to_providers(device)) + self.loss_grad_sess_bind_ = ( + self.loss_grad_sess_.io_binding()._iobinding) diff --git a/onnxcustom/training/sgd_learning_penalty.py b/onnxcustom/training/sgd_learning_penalty.py new file mode 100644 index 00000000..adfb6e4a --- /dev/null +++ b/onnxcustom/training/sgd_learning_penalty.py @@ -0,0 +1,174 @@ +# pylint: disable=W0105 +""" +@file +@brief Helper for :epkg:`onnxruntime-training`. +""" +from onnxruntime import SessionOptions, InferenceSession +from ..utils.onnx_function import function_onnx_graph +from ..utils.onnxruntime_helper import device_to_providers +from .base_onnx_function import BaseLearningOnnx + + +class BaseLearningPenalty(BaseLearningOnnx): + """ + Class handling the penalty on the coefficients for class + @see cl OrtGradientForwardBackwardOptimizer. + """ + + def __init__(self): + BaseLearningOnnx.__init__(self) + + @staticmethod + def select(class_name, **kwargs): + """ + Returns an instance of a given initialized with + *kwargs*. + :param class_name: an instance of @see cl BaseLearningPenalty + or a string among the following class names (see below) + :return: instance of @see cl BaseLearningPenalty + + Possible values for *class_name*: + * None or `'penalty'`: see @see cl L1L2PenaltyLearning + """ + if isinstance(class_name, BaseLearningPenalty): + return class_name + cls = {NoLearningPenalty: [None, ''], + ElasticLearningPenalty: ['elastic', 'l1l2']} + for cl, aliases in cls.items(): + if class_name == cl.__class__.__name__ or class_name in aliases: + return cl(**kwargs) + raise ValueError( # pragma: no cover + "Unexpected class name %r. It should be one of %r." % ( + class_name, list(map(lambda c: c.__name__, cls)))) + + def penalty_loss(self, device, loss, *weights): + """ + Returns the received loss. Updates the loss inplace. + + :param device: device where the training takes place + :param loss: loss without penalty + :param weights: any weights to be penalized + :return: loss + """ + raise NotImplementedError( + "penalty_loss must be overwritten.") + + def update_weights(self, device, statei): + """ + Returns the received loss. Updates the weight inplace. + + :param device: device where the training takes place + :param statei: loss without penalty + :return: weight + """ + raise NotImplementedError( + "update_weights must be overwritten.") + + +class NoLearningPenalty(BaseLearningPenalty): + """ + No weight penalty. + """ + + def __init__(self): + BaseLearningPenalty.__init__(self) + + def build_onnx_function(self, opset, device, n_tensors): + # Nothing to do. + pass + + def penalty_loss(self, device, loss, *weights): + """ + Returns the received loss. Updates the loss inplace. + + :param device: device where the training takes place + :param loss: loss without penalty + :param weights: any weights to be penalized + :return: loss + """ + return loss + + def update_weights(self, device, statei): + """ + Returns the received loss. Updates the weight inplace. + + :param device: device where the training takes place + :param statei: loss without penalty + :return: weight + """ + return statei + + +class ElasticLearningPenalty(BaseLearningPenalty): + """ + Implements a L1 or L2 penalty on weights. + """ + + def __init__(self, l1=0.5, l2=0.5): + BaseLearningPenalty.__init__(self) + self.l1 = l1 + self.l2 = l2 + + def build_onnx_function(self, opset, device, n_tensors): + so = SessionOptions() + so.log_severity_level = 4 + + # loss_grad + self.penalty_onnx_ = function_onnx_graph( + "n_penalty_elastic_error", target_opset=opset, n_tensors=n_tensors) + self.penalty_sess_ = InferenceSession( + self.penalty_onnx_.SerializeToString(), so, + providers=device_to_providers(device)) + self.penalty_sess_bind_ = ( + self.penalty_sess_.io_binding()._iobinding) + self.names_ = [i.name for i in self.penalty_onnx_.graph.input] + + # weight updates + self.penalty_grad_onnx_ = function_onnx_graph( + "update_penalty_elastic_error", target_opset=opset) + self.penalty_grad_sess_ = InferenceSession( + self.penalty_grad_onnx_.SerializeToString(), so, + providers=device_to_providers(device)) + self.penalty_grad_sess_bind_ = ( + self.penalty_grad_sess_.io_binding()._iobinding) + + def penalty_loss(self, device, *inputs): + """ + Computes the penalty associated to every + weights and adds them up to the loss. + + :param device: device where the training takes place + :param inputs: loss without penalty and weights + :return: loss + penatlies + """ + if (not hasattr(self, "penalty_onnx_") or + not hasattr(self, "penalty_sess_bind_")): + raise RuntimeError( # pragma: no cover + "Attributes 'penalty_sess_bind_' or 'penalty_onnx_' " + "is missing. Method 'build_onnx_function' has not been called.") + if len(self.names_) != len(inputs): + raise RuntimeError( + "Mismatched number of inputs: %d != %d." % ( + len(self.names_), len(inputs))) + + for name, inp in zip(self.names_, inputs): + self._bind_input_ortvalue( + name, self.penalty_sess_bind_, inp, device) + self._bind_output_ortvalue('Y', self.penalty_sess_bind_, inputs[0]) + self.penalty_sess_._sess.run_with_iobinding( + self.penalty_sess_bind_, None) + return self.penalty_sess_bind_.get_outputs()[0] + + def update_weights(self, device, statei): + if (not hasattr(self, "penalty_grad_onnx_") or + not hasattr(self, "penalty_grad_sess_bind_")): + raise RuntimeError( # pragma: no cover + "Attributes 'penalty_grad_sess_bind_' or " + "'penalty_grad_onnx_' is missing. Method " + "'build_onnx_function' has not been called.") + self._bind_input_ortvalue( + "X", self.penalty_grad_sess_bind_, statei, device) + self._bind_output_ortvalue('Y', self.penalty_grad_sess_bind_, statei) + self.penalty_grad_sess_._sess.run_with_iobinding( + self.penalty_grad_sess_bind_, None) + return self.penalty_grad_sess_bind_.get_outputs()[0] # X diff --git a/onnxcustom/training/sgd_learning_rate.py b/onnxcustom/training/sgd_learning_rate.py index c8a36a75..7497e5b0 100644 --- a/onnxcustom/training/sgd_learning_rate.py +++ b/onnxcustom/training/sgd_learning_rate.py @@ -210,7 +210,7 @@ def build_onnx_function(self, opset, device): def update_weights(self, device, statei, gradienti, batch_size, velocity=None): if velocity is not None: - raise RuntimeError( + raise RuntimeError( # pragma: no cover "Velocity must be None for this way of updating weights.") self._bind_input_ortvalue( "X1", self.axpy_sess_bind_, gradienti, device) @@ -320,7 +320,7 @@ def build_onnx_function(self, opset, device): def update_weights(self, device, statei, gradienti, batch_size, velocity=None): if velocity is None: - raise RuntimeError( + raise RuntimeError( # pragma: no cover "Velocity must not be None for this way of updating weights.") self._bind_input_ortvalue( "X1", self.axpyw_sess_bind_, gradienti, device) diff --git a/onnxcustom/utils/onnx_function.py b/onnxcustom/utils/onnx_function.py index 1c955918..2b15c0c2 100644 --- a/onnxcustom/utils/onnx_function.py +++ b/onnxcustom/utils/onnx_function.py @@ -20,7 +20,7 @@ def get_supported_functions(): def function_onnx_graph(name, target_opset=None, dtype=numpy.float32, - weight_name=None): + weight_name=None, **kwargs): """ Returns the ONNX graph corresponding to a function. @@ -28,6 +28,7 @@ def function_onnx_graph(name, target_opset=None, dtype=numpy.float32, :param target_opset: opset version :param dtype: computation type :param weight_name: weight name if any + :param kwargs: additional parameters :return: ONNX graph A wrong name will raise an exception giving the whole of @@ -72,9 +73,10 @@ def function_onnx_graph(name, target_opset=None, dtype=numpy.float32, full_name = "_onnx_" + name if full_name in glo: if weight_name is None: - return glo[full_name](target_opset=target_opset, dtype=dtype) + return glo[full_name](target_opset=target_opset, + dtype=dtype, **kwargs) return glo[full_name](target_opset=target_opset, dtype=dtype, - weight_name=weight_name) + weight_name=weight_name, **kwargs) raise ValueError( "Unable to find function %r in %r." % ( full_name, list(sorted( @@ -323,12 +325,44 @@ def _onnx_zero(target_opset=None, dtype=numpy.float32): return onx +def _onnx_linear_regression(target_opset=None, dtype=numpy.float32): + """ + Returns the ONNX graph for function + :math:`Y = f(X, A, B) = A X + B`. + + .. gdot:: + :script: DOT-SECTION + + from mlprodict.onnxrt import OnnxInference + from onnxcustom.utils.onnx_function import function_onnx_graph + + model_onnx = function_onnx_graph('linear_regression') + oinf = OnnxInference(model_onnx, inplace=False) + + print("DOT-SECTION", oinf.to_dot()) + """ + from skl2onnx.algebra.onnx_ops import ( + OnnxMatMul, OnnxAdd) + res = OnnxAdd( + OnnxMatMul('X', 'A', op_version=target_opset), + 'B', op_version=target_opset, output_names=['Y']) + + var_type = dtype_to_var_type(dtype) + varsx = [('X', var_type([None, None])), + ('A', var_type([None, None])), + ('B', var_type([None, None]))] + onx = res.to_onnx( + varsx, outputs=[('Y', var_type())], + target_opset=target_opset, other_outputs=[res]) + return onx + + def _onnx_grad_loss_square_error(target_opset=None, dtype=numpy.float32, weight_name=None): """ Returns the ONNX graph for function - :math:`Y = f(X1, X2) = \\lVert X1 - X2 \\rVert ^2` or - :math:`Y = f(X1, X2) = \\lVert X1 - X2 \\rVert ^2 w` if + :math:`Y = f(X1, X2) = \\lVert (X1 - X2) \\rVert ^2` or + :math:`Y = f(X1, X2) = \\lVert (w**0.5)(X1 - X2) \\rVert ^2 w` if *weight_name* is not None and its gradient. .. gdot:: @@ -337,7 +371,7 @@ def _onnx_grad_loss_square_error(target_opset=None, dtype=numpy.float32, from mlprodict.onnxrt import OnnxInference from onnxcustom.utils.onnx_function import function_onnx_graph - model_onnx = function_onnx_graph('square_error') + model_onnx = function_onnx_graph('grad_loss_square_error') oinf = OnnxInference(model_onnx, inplace=False) print("DOT-SECTION", oinf.to_dot()) @@ -379,10 +413,13 @@ def _onnx_grad_loss_square_error(target_opset=None, dtype=numpy.float32, return onx -def _onnx_linear_regression(target_opset=None, dtype=numpy.float32): +def _onnx_grad_loss_absolute_error(target_opset=None, dtype=numpy.float32, + weight_name=None): """ Returns the ONNX graph for function - :math:`Y = f(X, A, B) = A X + B`. + :math:`Y = f(X1, X2) = \\lVert X1 - X2 \\rVert` or + :math:`Y = f(X1, X2) = \\lVert (X1 - X2)w \\rVert` if + *weight_name* is not None and its gradient. .. gdot:: :script: DOT-SECTION @@ -390,22 +427,275 @@ def _onnx_linear_regression(target_opset=None, dtype=numpy.float32): from mlprodict.onnxrt import OnnxInference from onnxcustom.utils.onnx_function import function_onnx_graph - model_onnx = function_onnx_graph('linear_regression') + model_onnx = function_onnx_graph('grad_loss_absolute_error') oinf = OnnxInference(model_onnx, inplace=False) print("DOT-SECTION", oinf.to_dot()) """ from skl2onnx.algebra.onnx_ops import ( - OnnxMatMul, OnnxAdd) + OnnxSub, OnnxMul, + OnnxReduceSum, OnnxReshape, OnnxSign, OnnxAbs) + diff = OnnxSub('X1', 'X2', op_version=target_opset) + abs_diff = OnnxAbs(diff, op_version=target_opset) + if weight_name is None: + res = OnnxReduceSum(abs_diff, op_version=target_opset, + keepdims=0, output_names=['Y']) + res2 = OnnxSign(diff, op_version=target_opset, + output_names=['Z']) + else: + resh = OnnxReshape(weight_name, + numpy.array([-1, 1], dtype=numpy.int64), + op_version=target_opset) + mul = OnnxMul(abs_diff, resh, op_version=target_opset) + res = OnnxReduceSum(mul, op_version=target_opset, + keepdims=0, output_names=['Y']) + res2 = OnnxMul( + OnnxSign(diff, op_version=target_opset), + resh, op_version=target_opset, output_names=['Z']) + + var_type = dtype_to_var_type(dtype) + varsx = [('X1', var_type([None, None])), + ('X2', var_type([None, None]))] + if weight_name is not None: + varsx.append((weight_name, var_type([None]))) + onx = res.to_onnx( + varsx, outputs=[('Y', var_type()), ('Z', var_type())], + target_opset=target_opset, other_outputs=[res2]) + if weight_name is not None: + onx = add_initializer( + onx, weight_name, numpy.array([1], dtype=dtype)) + return onx + + +def _onnx_grad_loss_elastic_error(target_opset=None, dtype=numpy.float32, + weight_name=None, + l1_weight=0.01, l2_weight=0.01): + """ + Returns the ONNX graph for function + :math:`Y = f(X1, X2) = \\beta \\lVert X1 - X2 \\rVert + + \\alpha \\lVert X1 - X2 \\rVert^2` or + :math:`Y = f(X1, X2) = \\beta \\lVert w(X1 - X2) \\rVert + + \\alpha \\lVert (w**0.5)(X1 - X2) \\rVert^2` if + *weight_name* is not None and its gradient. + *l1_weight* is :math:`\\beta` and + *l2_weight* is :math:`\\alpha`. + + .. gdot:: + :script: DOT-SECTION + + from mlprodict.onnxrt import OnnxInference + from onnxcustom.utils.onnx_function import function_onnx_graph + + model_onnx = function_onnx_graph('grad_loss_elastic_error') + oinf = OnnxInference(model_onnx, inplace=False) + + print("DOT-SECTION", oinf.to_dot()) + """ + from skl2onnx.algebra.onnx_ops import ( + OnnxSub, OnnxMul, OnnxAdd, OnnxReduceSumSquare, + OnnxReduceSum, OnnxReshape, OnnxSign, OnnxAbs) + diff = OnnxSub('X1', 'X2', op_version=target_opset) + abs_diff = OnnxAbs(diff, op_version=target_opset) + if weight_name is None: + res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset, + keepdims=0) + res2_l1 = OnnxSign(diff, op_version=target_opset) + res_l2 = OnnxReduceSumSquare(diff, op_version=target_opset, + keepdims=0) + res2_l2 = diff + else: + resh = OnnxReshape(weight_name, + numpy.array([-1, 1], dtype=numpy.int64), + op_version=target_opset) + mul = OnnxMul(abs_diff, resh, op_version=target_opset) + res_l1 = OnnxReduceSum(mul, op_version=target_opset, keepdims=0) + res2_l1 = OnnxMul( + OnnxSign(diff, op_version=target_opset), + resh, op_version=target_opset) + + mul = OnnxMul( + OnnxMul(diff, diff, op_version=target_opset), + resh, op_version=target_opset) + res_l2 = OnnxReduceSum(mul, op_version=target_opset) + res2_l2 = OnnxMul(diff, resh, op_version=target_opset) + res = OnnxAdd( - OnnxMatMul('X', 'A', op_version=target_opset), - 'B', op_version=target_opset, output_names=['Y']) + OnnxMul(res_l1, numpy.array([l1_weight], dtype=dtype), + op_version=target_opset), + OnnxMul(res_l2, numpy.array([l2_weight], dtype=dtype), + op_version=target_opset), + op_version=target_opset, output_names=['Y']) + + res2 = OnnxAdd( + OnnxMul(res2_l1, numpy.array([l1_weight], dtype=dtype), + op_version=target_opset), + OnnxMul(res2_l2, numpy.array([l2_weight * (-2)], dtype=dtype), + op_version=target_opset), + op_version=target_opset, output_names=['Z']) var_type = dtype_to_var_type(dtype) - varsx = [('X', var_type([None, None])), - ('A', var_type([None, None])), - ('B', var_type([None, None]))] + varsx = [('X1', var_type([None, None])), + ('X2', var_type([None, None]))] + if weight_name is not None: + varsx.append((weight_name, var_type([None]))) + onx = res.to_onnx( + varsx, outputs=[('Y', var_type()), ('Z', var_type())], + target_opset=target_opset, other_outputs=[res2]) + if weight_name is not None: + onx = add_initializer( + onx, weight_name, numpy.array([1], dtype=dtype)) + return onx + + +def _onnx_grad_penalty_elastic_error(target_opset=None, dtype=numpy.float32, + l1_weight=0.01, l2_weight=0.01): + """ + Returns the ONNX graph for function + :math:`Y = f(W) = \\beta \\lVert W \\rVert + + \\alpha \\lVert W \\rVert^2` + *l1_weight* is :math:`\\beta` and + *l2_weight* is :math:`\\alpha`. + + .. gdot:: + :script: DOT-SECTION + + from mlprodict.onnxrt import OnnxInference + from onnxcustom.utils.onnx_function import function_onnx_graph + + model_onnx = function_onnx_graph('grad_penalty_elastic_error') + oinf = OnnxInference(model_onnx, inplace=False) + + print("DOT-SECTION", oinf.to_dot()) + """ + from skl2onnx.algebra.onnx_ops import ( + OnnxMul, OnnxAdd, OnnxReduceSumSquare, + OnnxReduceSum, OnnxSign, OnnxAbs) + diff = 'X' + abs_diff = OnnxAbs(diff, op_version=target_opset) + res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset, + keepdims=0) + res2_l1 = OnnxSign(diff, op_version=target_opset) + res_l2 = OnnxReduceSumSquare(diff, op_version=target_opset, + keepdims=0) + res2_l2 = diff + + res = OnnxAdd( + OnnxMul(res_l1, numpy.array([l1_weight], dtype=dtype), + op_version=target_opset), + OnnxMul(res_l2, numpy.array([l2_weight], dtype=dtype), + op_version=target_opset), + op_version=target_opset, output_names=['Y']) + + res2 = OnnxAdd( + OnnxMul(res2_l1, numpy.array([l1_weight], dtype=dtype), + op_version=target_opset), + OnnxMul(res2_l2, numpy.array([l2_weight * (2)], dtype=dtype), + op_version=target_opset), + op_version=target_opset, output_names=['Z']) + + var_type = dtype_to_var_type(dtype) + varsx = [('X', var_type([None, None]))] + onx = res.to_onnx( + varsx, outputs=[('Y', var_type()), ('Z', var_type())], + target_opset=target_opset, other_outputs=[res2]) + return onx + + +def _onnx_n_penalty_elastic_error(target_opset=None, dtype=numpy.float32, + weight_name=None, + l1_weight=0.01, l2_weight=0.01, n_tensors=1): + """ + Returns the ONNX graph for function + :math:`Y = f(W) = \\beta \\lVert W \\rVert + + \\alpha \\lVert W \\rVert^2` + *l1_weight* is :math:`\\beta` and + *l2_weight* is :math:`\\alpha`. + It does that for *n_tensors* and adds all of the results + to an input loss. + + .. gdot:: + :script: DOT-SECTION + + from mlprodict.onnxrt import OnnxInference + from onnxcustom.utils.onnx_function import function_onnx_graph + + model_onnx = function_onnx_graph( + 'n_penalty_elastic_error', n_tensors=2) + oinf = OnnxInference(model_onnx, inplace=False) + + print("DOT-SECTION", oinf.to_dot()) + """ + from skl2onnx.algebra.onnx_ops import ( + OnnxMul, OnnxAdd, OnnxReduceSumSquare, + OnnxReduceSum, OnnxAbs, OnnxSum) + + if n_tensors <= 0: + raise ValueError( + "This function is useless if the number of tensors is null.") + + var_type = dtype_to_var_type(dtype) + varsx = [('loss', var_type([1, 1]))] + names = ['loss'] + for n in range(n_tensors): + name = 'W%d' % n + abs_diff = OnnxAbs(name, op_version=target_opset) + res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset, + keepdims=0) + # res2_l1 = OnnxSign(diff, op_version=target_opset) + res_l2 = OnnxReduceSumSquare(name, op_version=target_opset, + keepdims=0) + # res2_l2 = diff + res = OnnxAdd( + OnnxMul(res_l1, numpy.array([l1_weight], dtype=dtype), + op_version=target_opset), + OnnxMul(res_l2, numpy.array([l2_weight], dtype=dtype), + op_version=target_opset), + op_version=target_opset) + names.append(res) + varsx.append(('W%d' % n, var_type())) + + res = OnnxSum(*names, op_version=target_opset, output_names=['Y']) onx = res.to_onnx( varsx, outputs=[('Y', var_type())], - target_opset=target_opset, other_outputs=[res]) + target_opset=target_opset) + return onx + + +def _onnx_update_penalty_elastic_error(target_opset=None, dtype=numpy.float32, + l1=0.01, l2=0.01): + """ + Returns the ONNX graph for function + :math:`Y = f(W) = W - 2 \\beta W + - \\alpha sign(W)` + *l1* is :math:`\\beta` and + *l2* is :math:`\\alpha`. + + .. gdot:: + :script: DOT-SECTION + + from mlprodict.onnxrt import OnnxInference + from onnxcustom.utils.onnx_function import function_onnx_graph + + model_onnx = function_onnx_graph( + 'update_penalty_elastic_error') + oinf = OnnxInference(model_onnx, inplace=False) + + print("DOT-SECTION", oinf.to_dot()) + """ + from skl2onnx.algebra.onnx_ops import ( + OnnxSub, OnnxMul, OnnxSign) + + res = OnnxSub( + OnnxMul('X', numpy.array([1 - 2 * l2], dtype=dtype), + op_version=target_opset), + OnnxMul(OnnxSign('X', op_version=target_opset), + numpy.array([l1], dtype=dtype), + op_version=target_opset), + op_version=target_opset, + output_names=['Y']) + + var_type = dtype_to_var_type(dtype) + varsx = [('X', var_type())] + onx = res.to_onnx( + varsx, outputs=[('Y', var_type())], + target_opset=target_opset) return onx diff --git a/onnxcustom/utils/onnxruntime_helper.py b/onnxcustom/utils/onnxruntime_helper.py index 812ef68b..433310dd 100644 --- a/onnxcustom/utils/onnxruntime_helper.py +++ b/onnxcustom/utils/onnxruntime_helper.py @@ -87,7 +87,7 @@ def get_ort_device(device): C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) raise ValueError( "Unable to interpret string %r as a device." % device) - raise TypeError( + raise TypeError( # pragma: no cover "Unable to interpret type %r, (%r) as de device." % ( type(device), device))