diff --git a/_doc/examples/plot_orttraining_nn_gpu_fwbw_nesterov.py b/_doc/examples/plot_orttraining_nn_gpu_fwbw_nesterov.py index 87147305..9b4eda3a 100644 --- a/_doc/examples/plot_orttraining_nn_gpu_fwbw_nesterov.py +++ b/_doc/examples/plot_orttraining_nn_gpu_fwbw_nesterov.py @@ -2,8 +2,8 @@ .. _l-orttraining-nn-gpu-fwbw-nesterov: -Forward backward on a neural network on GPU (Nesterov) -====================================================== +Forward backward on a neural network on GPU (Nesterov) and penalty +================================================================== This example does the same as :ref:`l-orttraining-nn-gpu-fwbw` but updates the weights using `Nesterov momentum @@ -18,6 +18,7 @@ """ import warnings import numpy +import onnx from pandas import DataFrame from onnxruntime import get_device from sklearn.datasets import make_regression @@ -26,11 +27,13 @@ from sklearn.metrics import mean_squared_error from onnxcustom.plotting.plotting_onnx import plot_onnxs from mlprodict.onnx_conv import to_onnx +from mlprodict.plotting.text_plot import onnx_simple_text_plot from onnxcustom.utils.orttraining_helper import get_train_initializer from onnxcustom.utils.onnx_helper import onnx_rename_weights from onnxcustom.training.optimizers_partial import ( OrtGradientForwardBackwardOptimizer) from onnxcustom.training.sgd_learning_rate import LearningRateSGDNesterov +from onnxcustom.training.sgd_learning_penalty import ElasticLearningPenalty X, y = make_regression(1000, n_features=10, bias=2) @@ -92,11 +95,76 @@ df = DataFrame({'ort losses': train_session.train_losses_, 'skl losses:': nn.loss_curve_}) -df.plot(title="Train loss against iterations", logy=True) +df.plot(title="Train loss against iterations (Nesterov)", logy=True) ############################################## # The convergence rate is different but both classes # do not update the learning exactly the same way. +############################################## +# Penalty +# +++++++ +# +# Default parameters for MLPRegressor suggest to penalize weights +# during training: `alpha=1e-4`. + +nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=100, + solver='sgd', learning_rate_init=5e-5, + n_iter_no_change=1000, batch_size=10, alpha=1e-4, + momentum=0.9, nesterovs_momentum=True) + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + nn.fit(X_train, y_train) + +print(nn.loss_curve_) + +################################################ +# Let's do the same with onnxruntime. + +train_session = OrtGradientForwardBackwardOptimizer( + onx, device=device, verbose=1, + learning_rate=LearningRateSGDNesterov(1e-4, nesterov=True, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4), + warm_start=False, max_iter=100, batch_size=10) +train_session.fit(X, y) + +######################################### +# Let's see the weights. + +state_tensors = train_session.get_state() + +########################################## +# And the loss. + +print(train_session.train_losses_) + +df = DataFrame({'ort losses': train_session.train_losses_, + 'skl losses:': nn.loss_curve_}) +df.plot(title="Train loss against iterations (Nesterov + penalty)", logy=True) + +########################################### +# All ONNX graphs +# +++++++++++++++ +# +# Method Method :meth:`save_onnx_graph +# ` +# can export all the ONNX graph used by the model on disk. + + +def print_graph(d): + for k, v in sorted(d.items()): + if isinstance(v, dict): + print_graph(v) + else: + print("\n++++++", v.replace("\\", "/"), "\n") + with open(v, "rb") as f: + print(onnx_simple_text_plot(onnx.load(f))) + + +all_files = train_session.save_onnx_graph('.') +print_graph(all_files) + + # import matplotlib.pyplot as plt # plt.show() diff --git a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst index b424432a..9deb14db 100644 --- a/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst +++ b/_doc/sphinxdoc/source/tutorial_training/tutorial_6_training_partial.rst @@ -107,6 +107,10 @@ And train losses: losses = train_session.train_losses_ +Method :meth:`save_onnx_graph +` +exports all graphs used by a model. It can be saved on disk +or just serialized in memory. Next examples show that in practice. Cache diff --git a/_unittests/ut_training/test_optimizers_forward_backward.py b/_unittests/ut_training/test_optimizers_forward_backward.py index 57a68b8b..aa3034c1 100644 --- a/_unittests/ut_training/test_optimizers_forward_backward.py +++ b/_unittests/ut_training/test_optimizers_forward_backward.py @@ -6,8 +6,11 @@ import io import pickle import logging -from pyquickhelper.pycode import ExtTestCase, ignore_warnings, skipif_appveyor +from pyquickhelper.pycode import ( + ExtTestCase, ignore_warnings, skipif_appveyor, + get_temp_folder) import numpy +import onnx from onnx.helper import set_model_props from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split @@ -1105,12 +1108,139 @@ def test_ort_gradient_optimizers_use_numpy_pickle_w_nesterov_rate(self): self.assertGreater(len(losses), 1) self.assertFalse(any(map(numpy.isnan, losses))) + @unittest.skipIf(TrainingSession is None, reason="not training") + def test_ort_gradient_optimizers_nesterov_penalty_l2(self): + from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=10, bias=2, random_state=0) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + w = (numpy.random.rand(y.shape[0]) + 1).astype(X.dtype) + X_train, _, y_train, __, w_train, ___ = train_test_split( + X, y, w) + reg = LinearRegression() + reg.fit(X_train, y_train, w_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + onx_model = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + set_model_props(onx_model, {'info': 'unit test'}) + inits = ['coef', 'intercept'] + + train_session = OrtGradientForwardBackwardOptimizer( + onx_model, inits, + learning_rate=LearningRateSGDNesterov( + 1e-4, nesterov=True, momentum=0.85), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4), + warm_start=False, max_iter=100, batch_size=10) + + temp = get_temp_folder( + __file__, "temp_ort_gradient_optimizers_nesterov_penalty_l2") + + saved = train_session.save_onnx_graph(temp) + saved_bytes = train_session.save_onnx_graph(bytes) + self.assertIsInstance(saved, dict) + self.assertNotEmpty(saved) + self.assertEqual(len(saved), len(saved_bytes)) + checked = [] + for k, v in saved_bytes.items(): + if k == "learning_penalty": + for att, onxb in v.items(): + if att in ('penalty_grad_onnx_', 'penalty_onnx_'): + onx = onnx.load(io.BytesIO(onxb)) + for init in onx.graph.initializer: # pylint: disable=E1101 + vals = init.float_data + if len(vals) == 1 and vals[0] == 0: + checked.append((k, att)) + if len(checked) != 2: + raise AssertionError("Unexpected parameter %r." % checked) + train_session.fit(X, y) + + train_session = OrtGradientForwardBackwardOptimizer( + onx_model, inits, weight_name='weight', + learning_rate=LearningRateSGDNesterov( + 1e-4, nesterov=True, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4), + warm_start=False, max_iter=100, batch_size=10) + temp = get_temp_folder( + __file__, "temp_ort_gradient_optimizers_nesterov_penalty_l2_weight") + train_session.save_onnx_graph(temp) + train_session.fit(X, y, w) + + @unittest.skipIf(TrainingSession is None, reason="not training") + def test_ort_gradient_optimizers_nesterov_penalty_l1l2(self): + from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=10, bias=2, random_state=0) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + w = (numpy.random.rand(y.shape[0]) + 1).astype(X.dtype) + X_train, _, y_train, __, w_train, ___ = train_test_split( + X, y, w) + reg = LinearRegression() + reg.fit(X_train, y_train, w_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + onx = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + set_model_props(onx, {'info': 'unit test'}) + inits = ['coef', 'intercept'] + + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, + learning_rate=LearningRateSGDNesterov( + 1e-4, nesterov=True, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=1e-3, l2=1e-4), + warm_start=False, max_iter=100, batch_size=10) + train_session.fit(X, y) + + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, weight_name='weight', + learning_rate=LearningRateSGDNesterov( + 1e-4, nesterov=True, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=1e-3, l2=1e-4), + warm_start=False, max_iter=100, batch_size=10) + train_session.fit(X, y, w) + + @unittest.skipIf(TrainingSession is None, reason="not training") + def test_ort_gradient_optimizers_nesterov_penalty_l1l2_no(self): + from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=10, bias=2, random_state=0) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + w = (numpy.random.rand(y.shape[0]) + 1).astype(X.dtype) + X_train, _, y_train, __, w_train, ___ = train_test_split( + X, y, w) + reg = LinearRegression() + reg.fit(X_train, y_train, w_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + onx = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + set_model_props(onx, {'info': 'unit test'}) + inits = ['coef', 'intercept'] + + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, + learning_rate=LearningRateSGDNesterov( + 1e-4, nesterov=False, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=1e-3, l2=1e-4), + warm_start=False, max_iter=100, batch_size=10) + train_session.fit(X, y) + + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, weight_name='weight', + learning_rate=LearningRateSGDNesterov( + 1e-4, nesterov=False, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=1e-3, l2=1e-4), + warm_start=False, max_iter=100, batch_size=10) + train_session.fit(X, y, w) + if __name__ == "__main__": # import logging # logger = logging.getLogger('onnxcustom') # logger.setLevel(logging.DEBUG) # logging.basicConfig(level=logging.DEBUG) - # TestOptimizersForwardBackward().test_ort_gradient_optimizers_optimal_use_ort_w_elastic_penalty() + # cl = TestOptimizersForwardBackward() + # cl.test_ort_gradient_optimizers_nesterov_penalty_l2() # stop unittest.main() diff --git a/_unittests/ut_utils/test_onnx_function.py b/_unittests/ut_utils/test_onnx_function.py index 2758289f..995cc9aa 100644 --- a/_unittests/ut_utils/test_onnx_function.py +++ b/_unittests/ut_utils/test_onnx_function.py @@ -350,7 +350,7 @@ def fct(x): oinf = OnnxInference(onx) got = oinf.run({'loss': loss, 'W0': w1, 'W1': w2}) - self.assertEqualArray(exp_loss, got['Y'], decimal=5) + self.assertEqualArray(exp_loss.reshape((-1, )), got['Y'], decimal=5) providers = device_to_providers('cpu') so = SessionOptions() @@ -358,7 +358,7 @@ def fct(x): sess = InferenceSession( onx.SerializeToString(), so, providers=providers) got = sess.run(None, {'loss': loss, 'W0': w1, 'W1': w2}) - self.assertEqualArray(exp_loss, got[0], decimal=5) + self.assertEqualArray(exp_loss.reshape((-1, )), got[0], decimal=5) def test_penalty_update(self): x = numpy.random.randn(10, 1).astype(numpy.float32) diff --git a/onnxcustom/training/_base.py b/onnxcustom/training/_base.py new file mode 100644 index 00000000..a73d1b53 --- /dev/null +++ b/onnxcustom/training/_base.py @@ -0,0 +1,139 @@ +""" +@file +@brief Base class for @see cl BaseEstimator and @see cl BaseOnnxFunction. +""" +import os +import inspect +import warnings + + +class BaseOnnxClass: + """ + Bases class with common functions to handle attributes + in classes owning ONNX graphs. + """ + + @classmethod + def _get_param_names(cls): + "Extracts all parameters to serialize." + init = getattr(cls.__init__, "deprecated_original", cls.__init__) + init_signature = inspect.signature(init) + parameters = [ + p for p in init_signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD] + return [(p.name, p.default) for p in parameters] + + def save_onnx_graph(self, folder, prefix=None, suffix=None): + """ + Saves all ONNX files stored in this class. + + :param folder: folder where to save (it must exists) or + ``bytes`` if the onnx graph must be returned as bytes, + not files + :param prefix: suffix to add to the name + :param suffix: suffix to add to the name + :return: list of saved files (dictionary + `{ attribute: filename or dictionary }`) + + The function raises a warning if a file already exists. + The function uses class name, attribute names to compose + file names. It shortens them for frequent classes. + + * 'Learning' -> 'L' + * 'OrtGradient' -> 'Grad' + * 'ForwardBackward' -> 'FB' + + .. runpython:: + :showcode: + + import io + import numpy + import onnx + from sklearn.datasets import make_regression + from sklearn.model_selection import train_test_split + from sklearn.linear_model import LinearRegression + from skl2onnx import to_onnx + from mlprodict.plotting.text_plot import onnx_simple_text_plot + from onnxcustom.training.optimizers_partial import ( + OrtGradientForwardBackwardOptimizer) + from onnxcustom.training.sgd_learning_rate import ( + LearningRateSGDNesterov) + from onnxcustom.training.sgd_learning_penalty import ( + ElasticLearningPenalty) + + + def walk_through(obj, prefix="", only_name=True): + for k, v in obj.items(): + if isinstance(v, dict): + p = prefix + "." + k if prefix else k + walk_through(v, prefix=p, only_name=only_name) + elif only_name: + name = "%s.%s" % (prefix, k) if prefix else k + print('+', name) + else: + name = "%s.%s" % (prefix, k) if prefix else k + print('\n++++++', name) + print() + bf = io.BytesIO(v) + onx = onnx.load(bf) + print(onnx_simple_text_plot(onx)) + + + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=3, bias=2, random_state=0) + X = X.astype(numpy.float32) + y = y.astype(numpy.float32) + X_train, _, y_train, __ = train_test_split(X, y) + reg = LinearRegression() + reg.fit(X_train, y_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + opset = 15 + onx = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + inits = ['coef', 'intercept'] + + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, + learning_rate=LearningRateSGDNesterov( + 1e-4, nesterov=False, momentum=0.9), + learning_penalty=ElasticLearningPenalty(l1=1e-3, l2=1e-4), + warm_start=False, max_iter=100, batch_size=10) + + onxs = train_session.save_onnx_graph(bytes) + + print("+ all onnx graphs") + walk_through(onxs, only_name=True) + walk_through(onxs, only_name=False) + """ + repls = {'Learning': 'L', 'OrtGradient': 'Grad', + 'ForwardBackward': 'FB'} + if folder is None: + return None # pragma: no cover + if prefix is None: + prefix = '' + if suffix is None: + suffix = '' + if isinstance(folder, str) and not os.path.exists(folder): + raise FileNotFoundError( # pragma: no cover + "Folder %r does not exist." % folder) + saved = {} + for k, v in self.__dict__.items(): + if hasattr(v, "SerializeToString"): + if isinstance(folder, str): + name = "%s%s%s.%s.onnx" % ( + prefix, self.__class__.__name__, suffix, k) + for a, b in repls.items(): + name = name.replace(a, b) + filename = os.path.join(folder, name) + if os.path.exists(filename): + warnings.warn( # pragma: no cover + "Filename %r already exists." % filename) + with open(filename, "wb") as f: + f.write(v.SerializeToString()) + saved[k] = filename + else: + saved[k] = v.SerializeToString() + elif hasattr(v, "save_onnx_graph"): + saved[k] = v.save_onnx_graph( + folder, prefix=prefix, suffix="%s.%s" % (suffix, k)) + return saved diff --git a/onnxcustom/training/base_estimator.py b/onnxcustom/training/base_estimator.py index f54461af..bc26a76d 100644 --- a/onnxcustom/training/base_estimator.py +++ b/onnxcustom/training/base_estimator.py @@ -7,10 +7,12 @@ OrtDevice as C_OrtDevice) from ..utils.onnxruntime_helper import ( get_ort_device, ort_device_to_string) +from ._base import BaseOnnxClass from .sgd_learning_rate import BaseLearningRate +from .base_onnx_function import BaseLearningOnnx -class BaseEstimator: +class BaseEstimator(BaseOnnxClass): """ Base class for optimizers. Implements common methods such `__repr__`. @@ -43,7 +45,7 @@ def __repr__(self): if k not in self.__dict__: continue # pragma: no cover ov = getattr(self, k) - if isinstance(ov, BaseLearningRate): + if isinstance(ov, BaseLearningOnnx): ps.append("%s=%s" % (k, repr(ov))) elif isinstance(ov, C_OrtDevice): ps.append("%s=%r" % (k, ort_device_to_string(ov))) diff --git a/onnxcustom/training/base_onnx_function.py b/onnxcustom/training/base_onnx_function.py index bed3dd27..2356f524 100644 --- a/onnxcustom/training/base_onnx_function.py +++ b/onnxcustom/training/base_onnx_function.py @@ -7,14 +7,15 @@ from io import BytesIO import numpy import onnx -from onnxruntime import SessionOptions, InferenceSession +from onnxruntime import SessionOptions, InferenceSession, RunOptions from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 OrtValue as C_OrtValue) from ..utils.onnxruntime_helper import ort_device_to_string from .excs import ProviderError +from ._base import BaseOnnxClass -class BaseLearningOnnx: +class BaseLearningOnnx(BaseOnnxClass): """ Class handling ONNX function to manipulate OrtValue. Base class for @see cl BaseLearningRate and @@ -31,6 +32,8 @@ def __getstate__(self): """ atts = [k for k in self.__dict__ if not k.endswith('_')] state = {k: getattr(self, k) for k in atts} + if hasattr(self, 'ro_'): + state['ro_'] = True onx = [k for k in self.__dict__ if k.endswith('_onnx_')] for o in onx: state[o] = getattr(self, o).SerializeToString() @@ -50,7 +53,9 @@ def __setstate__(self, state): Overwrites getstate to get rid of InferenceSession. """ for k, v in state.items(): - if not k.endswith('_onnx_') and not k.endswith('_sess_'): + if k == 'ro_': + self.ro_ = RunOptions() + elif not k.endswith('_onnx_') and not k.endswith('_sess_'): setattr(self, k, v) so = SessionOptions() @@ -219,12 +224,3 @@ def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False): raise TypeError( # pragma: no cover "Unable to bind type %r for name %r." % ( type(c_ortvalue), name)) - - @classmethod - def _get_param_names(cls): - init = getattr(cls.__init__, "deprecated_original", cls.__init__) - init_signature = inspect.signature(init) - parameters = [ - p for p in init_signature.parameters.values() - if p.name != "self" and p.kind != p.VAR_KEYWORD] - return [(p.name, p.default) for p in parameters] diff --git a/onnxcustom/training/optimizers_partial.py b/onnxcustom/training/optimizers_partial.py index 2774484f..eec6b38d 100644 --- a/onnxcustom/training/optimizers_partial.py +++ b/onnxcustom/training/optimizers_partial.py @@ -418,10 +418,19 @@ def _iteration(self, data_loader, states, n_weights): # loss loss, loss_gradient = self.learning_loss.loss_gradient( self.device, orty, prediction[0], weight=ortw) + + if logger is not None: + logger.debug( + "[OrtGradientForwardBackwardOptimizer._iteration] " + "loss=%g has_weight=%r", + loss.numpy(), ortw is not None) + n = len(state) - n_weights loss = self.learning_penalty.penalty_loss( self.device, loss, *state[n:]) + cpu_loss = loss.numpy() + if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss): raise ConvergenceError( "Loss is nan, learning_rate=%r, " @@ -445,6 +454,7 @@ def _iteration(self, data_loader, states, n_weights): "%r != %r." % (len(gradient), len(state))) n = len(state) - n_weights + for i in range(n, len(state)): self.learning_penalty.update_weights( i - n, self.device, state[i]) @@ -456,7 +466,7 @@ def _iteration(self, data_loader, states, n_weights): if logger is not None: logger.debug( "[OrtGradientForwardBackwardOptimizer._iteration] " - "loss=%g", cpu_loss) + "loss=%g n_weights=%d", cpu_loss, n) for i in range(n, len(state)): logger.debug( "[OrtGradientForwardBackwardOptimizer._iteration] " diff --git a/onnxcustom/training/ortgradient.py b/onnxcustom/training/ortgradient.py index 32a54b4a..717671b4 100644 --- a/onnxcustom/training/ortgradient.py +++ b/onnxcustom/training/ortgradient.py @@ -122,6 +122,9 @@ def __init__(self, onnx_model, weights_to_train=None, "You shoud use function onnx_rename_weights to do that " "before calling this class." % self.weights_to_train) set_weights = set(self.weights_to_train) + if len(set_weights) != len(self.weights_to_train): + raise ValueError( # pragma: no cover + "One weight is not unique in %r." % self.weights_to_train) found = [] for i in self.onnx_model.graph.initializer: if i.name not in set_weights: @@ -130,7 +133,9 @@ def __init__(self, onnx_model, weights_to_train=None, if len(found) != len(self.weights_to_train): raise ValueError( "One weight name in self.weights_to_train was not found in " - "the initializers %r." % (self.weights_to_train, )) + "the initializers %r found=%r init names=%r." % ( + self.weights_to_train, found, + [i.name for i in self.onnx_model.graph.initializer])) if found != self.weights_to_train: raise ValueError( "List of weights to train must be sorted and follow the " diff --git a/onnxcustom/training/sgd_learning_loss.py b/onnxcustom/training/sgd_learning_loss.py index 6da86304..2cbdb9eb 100644 --- a/onnxcustom/training/sgd_learning_loss.py +++ b/onnxcustom/training/sgd_learning_loss.py @@ -3,7 +3,7 @@ @file @brief Helper for :epkg:`onnxruntime-training`. """ -from onnxruntime import SessionOptions, InferenceSession +from onnxruntime import SessionOptions, InferenceSession, RunOptions from ..utils.onnx_function import function_onnx_graph from ..utils.onnxruntime_helper import device_to_providers from .base_onnx_function import BaseLearningOnnx @@ -22,6 +22,10 @@ class BaseLearningLoss(BaseLearningOnnx): def __init__(self): BaseLearningOnnx.__init__(self) + self.ro_ = RunOptions() + + def _call_iobinding(self, sess, bind): + sess.run_with_iobinding(bind, self.ro_) def loss_gradient( # pylint: disable=E1101 self, device, expected, predicted, weight=None): @@ -50,7 +54,7 @@ def loss_gradient( # pylint: disable=E1101 self._bind_input_ortvalue("X2", bind, predicted, device, cache=True) self.loss_grad_sess_bind_.bind_output('Y', device) self.loss_grad_sess_bind_.bind_output('Z', device) - self.loss_grad_sess_._sess.run_with_iobinding(bind, None) + self._call_iobinding(self.loss_grad_sess_._sess, bind) loss, grad = bind.get_outputs() return loss, grad diff --git a/onnxcustom/training/sgd_learning_penalty.py b/onnxcustom/training/sgd_learning_penalty.py index 3be6ed92..3d8bc151 100644 --- a/onnxcustom/training/sgd_learning_penalty.py +++ b/onnxcustom/training/sgd_learning_penalty.py @@ -3,7 +3,7 @@ @file @brief Helper for :epkg:`onnxruntime-training`. """ -from onnxruntime import SessionOptions, InferenceSession +from onnxruntime import SessionOptions, InferenceSession, RunOptions from ..utils.onnx_function import function_onnx_graph from ..utils.onnxruntime_helper import device_to_providers from .base_onnx_function import BaseLearningOnnx @@ -17,6 +17,10 @@ class BaseLearningPenalty(BaseLearningOnnx): def __init__(self): BaseLearningOnnx.__init__(self) + self.ro_ = RunOptions() + + def _call_iobinding(self, sess, bind): + sess.run_with_iobinding(bind, self.ro_) @staticmethod def select(class_name, **kwargs): @@ -115,7 +119,8 @@ def build_onnx_function(self, opset, device, n_tensors): # loss_grad self.penalty_onnx_ = function_onnx_graph( - "n_penalty_elastic_error", target_opset=opset, n_tensors=n_tensors) + "n_penalty_elastic_error", target_opset=opset, n_tensors=n_tensors, + loss_shape=None, l1_weight=self.l1, l2_weight=self.l2) self.penalty_sess_ = InferenceSession( self.penalty_onnx_.SerializeToString(), so, providers=device_to_providers(device)) @@ -125,7 +130,8 @@ def build_onnx_function(self, opset, device, n_tensors): # weight updates self.penalty_grad_onnx_ = function_onnx_graph( - "update_penalty_elastic_error", target_opset=opset) + "update_penalty_elastic_error", target_opset=opset, + l1=self.l1, l2=self.l2) self.penalty_grad_sess_ = InferenceSession( self.penalty_grad_onnx_.SerializeToString(), so, providers=device_to_providers(device)) @@ -157,8 +163,7 @@ def penalty_loss(self, device, *inputs): name, self.penalty_sess_bind_, inp, device, cache=True) self._bind_output_ortvalue( 'Y', self.penalty_sess_bind_, inputs[0], cache=True) - self.penalty_sess_._sess.run_with_iobinding( - self.penalty_sess_bind_, None) + self._call_iobinding(self.penalty_sess_._sess, self.penalty_sess_bind_) return self.penalty_sess_bind_.get_outputs()[0] def update_weights(self, n_bind, device, statei): @@ -171,5 +176,5 @@ def update_weights(self, n_bind, device, statei): bind = self.penalty_grad_sess_binds_[n_bind] self._bind_input_ortvalue("X", bind, statei, device, cache=True) self._bind_output_ortvalue('Y', bind, statei, cache=True) - self.penalty_grad_sess_._sess.run_with_iobinding(bind, None) + self._call_iobinding(self.penalty_grad_sess_._sess, bind) return bind.get_outputs()[0] # X diff --git a/onnxcustom/training/sgd_learning_rate.py b/onnxcustom/training/sgd_learning_rate.py index 65330aab..ec2db721 100644 --- a/onnxcustom/training/sgd_learning_rate.py +++ b/onnxcustom/training/sgd_learning_rate.py @@ -5,7 +5,7 @@ """ import numpy from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE -from onnxruntime import SessionOptions, InferenceSession +from onnxruntime import SessionOptions, InferenceSession, RunOptions from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 OrtValue as C_OrtValue) from ..utils.onnx_function import function_onnx_graph @@ -23,6 +23,10 @@ class BaseLearningRate(BaseLearningOnnx): def __init__(self): BaseLearningOnnx.__init__(self) + self.ro_ = RunOptions() + + def _call_iobinding(self, sess, bind): + sess.run_with_iobinding(bind, self.ro_) def init_learning_rate(self): """ @@ -231,7 +235,7 @@ def update_weights(self, n_bind, device, statei, gradienti, batch_size, ort_alpha = C_OrtValue.ortvalue_from_numpy(self.alpha_, device) self._bind_input_ortvalue("alpha", bind, ort_alpha, device, cache=True) self._bind_output_ortvalue('Y', bind, statei, cache=True) - self.axpy_sess_._sess.run_with_iobinding(bind, None) + self._call_iobinding(self.axpy_sess_._sess, bind) loss = bind.get_outputs()[0] return loss @@ -358,5 +362,5 @@ def update_weights(self, n_bind, device, statei, gradienti, batch_size, self._bind_input_ortvalue("beta", bind, ort_beta, device, cache=True) self._bind_output_ortvalue('Y', bind, statei, cache=True) self._bind_output_ortvalue('Z', bind, velocity, cache=True) - self.axpyw_sess_._sess.run_with_iobinding(bind, None) + self._call_iobinding(self.axpyw_sess_._sess, bind) return bind.get_outputs() # loss, velocity diff --git a/onnxcustom/utils/onnx_function.py b/onnxcustom/utils/onnx_function.py index 8164eeea..7306143f 100644 --- a/onnxcustom/utils/onnx_function.py +++ b/onnxcustom/utils/onnx_function.py @@ -204,8 +204,7 @@ def _onnx_square_error(target_opset=None, dtype=numpy.float32, OnnxReduceSum, OnnxMul) diff = OnnxSub('X1', 'X2', op_version=target_opset) if weight_name is None: - res = OnnxReduceSumSquare(diff, op_version=target_opset, - keepdims=0, output_names=['Y']) + res = OnnxReduceSumSquare(diff, op_version=target_opset) else: mul = OnnxMul( OnnxMul(diff, diff, op_version=target_opset), @@ -213,8 +212,10 @@ def _onnx_square_error(target_opset=None, dtype=numpy.float32, numpy.array([-1, 1], dtype=numpy.int64), op_version=target_opset), op_version=target_opset) - res = OnnxReduceSum(mul, op_version=target_opset, - keepdims=0, output_names=['Y']) + res = OnnxReduceSum(mul, op_version=target_opset) + res = OnnxReshape(res, numpy.array([-1], numpy.int64), + op_version=target_opset, + output_names=['Y']) var_type = dtype_to_var_type(dtype) varsx = [('X1', var_type([None, None])), ('X2', var_type([None, None]))] @@ -381,8 +382,7 @@ def _onnx_grad_loss_square_error(target_opset=None, dtype=numpy.float32, OnnxReduceSum, OnnxReshape) diff = OnnxSub('X1', 'X2', op_version=target_opset) if weight_name is None: - res = OnnxReduceSumSquare(diff, op_version=target_opset, - keepdims=0, output_names=['Y']) + res = OnnxReduceSumSquare(diff, op_version=target_opset) res2 = OnnxMul(diff, numpy.array([-2], dtype=dtype), op_version=target_opset, output_names=['Z']) else: @@ -392,13 +392,16 @@ def _onnx_grad_loss_square_error(target_opset=None, dtype=numpy.float32, mul = OnnxMul( OnnxMul(diff, diff, op_version=target_opset), resh, op_version=target_opset) - res = OnnxReduceSum(mul, op_version=target_opset, - keepdims=0, output_names=['Y']) + res = OnnxReduceSum(mul, op_version=target_opset) res2 = OnnxMul( OnnxMul(diff, numpy.array([-2], dtype=dtype), op_version=target_opset), resh, op_version=target_opset, output_names=['Z']) + res = OnnxReshape(res, numpy.array([-1], numpy.int64), + op_version=target_opset, + output_names=['Y']) + var_type = dtype_to_var_type(dtype) varsx = [('X1', var_type([None, None])), ('X2', var_type([None, None]))] @@ -433,13 +436,12 @@ def _onnx_grad_loss_absolute_error(target_opset=None, dtype=numpy.float32, print("DOT-SECTION", oinf.to_dot()) """ from skl2onnx.algebra.onnx_ops import ( - OnnxSub, OnnxMul, - OnnxReduceSum, OnnxReshape, OnnxSign, OnnxAbs) + OnnxSub, OnnxMul, OnnxReduceSum, OnnxReshape, + OnnxSign, OnnxAbs) diff = OnnxSub('X1', 'X2', op_version=target_opset) abs_diff = OnnxAbs(diff, op_version=target_opset) if weight_name is None: - res = OnnxReduceSum(abs_diff, op_version=target_opset, - keepdims=0, output_names=['Y']) + res = OnnxReduceSum(abs_diff, op_version=target_opset) res2 = OnnxSign(diff, op_version=target_opset, output_names=['Z']) else: @@ -447,12 +449,14 @@ def _onnx_grad_loss_absolute_error(target_opset=None, dtype=numpy.float32, numpy.array([-1, 1], dtype=numpy.int64), op_version=target_opset) mul = OnnxMul(abs_diff, resh, op_version=target_opset) - res = OnnxReduceSum(mul, op_version=target_opset, - keepdims=0, output_names=['Y']) + res = OnnxReduceSum(mul, op_version=target_opset) res2 = OnnxMul( OnnxSign(diff, op_version=target_opset), resh, op_version=target_opset, output_names=['Z']) + res = OnnxReshape(res, numpy.array([-1], numpy.int64), + op_version=target_opset, + output_names=['Y']) var_type = dtype_to_var_type(dtype) varsx = [('X1', var_type([None, None])), ('X2', var_type([None, None]))] @@ -497,18 +501,16 @@ def _onnx_grad_loss_elastic_error(target_opset=None, dtype=numpy.float32, diff = OnnxSub('X1', 'X2', op_version=target_opset) abs_diff = OnnxAbs(diff, op_version=target_opset) if weight_name is None: - res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset, - keepdims=0) + res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset) res2_l1 = OnnxSign(diff, op_version=target_opset) - res_l2 = OnnxReduceSumSquare(diff, op_version=target_opset, - keepdims=0) + res_l2 = OnnxReduceSumSquare(diff, op_version=target_opset) res2_l2 = diff else: resh = OnnxReshape(weight_name, numpy.array([-1, 1], dtype=numpy.int64), op_version=target_opset) mul = OnnxMul(abs_diff, resh, op_version=target_opset) - res_l1 = OnnxReduceSum(mul, op_version=target_opset, keepdims=0) + res_l1 = OnnxReduceSum(mul, op_version=target_opset) res2_l1 = OnnxMul( OnnxSign(diff, op_version=target_opset), resh, op_version=target_opset) @@ -524,7 +526,10 @@ def _onnx_grad_loss_elastic_error(target_opset=None, dtype=numpy.float32, op_version=target_opset), OnnxMul(res_l2, numpy.array([l2_weight], dtype=dtype), op_version=target_opset), - op_version=target_opset, output_names=['Y']) + op_version=target_opset) + res = OnnxReshape(res, numpy.array([-1], numpy.int64), + op_version=target_opset, + output_names=['Y']) res2 = OnnxAdd( OnnxMul(res2_l1, numpy.array([l1_weight], dtype=dtype), @@ -569,14 +574,12 @@ def _onnx_grad_penalty_elastic_error(target_opset=None, dtype=numpy.float32, """ from skl2onnx.algebra.onnx_ops import ( OnnxMul, OnnxAdd, OnnxReduceSumSquare, - OnnxReduceSum, OnnxSign, OnnxAbs) + OnnxReduceSum, OnnxSign, OnnxAbs, OnnxReshape) diff = 'X' abs_diff = OnnxAbs(diff, op_version=target_opset) - res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset, - keepdims=0) + res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset) res2_l1 = OnnxSign(diff, op_version=target_opset) - res_l2 = OnnxReduceSumSquare(diff, op_version=target_opset, - keepdims=0) + res_l2 = OnnxReduceSumSquare(diff, op_version=target_opset) res2_l2 = diff res = OnnxAdd( @@ -584,7 +587,10 @@ def _onnx_grad_penalty_elastic_error(target_opset=None, dtype=numpy.float32, op_version=target_opset), OnnxMul(res_l2, numpy.array([l2_weight], dtype=dtype), op_version=target_opset), - op_version=target_opset, output_names=['Y']) + op_version=target_opset) + res = OnnxReshape(res, numpy.array([-1], numpy.int64), + op_version=target_opset, + output_names=['Y']) res2 = OnnxAdd( OnnxMul(res2_l1, numpy.array([l1_weight], dtype=dtype), @@ -596,14 +602,15 @@ def _onnx_grad_penalty_elastic_error(target_opset=None, dtype=numpy.float32, var_type = dtype_to_var_type(dtype) varsx = [('X', var_type([None, None]))] onx = res.to_onnx( - varsx, outputs=[('Y', var_type()), ('Z', var_type())], + varsx, outputs=[('Y', var_type([None])), ('Z', var_type())], target_opset=target_opset, other_outputs=[res2]) return onx def _onnx_n_penalty_elastic_error(target_opset=None, dtype=numpy.float32, weight_name=None, - l1_weight=0.01, l2_weight=0.01, n_tensors=1): + l1_weight=0.01, l2_weight=0.01, n_tensors=1, + loss_shape=(1, 1)): """ Returns the ONNX graph for function :math:`Y = f(W) = \\beta \\lVert W \\rVert + @@ -627,23 +634,21 @@ def _onnx_n_penalty_elastic_error(target_opset=None, dtype=numpy.float32, """ from skl2onnx.algebra.onnx_ops import ( OnnxMul, OnnxAdd, OnnxReduceSumSquare, - OnnxReduceSum, OnnxAbs, OnnxSum) + OnnxReduceSum, OnnxAbs, OnnxReshape) if n_tensors <= 0: raise ValueError( # pragma: no cover "This function is useless if the number of tensors is null.") var_type = dtype_to_var_type(dtype) - varsx = [('loss', var_type([1, 1]))] + varsx = [('loss', var_type(loss_shape))] names = ['loss'] for n in range(n_tensors): name = 'W%d' % n abs_diff = OnnxAbs(name, op_version=target_opset) - res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset, - keepdims=0) + res_l1 = OnnxReduceSum(abs_diff, op_version=target_opset) # res2_l1 = OnnxSign(diff, op_version=target_opset) - res_l2 = OnnxReduceSumSquare(name, op_version=target_opset, - keepdims=0) + res_l2 = OnnxReduceSumSquare(name, op_version=target_opset) # res2_l2 = diff res = OnnxAdd( OnnxMul(res_l1, numpy.array([l1_weight], dtype=dtype), @@ -654,18 +659,28 @@ def _onnx_n_penalty_elastic_error(target_opset=None, dtype=numpy.float32, names.append(res) varsx.append(('W%d' % n, var_type())) - res = OnnxSum(*names, op_version=target_opset, output_names=['Y']) + if len(names) == 2: + res = OnnxAdd(*names, op_version=target_opset) + else: + res = OnnxAdd(names[1], names[2], op_version=target_opset) + for i in range(3, len(names)): + res = OnnxAdd(res, names[i], op_version=target_opset) + res = OnnxAdd(names[0], res, op_version=target_opset) + + res = OnnxReshape(res, numpy.array([-1], numpy.int64), + op_version=target_opset, + output_names=['Y']) onx = res.to_onnx( - varsx, outputs=[('Y', var_type())], + varsx, outputs=[('Y', var_type([None]))], target_opset=target_opset) return onx def _onnx_update_penalty_elastic_error(target_opset=None, dtype=numpy.float32, - l1=0.01, l2=0.01): + l1=1e-4, l2=1e-4): """ Returns the ONNX graph for function - :math:`Y = f(W) = W - 2 \\beta W + - \\alpha sign(W)` + :math:`Y = f(W) = W - 2 \\beta W - \\alpha sign(W)` *l1* is :math:`\\beta` and *l2* is :math:`\\alpha`.