Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion _doc/examples/plot_orttraining_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@
# The training requires a loss function. By default, it
# is the square function but it could be the absolute error or
# include penalties. Function
# :func:`add_loss_output <onnxcustom.training.orttraining.add_loss_output>`
# :func:`add_loss_output
# <onnxcustom.utils.orttraining_helper.add_loss_output>`
# appends the loss function to the ONNX graph.

onx_train = add_loss_output(onx)
Expand Down
2 changes: 1 addition & 1 deletion _doc/examples/plot_orttraining_linear_regression_fwbw.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class :class:`OrtGradientForwardBackwardOptimizer
# and returns the updated weights. This graph works on tensors of any shape
# but with the same element type.

plot_onnxs(train_session.loss_grad_onnx_,
plot_onnxs(train_session.learning_loss.loss_grad_onnx_,
train_session.learning_rate.axpy_onnx_,
title=['error gradient + loss', 'gradient update'])

Expand Down
3 changes: 2 additions & 1 deletion _doc/examples/plot_orttraining_nn_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@
# ++++++++++++++
#
# The loss function is the square function. We use function
# :func:`add_loss_output <onnxcustom.training.orttraining.add_loss_output>`.
# :func:`add_loss_output
# <onnxcustom.utils.orttraining_helper.add_loss_output>`.
# It does something what is implemented in example
# :ref:`l-orttraining-linreg-cpu`.

Expand Down
101 changes: 84 additions & 17 deletions _doc/sphinxdoc/source/api/training.rst
Original file line number Diff line number Diff line change
@@ -1,50 +1,117 @@

========
Training
========

There exists two APIs in :epkg:`onnxruntime`. One assumes
the loss function is part of the graph to derive, the other
one assumes the users provides the derivative of the loss
against the output of the graph. With the first API,
the weights are automatically updated. In the second API,
the users has to do it. It is more complex but gives more
freedom.

Both API are wrapped into two classes,
:ref:`l-api-prt-gradient-optimizer` for the first API,
:ref:`l-api-prt-gradient-optimizer-fw` for the second API.
Both classes make it easier to a user accustomed to
:epkg:`scikit-learn` API to train any graph with a
stochastic gradient descent algorithm.

.. contents::
:local:

BaseEstimator
+++++++++++++
=============

Ancestor to both classes wrapping :epkg:`onnxruntime` API.

.. autosignature:: onnxcustom.training.base_estimator.BaseEstimator
:members:

LearningRate
++++++++++++
Exceptions
==========

.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGD
:members:
.. autosignature:: onnxcustom.training.excs.ConvergenceError

.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGDNesterov
:members:
.. autosignature:: onnxcustom.training.excs.EvaluationError

.. autosignature:: onnxcustom.training.excs.ProviderError

First API: loss part of the graph
=================================

Helpers
+++++++

Function `add_loss_output` adds a loss function to the graph
if this loss is part of the a predefined list. It may
be combination of L1, L2 losses and L1, L2 penalties.

.. autosignature:: onnxcustom.utils.orttraining_helper.add_loss_output

.. autosignature:: onnxcustom.utils.orttraining_helper.get_train_initializer

.. _l-api-prt-gradient-optimizer:

OrtGradientOptimizer
++++++++++++++++++++

.. autosignature:: onnxcustom.training.optimizers.OrtGradientOptimizer
:members:

OrtGradientForwardBackward
++++++++++++++++++++++++++
Second API: loss part of the graph
==================================

.. autosignature:: onnxcustom.training.optimizers_partial.OrtGradientForwardBackwardOptimizer
ONNX
++++

Second API relies on class :epkg:`TrainingAgent`. It expects to find
the weight to train in alphabetical order. That's usual not the case.
The following function does not change the order but renames all
of them to fulfil that requirement.

.. autosignature:: onnxcustom.utils.onnx_helper.onnx_rename_weights

LearningPenalty
+++++++++++++++

.. autosignature:: onnxcustom.training.sgd_learning_penalty.NoLearningPenalty
:members:

Helpers
+++++++
.. autosignature:: onnxcustom.training.sgd_learning_penalty.ElasticLearningPenalty
:members:

.. autosignature:: onnxcustom.utils.orttraining_helper.add_loss_output
LearningRate
++++++++++++

.. autosignature:: onnxcustom.utils.orttraining_helper.get_train_initializer
.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGD
:members:

Exceptions
++++++++++
.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGDNesterov
:members:

.. autosignature:: onnxcustom.training.excs.ConvergenceError
LearningLoss
++++++++++++

.. autosignature:: onnxcustom.training.sgd_learning_loss.AbsoluteLearningLoss
:members:

.. autosignature:: onnxcustom.training.sgd_learning_loss.ElasticLearningLoss
:members:

.. autosignature:: onnxcustom.training.sgd_learning_loss.SquareLearningLoss
:members:

Loss function
+++++++++++++

.. autosignature:: onnxcustom.utils.onnx_function.function_onnx_graph

.. _l-api-prt-gradient-optimizer-fw:

OrtGradientForwardBackward
++++++++++++++++++++++++++

.. autosignature:: onnxcustom.training.optimizers_partial.OrtGradientForwardBackwardOptimizer
:members:
5 changes: 0 additions & 5 deletions _doc/sphinxdoc/source/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ Labelling

.. autosignature:: onnxcustom.utils.imagenet_classes.get_class_names

ONNX
++++

.. autosignature:: onnxcustom.utils.onnx_helper.onnx_rename_weights

Time
++++

Expand Down
77 changes: 75 additions & 2 deletions _doc/sphinxdoc/source/tutorial_training/tutorial_6_training.rst
Original file line number Diff line number Diff line change
@@ -1,15 +1,85 @@

.. _l-full-training:

Full Training
=============
Full Training with OrtGradientOptimizer
=======================================

.. contents::
:local:

Design
++++++

:epkg:`onnxruntime` was initially designed to speed up inference
and deployment but it can also be used to train a model.
It builds a graph equivalent to the gradient function
also based on onnx operators and specific gradient operators.
Initializers are weights that can be trained. The gradient graph
has as many as outputs as initializers.

:class:`OrtGradientOptimizer
<onnxcustom.training.optimizers.OrtGradientOptimizer>` wraps
class :epkg:`TrainingSession` from :epkg:`onnxruntime-training`.
It starts with one model converted into ONNX graph.
A loss must be added to this graph. Then class :epkg:`TrainingSession`
is able to compute another ONNX graph equivalent to the gradient
of the loss against the weights defined by intializers.

The first ONNX graph implements a function *Y=f(W, X)*.
Then function :func:`add_loss_output
<onnxcustom.utils.orttraining_helper.add_loss_output>`
adds a loss to define a graph *loss, Y=loss(f(W, X), W, expected_Y)*.
This same function is able to add the necessary nodes to compute
L1 and L2 losses or a combination of both, a L1 or L2 penalties
or a combination of both. Assuming the user was able to create
an an ONNX graph, he would add *0.1 L1 loss + 0.9 L2 loss*
and a L2 penalty on the coefficients by calling :func:`add_loss_output
<onnxcustom.utils.orttraining_helper.add_loss_output>`
like that:

::

onx_loss = add_loss_output(
onx, weight_name='weight', score_name='elastic',
l1_weight=0.1, l2_weight=0.9,
penalty={'coef': {'l2': 0.01}})

An instance of class :class:`OrtGradientOptimizer
<onnxcustom.training.optimizers.OrtGradientOptimizer>` is
initialized:

::

train_session = OrtGradientOptimizer(
onx_loss, ['intercept', 'coef'], learning_rate=1e-3)

And then trained:

::

train_session.fit(X_train, y_train, w_train)

Coefficients can be retrieved like the following:

::

state_tensors = train_session.get_state()

And train losses:

::

losses = train_session.train_losses_

This design does not allow any training with momentum,
keeping an accumulator for gradients yet.
The class does not expose all the possibilies implemented in
:epkg:`onnxruntime-training`.
Next examples show that in practice.

Examples
++++++++

The first example compares a linear regression trained with
:epkg:`scikit-learn` and another one trained with
:epkg:`onnxruntime-training`.
Expand All @@ -18,6 +88,9 @@ The two next examples explains in details how the training
with :epkg:`onnxruntime-training`. They dig into class
:class:`OrtGradientOptimizer
<onnxcustom.training.optimizers.OrtGradientOptimizer>`.
It leverages class :epkg:`TrainingSession` from :epkg:`onnxruntime-training`.
This one assumes the loss function is part of the graph to train.
It takes care to the weight updating as well.

The fourth example replicates what was done with the linear regression
but with a neural network built by :epkg:`scikit-learn`.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@

Partial Training
================
Partial Training with OrtGradientForwardBackwardOptimizer
=========================================================

.. contents::
:local:

Design
++++++

Section :ref:`l-full-training` introduces a class able a while
ONNX graph. :epkg:`onnxruntime-training` handles the computation
Expand All @@ -13,7 +19,64 @@ ONNX, and be trained by a gradient descent implemented in python.
Partial training is another way to train an ONNX model. It can be trained
as a standalone ONNX graph or be integrated in a :epkg:`torch` model or any
framework implementing *forward* and *backward* mechanism.
Next example introduced how this is done with ONNX.
It leverages class :epkg:`TrainingAgent` from :epkg:`onnxruntime-training`.

Main class is :class:`OrtGradientForwardBackwardOptimizer
<onnxcustom.training.optimizers_partial.OrtGradientForwardBackwardOptimizer>`.
It is initialized with an ONNX graph defining

::

train_session = OrtGradientForwardBackwardOptimizer(
onx, ['coef', 'intercept'],
learning_rate=LearningRateSGDNesterov()
learning_loss=ElasticLearningLoss(l1_weight=0.1, l2_weight=0.9),
learning_penalty=ElasticLearningPenalty(l1=0.1, l2=0.9))

The class holds three attributes defining the loss, its gradient,
the penalty, its gradient, a learning rate possibly with momentum.

* an object inheriting from :class:`BaseLearningLoss
<onnxcustom.training.sgd_learning_loss.BaseLearningLoss>`
* an object inheriting from :class:`BaseLearningPenalty
<onnxcustom.training.sgd_learning_loss.BaseLearningPenalty>`
* an object inheriting from :class:`BaseLearningRate
<onnxcustom.training.sgd_learning_rate.BaseLearningRate>`

Because :epkg:`onnxruntime-training` does not implement any standard
operations on :epkg:`OrtValue`, the only remaining is to create
simple ONNX graph execute by :epkg:`InferenceSession` to compute
loss, penalty and their gradient, and to update the weights accordingly.
These three classes all implement meth `build_onnx_function` which
creates create the ONNX graph based on the argument the classes were
initialized with. Traning happens this way:

::

train_session.fit(X_train, y_train, w_train)

Coefficients can be retrieved like the following:

::

state_tensors = train_session.get_state()

And train losses:

::

losses = train_session.train_losses_

Next examples show that in practice.

Examples
++++++++

This example assumes the loss function is not part of the graph to train
but the gradient of the loss against the graph output is provided.
It does not take care to the weight. This part must be separatly
implemented as well. Next examples introduce how this is done
with ONNX and :epkg:`onnxruntime-training`.

.. toctree::
:maxdepth: 1
Expand Down
Loading