Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions _doc/sphinxdoc/source/api/training_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ ONNX

.. autosignature:: onnxcustom.utils.onnx_rewriter.onnx_rewrite_operator

.. autosignature:: onnxcustom.utils.onnx_helper.replace_initializers_into_onnx

onnxruntime
+++++++++++

Expand Down
24 changes: 23 additions & 1 deletion _unittests/ut_training/test_optimizers_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""
import unittest
import logging
from io import BytesIO
import numpy
from onnx import TensorProto
from onnx import TensorProto, load as load_onnx
from onnx.helper import set_model_props
from pyquickhelper.pycode import (
ExtTestCase, get_temp_folder, ignore_warnings)
Expand All @@ -17,6 +18,7 @@
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from mlprodict.onnx_tools.onnx_manipulations import select_model_inputs_outputs
# from mlprodict.onnxrt import OnnxInference
from onnxruntime import InferenceSession
try:
from onnxruntime import TrainingSession
except ImportError:
Expand Down Expand Up @@ -58,6 +60,7 @@ def wtest_ort_gradient_optimizers_binary(self, use_weight=False):
onx = to_onnx(reg, X_train, target_opset=opset,
black_op={'LinearClassifier'},
options={'zipmap': False})
onx2 = load_onnx(BytesIO(onx.SerializeToString()))
set_model_props(onx, {'info': 'unit test'})
onx_loss = add_loss_output(
onx, 'log', output_index=1,
Expand Down Expand Up @@ -86,6 +89,15 @@ def wtest_ort_gradient_optimizers_binary(self, use_weight=False):
self.assertGreater(len(losses), 1)
self.assertFalse(any(map(numpy.isnan, losses)))

# get_trained_weight
trained_onnx = train_session.get_trained_onnx(model=onx2)
sess = InferenceSession(onx2.SerializeToString())
got1 = sess.run(None, {'X': X_train})
sess = InferenceSession(trained_onnx.SerializeToString())
got2 = sess.run(None, {'X': X_train})
self.assertEqual(len(got1), len(got2))
self.assertEqual(got1[0].shape, got2[0].shape)

# state
state = train_session.get_state()
self.assertIsInstance(state, dict)
Expand Down Expand Up @@ -121,6 +133,7 @@ def wtest_ort_gradient_optimizers_fw_nesterov_binary(self, use_weight):
black_op={'LinearRegressor'},
options={'zipmap': False,
'raw_scores': True})
onx2 = onx
onx = select_model_inputs_outputs(onx, outputs=['score'])
self.assertIn("output: name='score'",
onnx_simple_text_plot(onx))
Expand All @@ -144,6 +157,15 @@ def wtest_ort_gradient_optimizers_fw_nesterov_binary(self, use_weight):
__file__, "temp_ort_gradient_optimizers_fw_nesterov_binary")
train_session.save_onnx_graph(temp)

# get_trained_weight
trained_onnx = train_session.get_trained_onnx(model=onx2)
sess = InferenceSession(onx2.SerializeToString())
got1 = sess.run(None, {'X': X_train})
sess = InferenceSession(trained_onnx.SerializeToString())
got2 = sess.run(None, {'X': X_train})
self.assertEqual(len(got1), len(got2))
self.assertEqual(got1[0].shape, got2[0].shape)

# state
state = train_session.get_state()
self.assertIsInstance(state, list)
Expand Down
32 changes: 31 additions & 1 deletion onnxcustom/training/_base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
OrtDevice as C_OrtDevice)
from ..utils.onnxruntime_helper import (
get_ort_device, ort_device_to_string)
from ..utils.onnx_helper import replace_initializers_into_onnx
from ._base import BaseOnnxClass
from ._base_onnx_function import BaseLearningOnnx
from .sgd_learning_rate import BaseLearningRate
Expand All @@ -17,13 +18,15 @@ class BaseEstimator(BaseOnnxClass):
Base class for optimizers.
Implements common methods such `__repr__`.

:param model_onnx: onnx graph to train
:param learning_rate: learning rate class,
see module :mod:`onnxcustom.training.sgd_learning_rate`
:param device: device as :epkg:`C_OrtDevice` or a string
representing this device
"""

def __init__(self, learning_rate, device):
def __init__(self, model_onnx, learning_rate, device):
self.model_onnx = model_onnx
self.learning_rate = BaseLearningRate.select(learning_rate)
self.device = get_ort_device(device)

Expand Down Expand Up @@ -98,3 +101,30 @@ def __setstate__(self, state):
setattr(self, att, v)
self.device = get_ort_device(self.device)
return self

def get_trained_onnx(self):
"""
Returns the trained onnx graph, the initial graph
modified by replacing the initializers with the trained
weights.

:return: onnx graph
"""
raise NotImplementedError( # pragma: no cover
"The method needs to be overloaded.")

def _get_trained_onnx(self, state, model=None):
"""
Returns the trained onnx graph, the initial graph
modified by replacing the initializers with the trained
weights.

:param state: trained weights
:param model: replace the weights in another graph
than the training graph
:return: onnx graph
"""
if model is None:
return replace_initializers_into_onnx(
self.model_onnx, state)
return replace_initializers_into_onnx(model, state)
21 changes: 18 additions & 3 deletions onnxcustom/training/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class OrtGradientOptimizer(BaseEstimator):
Implements a simple :epkg:`Stochastic Gradient Descent`
with :epkg:`onnxruntime-training`.

:param model_onnx: ONNX graph used to train
:param model_onnx: onnx graph to train
:param weights_to_train: names of initializers to be optimized
:param loss_output_name: name of the loss output
:param max_iter: number of training iterations
Expand Down Expand Up @@ -51,8 +51,7 @@ def __init__(self, model_onnx, weights_to_train, loss_output_name='loss',
device='cpu', warm_start=False, verbose=0,
validation_every=0.1, saved_gradient=None,
sample_weight_name="weight"):
BaseEstimator.__init__(self, learning_rate, device)
self.model_onnx = model_onnx
BaseEstimator.__init__(self, model_onnx, learning_rate, device)
self.batch_size = batch_size
self.weights_to_train = weights_to_train
self.loss_output_name = loss_output_name
Expand Down Expand Up @@ -336,6 +335,22 @@ def get_state(self):
raise AttributeError("Method fit must be called before.")
return self.train_session_.get_state()

def get_trained_onnx(self, model=None):
"""
Returns the trained onnx graph, the initial graph
modified by replacing the initializers with the trained
weights. If model is not specified, it uses the model
given as an argument to this class. This graph outputs
the loss and not the predictions. Parameter *model*
can be used to use the graph before loss was added
and then the returned graph will produce the predictions.

:param model: replace the weights in another graph
than the training graph
:return: onnx graph
"""
return self._get_trained_onnx(self.get_state(), model=model)

def set_state(self, state):
"""
Changes the trained weights.
Expand Down
20 changes: 16 additions & 4 deletions onnxcustom/training/optimizers_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class OrtGradientForwardBackwardOptimizer(BaseEstimator):
with :epkg:`onnxruntime-training`. It leverages class
@see class OrtGradientForwardBackward.

:param model_onnx: ONNX graph used to train
:param model_onnx: onnx graph to train
:param weights_to_train: names of initializers to be optimized,
if None, function @see fn get_train_initialize returns
the list of float iniitializer
Expand Down Expand Up @@ -77,8 +77,7 @@ def __init__(self, model_onnx, weights_to_train=None,
learning_penalty=None, exc=True):
if weights_to_train is None:
weights_to_train = list(get_train_initializer(model_onnx))
BaseEstimator.__init__(self, learning_rate, device)
self.model_onnx = model_onnx
BaseEstimator.__init__(self, model_onnx, learning_rate, device)
self.batch_size = batch_size
self.weights_to_train = weights_to_train
self.loss_output_name = loss_output_name
Expand Down Expand Up @@ -147,7 +146,7 @@ def _get_att_state(self, kind):

def get_full_state(self, kind='weight'):
"""
Returns the trained weights.
Returns the trained weights and the inputs.
"""
if isinstance(kind, list):
return [self.get_full_state(kind=k) for k in kind]
Expand All @@ -174,6 +173,19 @@ def get_state(self, kind='weight'):
n = len(value) - len(self.weights_to_train)
return value[n:]

def get_trained_onnx(self, model=None):
"""
Returns the trained onnx graph, the initial graph
modified by replacing the initializers with the trained
weights.

:param model: replace the weights in another graph
than the training graph
:return: onnx graph
"""
state = dict(zip(self.weights_to_train, self.get_state()))
return self._get_trained_onnx(state, model=model)

def set_state(self, state, check_trained=True, kind='weight', zero=False):
"""
Changes the trained weights.
Expand Down
70 changes: 54 additions & 16 deletions onnxcustom/utils/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import math
import numpy
from onnx import TensorProto, numpy_helper, helper
from onnxruntime import OrtValue
from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue


def onnx_rename_weights(onx):
Expand Down Expand Up @@ -88,6 +90,26 @@ def dtype_to_var_type(dtype):
"Unexpected value dtype=%r." % dtype)


def _finalize_new_onnx(graph, onx):
onnx_model = helper.make_model(graph)
onnx_model.ir_version = onx.ir_version
onnx_model.producer_name = onx.producer_name
onnx_model.producer_version = onx.producer_version
onnx_model.domain = onx.domain
onnx_model.model_version = onx.model_version
onnx_model.doc_string = onx.doc_string
if len(onx.metadata_props) > 0: # pragma: no cover
values = {p.key: p.value for p in onx.metadata_props}
helper.set_model_props(onnx_model, values)

del onnx_model.opset_import[:] # pylint: disable=E1101
for oimp in onx.opset_import:
op_set = onnx_model.opset_import.add() # pylint: disable=E1101
op_set.domain = oimp.domain
op_set.version = oimp.version
return onnx_model


def add_initializer(model, name, value):
"""
Adds an initializer to graph.
Expand All @@ -109,20 +131,36 @@ def add_initializer(model, name, value):
model.graph.node, model.graph.name,
model.graph.input, model.graph.output,
list_inits)
onnx_model = helper.make_model(graph_def)
onnx_model.ir_version = model.ir_version
onnx_model.producer_name = model.producer_name
onnx_model.producer_version = model.producer_version
onnx_model.domain = model.domain
onnx_model.model_version = model.model_version
onnx_model.doc_string = model.doc_string
if len(model.metadata_props) > 0: # pragma: no cover
values = {p.key: p.value for p in model.metadata_props}
helper.set_model_props(onnx_model, values)
return _finalize_new_onnx(graph_def, model)

del onnx_model.opset_import[:] # pylint: disable=E1101
for oimp in model.opset_import:
op_set = onnx_model.opset_import.add() # pylint: disable=E1101
op_set.domain = oimp.domain
op_set.version = oimp.version
return onnx_model

def replace_initializers_into_onnx(model, results):
"""
Replaces initializers by other initializers,
usually trained ones.

:param model: onnx graph
:param results: results to be added in a dictionary
:return: new onnx graph
"""
inputs = list(model.graph.input)
outputs = list(model.graph.output)
inits = list(model.graph.initializer)

inits_dict = {init.name: i for i, init in enumerate(inits)}
for k, v in results.items():
if k in inits_dict:
if isinstance(v, numpy.ndarray):
v = numpy_helper.from_array(v, k)
elif isinstance(v, (C_OrtValue, OrtValue)):
v = numpy_helper.from_array(v.numpy(), k)
inits[inits_dict[k]] = v
else:
raise RuntimeError(
"Unable to find initializer %r in "
"%r." % (k, inits_dict))

graph = helper.make_graph(
list(model.graph.node), model.graph.name, inputs,
outputs, inits)
return _finalize_new_onnx(graph, model)