Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions _doc/sphinxdoc/source/api/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ BaseEstimator
.. autosignature:: onnxcustom.training.optimizers.BaseEstimator
:members:

LearingRate
+++++++++++

.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGDRegressor
:members:

OrtGradientOptimizer
++++++++++++++++++++

Expand Down
2 changes: 1 addition & 1 deletion _doc/sphinxdoc/source/onnxmd/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ The documentation can be searched.

.. toctree::
:maxdepth: 1

index_onnx
index_onnxruntime
7 changes: 5 additions & 2 deletions _unittests/ut_training/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.linear_model import LinearRegression
from mlprodict.onnx_conv import to_onnx
from onnxcustom import __max_supported_opset__ as opset
from onnxcustom.training.sgd_learning_rate import LearningRateSGDRegressor
try:
from onnxruntime import TrainingSession
except ImportError:
Expand Down Expand Up @@ -136,7 +137,8 @@ def test_ort_gradient_optimizers_optimal_use_numpy(self):
onx_loss = add_loss_output(onx)
inits = ['intercept', 'coef']
train_session = OrtGradientOptimizer(
onx_loss, inits, learning_rate='optimal', max_iter=10)
onx_loss, inits, max_iter=10,
learning_rate=LearningRateSGDRegressor(learning_rate='optimal'))
self.assertRaise(lambda: train_session.get_state(), AttributeError)
train_session.fit(X, y, use_numpy=True)
state_tensors = train_session.get_state()
Expand Down Expand Up @@ -164,7 +166,8 @@ def test_ort_gradient_optimizers_optimal_use_ort(self):
onx_loss = add_loss_output(onx)
inits = ['intercept', 'coef']
train_session = OrtGradientOptimizer(
onx_loss, inits, learning_rate='optimal', max_iter=10)
onx_loss, inits, max_iter=10,
learning_rate=LearningRateSGDRegressor(learning_rate='optimal'))
self.assertRaise(lambda: train_session.get_state(), AttributeError)
train_session.fit(X, y, use_numpy=False)
state_tensors = train_session.get_state()
Expand Down
58 changes: 20 additions & 38 deletions onnxcustom/training/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@
OrtValue, TrainingParameters,
SessionOptions, TrainingSession)
from .data_loader import OrtDataLoader
from .sgd_learning_rate import BaseLearningRate


class BaseEstimator:
"""
Base class for optimizers.
Implements common methods such `__repr__`.

:param learning_rate: learning rate class,
see module :mod:`onnxcustom.training.sgd_learning_rate`
"""

def __init__(self, learning_rate):
self.learning_rate = BaseLearningRate.select(learning_rate)

@classmethod
def _get_param_names(cls):
init = getattr(cls.__init__, "deprecated_original", cls.__init__)
Expand All @@ -32,7 +39,9 @@ def __repr__(self):
if k not in self.__dict__:
continue # pragma: no cover
ov = getattr(self, k)
if v is not inspect._empty or ov != v:
if isinstance(ov, BaseLearningRate):
ps.append("%s=%s" % (k, repr(ov)))
elif v is not inspect._empty or ov != v:
ro = repr(ov)
if len(ro) > 50 or "\n" in ro:
ro = ro[:10].replace("\n", " ") + "..."
Expand All @@ -53,24 +62,16 @@ class OrtGradientOptimizer(BaseEstimator):
:param max_iter: number of training iterations
:param training_optimizer_name: optimizing algorithm
:param batch_size: batch size (see class *DataLoader*)
:param eta0: initial learning rate for the `'constant'`, `'invscaling'`
or `'adaptive'` schedules.
:param alpha: constant that multiplies the regularization term,
the higher the value, the stronger the regularization.
Also used to compute the learning rate when set to *learning_rate*
is set to `'optimal'`.
:param power_t: exponent for inverse scaling learning rate
:param learning_rate: learning rate schedule:
* `'constant'`: `eta = eta0`
* `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen
by a heuristic proposed by Leon Bottou.
* `'invscaling'`: `eta = eta0 / pow(t, power_t)`
:param learning_rate: a name or a learning rate instance,
see module :mod:`onnxcustom.training.sgd_learning_rate`
:param device: `'cpu'` or `'cuda'`
:param device_idx: device index
:param warm_start: when set to True, reuse the solution of the previous
call to fit as initialization, otherwise, just erase the previous
solution.
:param verbose: use :epkg:`tqdm` to display the training progress
:param validation_every: validation with a test set every
*validation_every* iterations

Once initialized, the class creates the attribute
`session_` which holds an instance of `onnxruntime.TrainingSession`.
Expand All @@ -96,22 +97,19 @@ class OrtGradientOptimizer(BaseEstimator):

def __init__(self, model_onnx, weights_to_train, loss_output_name='loss',
max_iter=100, training_optimizer_name='SGDOptimizer',
batch_size=10, eta0=0.01, alpha=0.0001, power_t=0.25,
learning_rate='invscaling', device='cpu', device_idx=0,
batch_size=10, learning_rate='SGDRegressor',
device='cpu', device_idx=0,
warm_start=False, verbose=0, validation_every=0.1):
# See https://scikit-learn.org/stable/modules/generated/
# sklearn.linear_model.SGDRegressor.html
BaseEstimator.__init__(self, learning_rate)
self.model_onnx = model_onnx
self.batch_size = batch_size
self.weights_to_train = weights_to_train
self.loss_output_name = loss_output_name
self.training_optimizer_name = training_optimizer_name
self.verbose = verbose
self.max_iter = max_iter
self.eta0 = eta0
self.alpha = alpha
self.power_t = power_t
self.learning_rate = learning_rate.lower()
self.device = device
self.device_idx = device_idx
self.warm_start = warm_start
Expand All @@ -133,23 +131,6 @@ def __setstate__(self, state):
setattr(self, att, v)
return self

def _init_learning_rate(self):
self.eta0_ = self.eta0
if self.learning_rate == "optimal":
typw = numpy.sqrt(1.0 / numpy.sqrt(self.alpha))
self.eta0_ = typw / max(1.0, (1 + typw) * 2)
self.optimal_init_ = 1.0 / (self.eta0_ * self.alpha)
else:
self.eta0_ = self.eta0
return self.eta0_

def _update_learning_rate(self, t, eta):
if self.learning_rate == "optimal":
eta = 1.0 / (self.alpha * (self.optimal_init_ + t))
elif self.learning_rate == "invscaling":
eta = self.eta0_ / numpy.power(t + 1, self.power_t)
return eta

def fit(self, X, y, X_val=None, y_val=None, use_numpy=False):
"""
Trains the model.
Expand Down Expand Up @@ -186,7 +167,7 @@ def fit(self, X, y, X_val=None, y_val=None, use_numpy=False):
X_val, y_val, batch_size=X_val.shape[0], device=self.device)
else:
data_loader_val = None
lr = self._init_learning_rate()
self.learning_rate.init_learning_rate()
self.input_names_ = [i.name for i in self.train_session_.get_inputs()]
self.output_names_ = [
o.name for o in self.train_session_.get_outputs()]
Expand All @@ -202,13 +183,14 @@ def fit(self, X, y, X_val=None, y_val=None, use_numpy=False):

train_losses = []
val_losses = []
lr = self.learning_rate.value
for it in loop:
bind_lr = OrtValue.ortvalue_from_numpy(
numpy.array([lr / self.batch_size], dtype=numpy.float32),
self.device, self.device_idx)
loss = self._iteration(data_loader, bind_lr,
bind, use_numpy=use_numpy)
lr = self._update_learning_rate(it, lr)
lr = self.learning_rate.update_learning_rate(it).value
if self.verbose > 1: # pragma: no cover
loop.set_description(
"loss=%1.3g lr=%1.3g" % ( # pylint: disable=E1101,E1307
Expand Down
153 changes: 153 additions & 0 deletions onnxcustom/training/sgd_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
@file
@brief Helper for :epkg:`onnxruntime-training`.
"""
import inspect
import numpy


class BaseLearningRate:
"""
Class handling the learning rate update after every
iteration of a gradient.
"""

def __init__(self):
pass

def init_learning_rate(self):
"""
Initializes the learning rate at the beginning of the training.
:return: self
"""
raise NotImplementedError(
"This method must be overwritten.")

def update_learning_rate(self, t):
"""
Updates the learning rate at the end of an iteration.
:param t: iteration number
:return: self
"""
raise NotImplementedError(
"This method must be overwritten.")

@property
def value(self):
"Returns the current learning rate."
raise NotImplementedError(
"This method must be overwritten.")

@staticmethod
def select(class_name, **kwargs):
"""
Returns an instance of a given initialized with
*kwargs*.
:param class_name: an instance of @see cl BaseLearningRate
or a string among the following class names (see below)
:return: instance of @see cl BaseLearningRate

Possible values for *class_name*:
* `'LearningRateSGDRegressor'`: see @see cl LearningRateSGDRegressor
"""
if isinstance(class_name, BaseLearningRate):
return class_name
cls = {LearningRateSGDRegressor: ['SGDRegressor']}
for cl, aliases in cls.items():
if class_name == cl.__class__.__name__ or class_name in aliases:
return cl(**kwargs)
raise ValueError(
"Unexpected class name %r. It should be one of %r." % (
class_name, list(map(lambda c: c.__name__, cls))))

@classmethod
def _get_param_names(cls):
init = getattr(cls.__init__, "deprecated_original", cls.__init__)
init_signature = inspect.signature(init)
parameters = [
p for p in init_signature.parameters.values()
if p.name != "self" and p.kind != p.VAR_KEYWORD]
return [(p.name, p.default) for p in parameters]

def __repr__(self):
"""
Usual
"""
param = self._get_param_names()
ps = []
for k, v in param:
if k not in self.__dict__:
continue # pragma: no cover
ov = getattr(self, k)
if v is not inspect._empty or ov != v:
ro = repr(ov)
ps.append("%s=%s" % (k, ro))
return "%s(%s)" % (self.__class__.__name__, ", ".join(ps))


class LearningRateSGDRegressor(BaseLearningRate):
"""
Implements the learning the same way as
:class:`sklearn.linear_model.SGDRegressor`.

:param eta0: initial learning rate for the `'constant'`, `'invscaling'`
or `'adaptive'` schedules.
:param alpha: constant that multiplies the regularization term,
the higher the value, the stronger the regularization.
Also used to compute the learning rate when set to *learning_rate*
is set to `'optimal'`.
:param power_t: exponent for inverse scaling learning rate
:param learning_rate: learning rate schedule:
* `'constant'`: `eta = eta0`
* `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen
by a heuristic proposed by Leon Bottou.
* `'invscaling'`: `eta = eta0 / pow(t, power_t)`

Created attributes:
* `eta0_`: initial eta0
* `optimal_init_`: use when `learning_rate=='optimal'`
* `value_`: value to be returned by property `value`
"""

def __init__(self, eta0=0.01, alpha=0.0001, power_t=0.25,
learning_rate='invscaling'):
BaseLearningRate.__init__(self)
self.eta0 = eta0
self.alpha = alpha
self.power_t = power_t
self.learning_rate = learning_rate.lower()
self.value_ = None

def init_learning_rate(self):
"""
Updates the learning rate at the end of an iteration.
:return: self
"""
self.eta0_ = self.eta0
if self.learning_rate == "optimal":
typw = numpy.sqrt(1.0 / numpy.sqrt(self.alpha))
self.eta0_ = typw / max(1.0, (1 + typw) * 2)
self.optimal_init_ = 1.0 / (self.eta0_ * self.alpha)
else:
self.eta0_ = self.eta0
self.value_ = self.eta0_
return self

def update_learning_rate(self, t):
"""
Updates the learning rate at the end of an iteration.
:param t: iteration number
:return: self
"""
eta = self.value_
if self.learning_rate == "optimal":
eta = 1.0 / (self.alpha * (self.optimal_init_ + t))
elif self.learning_rate == "invscaling":
eta = self.eta0_ / numpy.power(t + 1, self.power_t)
self.value_ = eta
return self

@property
def value(self):
"Returns the current learning rate."
return self.value_