diff --git a/_doc/sphinxdoc/source/api/training.rst b/_doc/sphinxdoc/source/api/training.rst index b5e33d8c..698032b4 100644 --- a/_doc/sphinxdoc/source/api/training.rst +++ b/_doc/sphinxdoc/source/api/training.rst @@ -11,6 +11,12 @@ BaseEstimator .. autosignature:: onnxcustom.training.optimizers.BaseEstimator :members: +LearingRate ++++++++++++ + +.. autosignature:: onnxcustom.training.sgd_learning_rate.LearningRateSGDRegressor + :members: + OrtGradientOptimizer ++++++++++++++++++++ diff --git a/_doc/sphinxdoc/source/onnxmd/index.rst b/_doc/sphinxdoc/source/onnxmd/index.rst index d0434296..19ec9cef 100644 --- a/_doc/sphinxdoc/source/onnxmd/index.rst +++ b/_doc/sphinxdoc/source/onnxmd/index.rst @@ -8,6 +8,6 @@ The documentation can be searched. .. toctree:: :maxdepth: 1 - + index_onnx index_onnxruntime diff --git a/_unittests/ut_training/test_optimizers.py b/_unittests/ut_training/test_optimizers.py index fb0e49f9..967149b7 100644 --- a/_unittests/ut_training/test_optimizers.py +++ b/_unittests/ut_training/test_optimizers.py @@ -13,6 +13,7 @@ from sklearn.linear_model import LinearRegression from mlprodict.onnx_conv import to_onnx from onnxcustom import __max_supported_opset__ as opset +from onnxcustom.training.sgd_learning_rate import LearningRateSGDRegressor try: from onnxruntime import TrainingSession except ImportError: @@ -136,7 +137,8 @@ def test_ort_gradient_optimizers_optimal_use_numpy(self): onx_loss = add_loss_output(onx) inits = ['intercept', 'coef'] train_session = OrtGradientOptimizer( - onx_loss, inits, learning_rate='optimal', max_iter=10) + onx_loss, inits, max_iter=10, + learning_rate=LearningRateSGDRegressor(learning_rate='optimal')) self.assertRaise(lambda: train_session.get_state(), AttributeError) train_session.fit(X, y, use_numpy=True) state_tensors = train_session.get_state() @@ -164,7 +166,8 @@ def test_ort_gradient_optimizers_optimal_use_ort(self): onx_loss = add_loss_output(onx) inits = ['intercept', 'coef'] train_session = OrtGradientOptimizer( - onx_loss, inits, learning_rate='optimal', max_iter=10) + onx_loss, inits, max_iter=10, + learning_rate=LearningRateSGDRegressor(learning_rate='optimal')) self.assertRaise(lambda: train_session.get_state(), AttributeError) train_session.fit(X, y, use_numpy=False) state_tensors = train_session.get_state() diff --git a/onnxcustom/training/optimizers.py b/onnxcustom/training/optimizers.py index e3ea35c9..d9c56f26 100644 --- a/onnxcustom/training/optimizers.py +++ b/onnxcustom/training/optimizers.py @@ -8,14 +8,21 @@ OrtValue, TrainingParameters, SessionOptions, TrainingSession) from .data_loader import OrtDataLoader +from .sgd_learning_rate import BaseLearningRate class BaseEstimator: """ Base class for optimizers. Implements common methods such `__repr__`. + + :param learning_rate: learning rate class, + see module :mod:`onnxcustom.training.sgd_learning_rate` """ + def __init__(self, learning_rate): + self.learning_rate = BaseLearningRate.select(learning_rate) + @classmethod def _get_param_names(cls): init = getattr(cls.__init__, "deprecated_original", cls.__init__) @@ -32,7 +39,9 @@ def __repr__(self): if k not in self.__dict__: continue # pragma: no cover ov = getattr(self, k) - if v is not inspect._empty or ov != v: + if isinstance(ov, BaseLearningRate): + ps.append("%s=%s" % (k, repr(ov))) + elif v is not inspect._empty or ov != v: ro = repr(ov) if len(ro) > 50 or "\n" in ro: ro = ro[:10].replace("\n", " ") + "..." @@ -53,24 +62,16 @@ class OrtGradientOptimizer(BaseEstimator): :param max_iter: number of training iterations :param training_optimizer_name: optimizing algorithm :param batch_size: batch size (see class *DataLoader*) - :param eta0: initial learning rate for the `'constant'`, `'invscaling'` - or `'adaptive'` schedules. - :param alpha: constant that multiplies the regularization term, - the higher the value, the stronger the regularization. - Also used to compute the learning rate when set to *learning_rate* - is set to `'optimal'`. - :param power_t: exponent for inverse scaling learning rate - :param learning_rate: learning rate schedule: - * `'constant'`: `eta = eta0` - * `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen - by a heuristic proposed by Leon Bottou. - * `'invscaling'`: `eta = eta0 / pow(t, power_t)` + :param learning_rate: a name or a learning rate instance, + see module :mod:`onnxcustom.training.sgd_learning_rate` :param device: `'cpu'` or `'cuda'` :param device_idx: device index :param warm_start: when set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. :param verbose: use :epkg:`tqdm` to display the training progress + :param validation_every: validation with a test set every + *validation_every* iterations Once initialized, the class creates the attribute `session_` which holds an instance of `onnxruntime.TrainingSession`. @@ -96,11 +97,12 @@ class OrtGradientOptimizer(BaseEstimator): def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', - batch_size=10, eta0=0.01, alpha=0.0001, power_t=0.25, - learning_rate='invscaling', device='cpu', device_idx=0, + batch_size=10, learning_rate='SGDRegressor', + device='cpu', device_idx=0, warm_start=False, verbose=0, validation_every=0.1): # See https://scikit-learn.org/stable/modules/generated/ # sklearn.linear_model.SGDRegressor.html + BaseEstimator.__init__(self, learning_rate) self.model_onnx = model_onnx self.batch_size = batch_size self.weights_to_train = weights_to_train @@ -108,10 +110,6 @@ def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', self.training_optimizer_name = training_optimizer_name self.verbose = verbose self.max_iter = max_iter - self.eta0 = eta0 - self.alpha = alpha - self.power_t = power_t - self.learning_rate = learning_rate.lower() self.device = device self.device_idx = device_idx self.warm_start = warm_start @@ -133,23 +131,6 @@ def __setstate__(self, state): setattr(self, att, v) return self - def _init_learning_rate(self): - self.eta0_ = self.eta0 - if self.learning_rate == "optimal": - typw = numpy.sqrt(1.0 / numpy.sqrt(self.alpha)) - self.eta0_ = typw / max(1.0, (1 + typw) * 2) - self.optimal_init_ = 1.0 / (self.eta0_ * self.alpha) - else: - self.eta0_ = self.eta0 - return self.eta0_ - - def _update_learning_rate(self, t, eta): - if self.learning_rate == "optimal": - eta = 1.0 / (self.alpha * (self.optimal_init_ + t)) - elif self.learning_rate == "invscaling": - eta = self.eta0_ / numpy.power(t + 1, self.power_t) - return eta - def fit(self, X, y, X_val=None, y_val=None, use_numpy=False): """ Trains the model. @@ -186,7 +167,7 @@ def fit(self, X, y, X_val=None, y_val=None, use_numpy=False): X_val, y_val, batch_size=X_val.shape[0], device=self.device) else: data_loader_val = None - lr = self._init_learning_rate() + self.learning_rate.init_learning_rate() self.input_names_ = [i.name for i in self.train_session_.get_inputs()] self.output_names_ = [ o.name for o in self.train_session_.get_outputs()] @@ -202,13 +183,14 @@ def fit(self, X, y, X_val=None, y_val=None, use_numpy=False): train_losses = [] val_losses = [] + lr = self.learning_rate.value for it in loop: bind_lr = OrtValue.ortvalue_from_numpy( numpy.array([lr / self.batch_size], dtype=numpy.float32), self.device, self.device_idx) loss = self._iteration(data_loader, bind_lr, bind, use_numpy=use_numpy) - lr = self._update_learning_rate(it, lr) + lr = self.learning_rate.update_learning_rate(it).value if self.verbose > 1: # pragma: no cover loop.set_description( "loss=%1.3g lr=%1.3g" % ( # pylint: disable=E1101,E1307 diff --git a/onnxcustom/training/sgd_learning_rate.py b/onnxcustom/training/sgd_learning_rate.py new file mode 100644 index 00000000..5a565ff5 --- /dev/null +++ b/onnxcustom/training/sgd_learning_rate.py @@ -0,0 +1,153 @@ +""" +@file +@brief Helper for :epkg:`onnxruntime-training`. +""" +import inspect +import numpy + + +class BaseLearningRate: + """ + Class handling the learning rate update after every + iteration of a gradient. + """ + + def __init__(self): + pass + + def init_learning_rate(self): + """ + Initializes the learning rate at the beginning of the training. + :return: self + """ + raise NotImplementedError( + "This method must be overwritten.") + + def update_learning_rate(self, t): + """ + Updates the learning rate at the end of an iteration. + :param t: iteration number + :return: self + """ + raise NotImplementedError( + "This method must be overwritten.") + + @property + def value(self): + "Returns the current learning rate." + raise NotImplementedError( + "This method must be overwritten.") + + @staticmethod + def select(class_name, **kwargs): + """ + Returns an instance of a given initialized with + *kwargs*. + :param class_name: an instance of @see cl BaseLearningRate + or a string among the following class names (see below) + :return: instance of @see cl BaseLearningRate + + Possible values for *class_name*: + * `'LearningRateSGDRegressor'`: see @see cl LearningRateSGDRegressor + """ + if isinstance(class_name, BaseLearningRate): + return class_name + cls = {LearningRateSGDRegressor: ['SGDRegressor']} + for cl, aliases in cls.items(): + if class_name == cl.__class__.__name__ or class_name in aliases: + return cl(**kwargs) + raise ValueError( + "Unexpected class name %r. It should be one of %r." % ( + class_name, list(map(lambda c: c.__name__, cls)))) + + @classmethod + def _get_param_names(cls): + init = getattr(cls.__init__, "deprecated_original", cls.__init__) + init_signature = inspect.signature(init) + parameters = [ + p for p in init_signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD] + return [(p.name, p.default) for p in parameters] + + def __repr__(self): + """ + Usual + """ + param = self._get_param_names() + ps = [] + for k, v in param: + if k not in self.__dict__: + continue # pragma: no cover + ov = getattr(self, k) + if v is not inspect._empty or ov != v: + ro = repr(ov) + ps.append("%s=%s" % (k, ro)) + return "%s(%s)" % (self.__class__.__name__, ", ".join(ps)) + + +class LearningRateSGDRegressor(BaseLearningRate): + """ + Implements the learning the same way as + :class:`sklearn.linear_model.SGDRegressor`. + + :param eta0: initial learning rate for the `'constant'`, `'invscaling'` + or `'adaptive'` schedules. + :param alpha: constant that multiplies the regularization term, + the higher the value, the stronger the regularization. + Also used to compute the learning rate when set to *learning_rate* + is set to `'optimal'`. + :param power_t: exponent for inverse scaling learning rate + :param learning_rate: learning rate schedule: + * `'constant'`: `eta = eta0` + * `'optimal'`: `eta = 1.0 / (alpha * (t + t0))` where *t0* is chosen + by a heuristic proposed by Leon Bottou. + * `'invscaling'`: `eta = eta0 / pow(t, power_t)` + + Created attributes: + * `eta0_`: initial eta0 + * `optimal_init_`: use when `learning_rate=='optimal'` + * `value_`: value to be returned by property `value` + """ + + def __init__(self, eta0=0.01, alpha=0.0001, power_t=0.25, + learning_rate='invscaling'): + BaseLearningRate.__init__(self) + self.eta0 = eta0 + self.alpha = alpha + self.power_t = power_t + self.learning_rate = learning_rate.lower() + self.value_ = None + + def init_learning_rate(self): + """ + Updates the learning rate at the end of an iteration. + :return: self + """ + self.eta0_ = self.eta0 + if self.learning_rate == "optimal": + typw = numpy.sqrt(1.0 / numpy.sqrt(self.alpha)) + self.eta0_ = typw / max(1.0, (1 + typw) * 2) + self.optimal_init_ = 1.0 / (self.eta0_ * self.alpha) + else: + self.eta0_ = self.eta0 + self.value_ = self.eta0_ + return self + + def update_learning_rate(self, t): + """ + Updates the learning rate at the end of an iteration. + :param t: iteration number + :return: self + """ + eta = self.value_ + if self.learning_rate == "optimal": + eta = 1.0 / (self.alpha * (self.optimal_init_ + t)) + elif self.learning_rate == "invscaling": + eta = self.eta0_ / numpy.power(t + 1, self.power_t) + self.value_ = eta + return self + + @property + def value(self): + "Returns the current learning rate." + return self.value_