From 7c5c037d2f51fbdbd6d0f99f21f39b988a668e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 4 Feb 2022 01:02:55 +0100 Subject: [PATCH] Implements get_trained_onnx to retreived the trained model --- _doc/sphinxdoc/source/api/training_utils.rst | 2 + .../test_optimizers_classification.py | 24 ++++++- onnxcustom/training/_base_estimator.py | 32 ++++++++- onnxcustom/training/optimizers.py | 21 +++++- onnxcustom/training/optimizers_partial.py | 20 ++++-- onnxcustom/utils/onnx_helper.py | 70 ++++++++++++++----- 6 files changed, 144 insertions(+), 25 deletions(-) diff --git a/_doc/sphinxdoc/source/api/training_utils.rst b/_doc/sphinxdoc/source/api/training_utils.rst index 6f872943..ac15c6ca 100644 --- a/_doc/sphinxdoc/source/api/training_utils.rst +++ b/_doc/sphinxdoc/source/api/training_utils.rst @@ -22,6 +22,8 @@ ONNX .. autosignature:: onnxcustom.utils.onnx_rewriter.onnx_rewrite_operator +.. autosignature:: onnxcustom.utils.onnx_helper.replace_initializers_into_onnx + onnxruntime +++++++++++ diff --git a/_unittests/ut_training/test_optimizers_classification.py b/_unittests/ut_training/test_optimizers_classification.py index 4d4f5cde..1f6098a2 100644 --- a/_unittests/ut_training/test_optimizers_classification.py +++ b/_unittests/ut_training/test_optimizers_classification.py @@ -3,8 +3,9 @@ """ import unittest import logging +from io import BytesIO import numpy -from onnx import TensorProto +from onnx import TensorProto, load as load_onnx from onnx.helper import set_model_props from pyquickhelper.pycode import ( ExtTestCase, get_temp_folder, ignore_warnings) @@ -17,6 +18,7 @@ from mlprodict.plotting.text_plot import onnx_simple_text_plot from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs # from mlprodict.onnxrt import OnnxInference +from onnxruntime import InferenceSession try: from onnxruntime import TrainingSession except ImportError: @@ -58,6 +60,7 @@ def wtest_ort_gradient_optimizers_binary(self, use_weight=False): onx = to_onnx(reg, X_train, target_opset=opset, black_op={'LinearClassifier'}, options={'zipmap': False}) + onx2 = load_onnx(BytesIO(onx.SerializeToString())) set_model_props(onx, {'info': 'unit test'}) onx_loss = add_loss_output( onx, 'log', output_index=1, @@ -86,6 +89,15 @@ def wtest_ort_gradient_optimizers_binary(self, use_weight=False): self.assertGreater(len(losses), 1) self.assertFalse(any(map(numpy.isnan, losses))) + # get_trained_weight + trained_onnx = train_session.get_trained_onnx(model=onx2) + sess = InferenceSession(onx2.SerializeToString()) + got1 = sess.run(None, {'X': X_train}) + sess = InferenceSession(trained_onnx.SerializeToString()) + got2 = sess.run(None, {'X': X_train}) + self.assertEqual(len(got1), len(got2)) + self.assertEqual(got1[0].shape, got2[0].shape) + # state state = train_session.get_state() self.assertIsInstance(state, dict) @@ -121,6 +133,7 @@ def wtest_ort_gradient_optimizers_fw_nesterov_binary(self, use_weight): black_op={'LinearRegressor'}, options={'zipmap': False, 'raw_scores': True}) + onx2 = onx onx = select_model_inputs_outputs(onx, outputs=['score']) self.assertIn("output: name='score'", onnx_simple_text_plot(onx)) @@ -144,6 +157,15 @@ def wtest_ort_gradient_optimizers_fw_nesterov_binary(self, use_weight): __file__, "temp_ort_gradient_optimizers_fw_nesterov_binary") train_session.save_onnx_graph(temp) + # get_trained_weight + trained_onnx = train_session.get_trained_onnx(model=onx2) + sess = InferenceSession(onx2.SerializeToString()) + got1 = sess.run(None, {'X': X_train}) + sess = InferenceSession(trained_onnx.SerializeToString()) + got2 = sess.run(None, {'X': X_train}) + self.assertEqual(len(got1), len(got2)) + self.assertEqual(got1[0].shape, got2[0].shape) + # state state = train_session.get_state() self.assertIsInstance(state, list) diff --git a/onnxcustom/training/_base_estimator.py b/onnxcustom/training/_base_estimator.py index a1ce57f7..7ad2ea9b 100644 --- a/onnxcustom/training/_base_estimator.py +++ b/onnxcustom/training/_base_estimator.py @@ -7,6 +7,7 @@ OrtDevice as C_OrtDevice) from ..utils.onnxruntime_helper import ( get_ort_device, ort_device_to_string) +from ..utils.onnx_helper import replace_initializers_into_onnx from ._base import BaseOnnxClass from ._base_onnx_function import BaseLearningOnnx from .sgd_learning_rate import BaseLearningRate @@ -17,13 +18,15 @@ class BaseEstimator(BaseOnnxClass): Base class for optimizers. Implements common methods such `__repr__`. + :param model_onnx: onnx graph to train :param learning_rate: learning rate class, see module :mod:`onnxcustom.training.sgd_learning_rate` :param device: device as :epkg:`C_OrtDevice` or a string representing this device """ - def __init__(self, learning_rate, device): + def __init__(self, model_onnx, learning_rate, device): + self.model_onnx = model_onnx self.learning_rate = BaseLearningRate.select(learning_rate) self.device = get_ort_device(device) @@ -98,3 +101,30 @@ def __setstate__(self, state): setattr(self, att, v) self.device = get_ort_device(self.device) return self + + def get_trained_onnx(self): + """ + Returns the trained onnx graph, the initial graph + modified by replacing the initializers with the trained + weights. + + :return: onnx graph + """ + raise NotImplementedError( # pragma: no cover + "The method needs to be overloaded.") + + def _get_trained_onnx(self, state, model=None): + """ + Returns the trained onnx graph, the initial graph + modified by replacing the initializers with the trained + weights. + + :param state: trained weights + :param model: replace the weights in another graph + than the training graph + :return: onnx graph + """ + if model is None: + return replace_initializers_into_onnx( + self.model_onnx, state) + return replace_initializers_into_onnx(model, state) diff --git a/onnxcustom/training/optimizers.py b/onnxcustom/training/optimizers.py index 9615b5f3..1251b1e5 100644 --- a/onnxcustom/training/optimizers.py +++ b/onnxcustom/training/optimizers.py @@ -19,7 +19,7 @@ class OrtGradientOptimizer(BaseEstimator): Implements a simple :epkg:`Stochastic Gradient Descent` with :epkg:`onnxruntime-training`. - :param model_onnx: ONNX graph used to train + :param model_onnx: onnx graph to train :param weights_to_train: names of initializers to be optimized :param loss_output_name: name of the loss output :param max_iter: number of training iterations @@ -51,8 +51,7 @@ def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', device='cpu', warm_start=False, verbose=0, validation_every=0.1, saved_gradient=None, sample_weight_name="weight"): - BaseEstimator.__init__(self, learning_rate, device) - self.model_onnx = model_onnx + BaseEstimator.__init__(self, model_onnx, learning_rate, device) self.batch_size = batch_size self.weights_to_train = weights_to_train self.loss_output_name = loss_output_name @@ -336,6 +335,22 @@ def get_state(self): raise AttributeError("Method fit must be called before.") return self.train_session_.get_state() + def get_trained_onnx(self, model=None): + """ + Returns the trained onnx graph, the initial graph + modified by replacing the initializers with the trained + weights. If model is not specified, it uses the model + given as an argument to this class. This graph outputs + the loss and not the predictions. Parameter *model* + can be used to use the graph before loss was added + and then the returned graph will produce the predictions. + + :param model: replace the weights in another graph + than the training graph + :return: onnx graph + """ + return self._get_trained_onnx(self.get_state(), model=model) + def set_state(self, state): """ Changes the trained weights. diff --git a/onnxcustom/training/optimizers_partial.py b/onnxcustom/training/optimizers_partial.py index 3bac8186..e12761fe 100644 --- a/onnxcustom/training/optimizers_partial.py +++ b/onnxcustom/training/optimizers_partial.py @@ -28,7 +28,7 @@ class OrtGradientForwardBackwardOptimizer(BaseEstimator): with :epkg:`onnxruntime-training`. It leverages class @see class OrtGradientForwardBackward. - :param model_onnx: ONNX graph used to train + :param model_onnx: onnx graph to train :param weights_to_train: names of initializers to be optimized, if None, function @see fn get_train_initialize returns the list of float iniitializer @@ -77,8 +77,7 @@ def __init__(self, model_onnx, weights_to_train=None, learning_penalty=None, exc=True): if weights_to_train is None: weights_to_train = list(get_train_initializer(model_onnx)) - BaseEstimator.__init__(self, learning_rate, device) - self.model_onnx = model_onnx + BaseEstimator.__init__(self, model_onnx, learning_rate, device) self.batch_size = batch_size self.weights_to_train = weights_to_train self.loss_output_name = loss_output_name @@ -147,7 +146,7 @@ def _get_att_state(self, kind): def get_full_state(self, kind='weight'): """ - Returns the trained weights. + Returns the trained weights and the inputs. """ if isinstance(kind, list): return [self.get_full_state(kind=k) for k in kind] @@ -174,6 +173,19 @@ def get_state(self, kind='weight'): n = len(value) - len(self.weights_to_train) return value[n:] + def get_trained_onnx(self, model=None): + """ + Returns the trained onnx graph, the initial graph + modified by replacing the initializers with the trained + weights. + + :param model: replace the weights in another graph + than the training graph + :return: onnx graph + """ + state = dict(zip(self.weights_to_train, self.get_state())) + return self._get_trained_onnx(state, model=model) + def set_state(self, state, check_trained=True, kind='weight', zero=False): """ Changes the trained weights. diff --git a/onnxcustom/utils/onnx_helper.py b/onnxcustom/utils/onnx_helper.py index 1959fa50..0d94eabd 100644 --- a/onnxcustom/utils/onnx_helper.py +++ b/onnxcustom/utils/onnx_helper.py @@ -6,6 +6,8 @@ import math import numpy from onnx import TensorProto, numpy_helper, helper +from onnxruntime import OrtValue +from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue def onnx_rename_weights(onx): @@ -88,6 +90,26 @@ def dtype_to_var_type(dtype): "Unexpected value dtype=%r." % dtype) +def _finalize_new_onnx(graph, onx): + onnx_model = helper.make_model(graph) + onnx_model.ir_version = onx.ir_version + onnx_model.producer_name = onx.producer_name + onnx_model.producer_version = onx.producer_version + onnx_model.domain = onx.domain + onnx_model.model_version = onx.model_version + onnx_model.doc_string = onx.doc_string + if len(onx.metadata_props) > 0: # pragma: no cover + values = {p.key: p.value for p in onx.metadata_props} + helper.set_model_props(onnx_model, values) + + del onnx_model.opset_import[:] # pylint: disable=E1101 + for oimp in onx.opset_import: + op_set = onnx_model.opset_import.add() # pylint: disable=E1101 + op_set.domain = oimp.domain + op_set.version = oimp.version + return onnx_model + + def add_initializer(model, name, value): """ Adds an initializer to graph. @@ -109,20 +131,36 @@ def add_initializer(model, name, value): model.graph.node, model.graph.name, model.graph.input, model.graph.output, list_inits) - onnx_model = helper.make_model(graph_def) - onnx_model.ir_version = model.ir_version - onnx_model.producer_name = model.producer_name - onnx_model.producer_version = model.producer_version - onnx_model.domain = model.domain - onnx_model.model_version = model.model_version - onnx_model.doc_string = model.doc_string - if len(model.metadata_props) > 0: # pragma: no cover - values = {p.key: p.value for p in model.metadata_props} - helper.set_model_props(onnx_model, values) + return _finalize_new_onnx(graph_def, model) - del onnx_model.opset_import[:] # pylint: disable=E1101 - for oimp in model.opset_import: - op_set = onnx_model.opset_import.add() # pylint: disable=E1101 - op_set.domain = oimp.domain - op_set.version = oimp.version - return onnx_model + +def replace_initializers_into_onnx(model, results): + """ + Replaces initializers by other initializers, + usually trained ones. + + :param model: onnx graph + :param results: results to be added in a dictionary + :return: new onnx graph + """ + inputs = list(model.graph.input) + outputs = list(model.graph.output) + inits = list(model.graph.initializer) + + inits_dict = {init.name: i for i, init in enumerate(inits)} + for k, v in results.items(): + if k in inits_dict: + if isinstance(v, numpy.ndarray): + v = numpy_helper.from_array(v, k) + elif isinstance(v, (C_OrtValue, OrtValue)): + v = numpy_helper.from_array(v.numpy(), k) + inits[inits_dict[k]] = v + else: + raise RuntimeError( + "Unable to find initializer %r in " + "%r." % (k, inits_dict)) + + graph = helper.make_graph( + list(model.graph.node), model.graph.name, inputs, + outputs, inits) + return _finalize_new_onnx(graph, model)