diff --git a/_doc/examples/plot_benchmark_op_leakyrelu.py b/_doc/examples/plot_benchmark_op_leakyrelu.py index 26efed8b..714d0f7a 100644 --- a/_doc/examples/plot_benchmark_op_leakyrelu.py +++ b/_doc/examples/plot_benchmark_op_leakyrelu.py @@ -15,7 +15,6 @@ .. contents:: :local: - The ONNX graphs for both implementations of LeakyRely +++++++++++++++++++++++++++++++++++++++++++++++++++++ diff --git a/_doc/examples/plot_benchmark_op_short.py b/_doc/examples/plot_benchmark_op_short.py index 5b605512..4359f1c6 100644 --- a/_doc/examples/plot_benchmark_op_short.py +++ b/_doc/examples/plot_benchmark_op_short.py @@ -149,8 +149,9 @@ def build_ort_op(op_version=14, save=None, slices=None): # opset=13, 14, ... shape=str(shape).replace( " ", ""), slice=str(slices).replace( " ", "")) - r = measure_time('sess.run(None, dx)', number=number, div_by_number=True, - context={'sess': sess, 'dx': {'X': x}}) + r = measure_time(lambda: sess.run(None, {'X': x}), + number=number, div_by_number=True, + context={}) obs.update(r) obs['provider'] = 'CPU' data.append(obs) diff --git a/_doc/sphinxdoc/source/conf.py b/_doc/sphinxdoc/source/conf.py index 0b3a9de9..db5a623a 100644 --- a/_doc/sphinxdoc/source/conf.py +++ b/_doc/sphinxdoc/source/conf.py @@ -159,7 +159,7 @@ def callback_begin(): 'C_OrtValue': 'http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/' 'onnxmd/onnxruntime_python/ortvalue.html#c-class-ortvaluevector', - 'Contrib Operators' : + 'Contrib Operators': 'http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/onnxmd/' 'onnxruntime_docs/ContribOperators.html', 'Gemm': diff --git a/_doc/sphinxdoc/source/tutorial_onnx/challenges.rst b/_doc/sphinxdoc/source/tutorial_onnx/challenges.rst index f652e27d..49be97bd 100644 --- a/_doc/sphinxdoc/source/tutorial_onnx/challenges.rst +++ b/_doc/sphinxdoc/source/tutorial_onnx/challenges.rst @@ -121,7 +121,7 @@ Then unit test must be updated. * Update unit test. The PR should include the modified files and the modified markdown documentation, -usually a subset of +usually a subset of `docs/docs/Changelog-ml.md`, `docs/Changelog.md`, `docs/Operators-ml.md`, `docs/Operators.md`, `docs/TestCoverage-ml.md`, `docs/TestCoverage.md`. diff --git a/_doc/sphinxdoc/source/tutorial_onnx/concepts.rst b/_doc/sphinxdoc/source/tutorial_onnx/concepts.rst index bada2f81..edf54286 100644 --- a/_doc/sphinxdoc/source/tutorial_onnx/concepts.rst +++ b/_doc/sphinxdoc/source/tutorial_onnx/concepts.rst @@ -397,7 +397,7 @@ One example is the operator CDist. Notebook `Pairwise distances with ONNX (pdist goes into the details of it. Pairwise distances, as shown in section :ref:`l-operator-scan-onnx-tutorial`, can be implemented with operator Scan. However, a dedicated operator called CDist is proved significantly -faster, significantly to make the effort to implement a dedicated runtime +faster, significantly to make the effort to implement a dedicated runtime for it. Shape (and Type) Inference diff --git a/_unittests/ut_training/test_optimizers_forward_backward.py b/_unittests/ut_training/test_optimizers_forward_backward.py index adc0f970..5a65ce6b 100644 --- a/_unittests/ut_training/test_optimizers_forward_backward.py +++ b/_unittests/ut_training/test_optimizers_forward_backward.py @@ -6,7 +6,7 @@ import io import pickle import logging -from pyquickhelper.pycode import ExtTestCase, ignore_warnings +from pyquickhelper.pycode import ExtTestCase, ignore_warnings, skipif_appveyor import numpy from onnx.helper import set_model_props from sklearn.datasets import make_regression @@ -111,6 +111,7 @@ def test_ort_gradient_optimizers_use_numpy_exc(self): ConvergenceError) @unittest.skipIf(TrainingSession is None, reason="not training") + @skipif_appveyor("logging issue") def test_ort_gradient_optimizers_use_numpy_log(self): from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer X, y = make_regression( # pylint: disable=W0632 @@ -135,6 +136,28 @@ def test_ort_gradient_optimizers_use_numpy_log(self): self.assertTrue(res is train_session) self.assertIn("[OrtGradientForwardBackwardOptimizer._iteration]", logs) + @unittest.skipIf(TrainingSession is None, reason="not training") + def test_ort_gradient_optimizers_use_numpy_log_appveyor(self): + from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer + X, y = make_regression( # pylint: disable=W0632 + 100, n_features=2, bias=2, random_state=0) + X[:10, :] = 0 + X = X.astype(numpy.float32) + y = (X.sum(axis=1) + y / 1000).astype(numpy.float32) + X_train, _, y_train, __ = train_test_split(X, y) + reg = LinearRegression() + reg.fit(X_train, y_train) + reg.coef_ = reg.coef_.reshape((1, -1)) + onx = to_onnx(reg, X_train, target_opset=opset, + black_op={'LinearRegressor'}) + # onx = onnx_rename_weights(onx) + set_model_props(onx, {'info': 'unit test'}) + inits = ['coef', 'intercept'] + train_session = OrtGradientForwardBackwardOptimizer( + onx, inits, enable_logging=True) + res = train_session.fit(X, y, use_numpy=True) + self.assertTrue(res is train_session) + @unittest.skipIf(TrainingSession is None, reason="not training") def test_ort_gradient_optimizers_use_numpy_pickle(self): from onnxcustom.training.optimizers_partial import OrtGradientForwardBackwardOptimizer diff --git a/_unittests/ut_training/test_orttraining_forward_backward.py b/_unittests/ut_training/test_orttraining_forward_backward.py index 387d3c4a..1e99c49f 100644 --- a/_unittests/ut_training/test_orttraining_forward_backward.py +++ b/_unittests/ut_training/test_orttraining_forward_backward.py @@ -25,12 +25,15 @@ class TestOrtTrainingForwardBackward(ExtTestCase): - def forward_no_training(self): + def forward_no_training(self, exc=None, verbose=False): + if exc is None: + exc = __name__ != '__main__' from onnxruntime.capi._pybind_state import ( - OrtValue as C_OrtValue, OrtDevice, OrtMemType) + OrtValue as C_OrtValue, OrtDevice as C_OrtDevice, OrtMemType) from onnxruntime.capi._pybind_state import ( OrtValueVector) from onnxcustom.training.ortgradient import OrtGradientForwardBackward + X, y = make_regression( # pylint: disable=W0632 100, n_features=10, bias=2) X = X.astype(numpy.float32) @@ -43,10 +46,17 @@ def forward_no_training(self): black_op={'LinearRegressor'}) # starts testing - self.assertRaise( - lambda: OrtGradientForwardBackward( - onx, debug=True, enable_logging=True, providers=['NONE']), - ValueError) + if verbose: + print("[forward_no_training] start testing") + if exc: + if verbose: + print("[forward_no_training] check exception") + self.assertRaise( + lambda: OrtGradientForwardBackward( + onx, debug=True, enable_logging=True, providers=['NONE']), + ValueError) + if verbose: + print("[forward_no_training] instantiate") forback = OrtGradientForwardBackward( onx, debug=True, enable_logging=True) self.assertEqual(repr(forback), "OrtGradientForwardBackward(...)") @@ -64,16 +74,25 @@ def forward_no_training(self): ['X_grad', 'coef_grad', 'intercept_grad']) self.assertEqual(forback.cls_type_._output_names, ['variable']) + if verbose: + print("[forward_no_training] expected prediction") + expected = reg.predict(X_test) coef = reg.coef_.astype(numpy.float32).reshape((-1, 1)) intercept = numpy.array([reg.intercept_], dtype=numpy.float32) + if verbose: + print("[forward_no_training] InferenceSession") + sess0 = InferenceSession(onx.SerializeToString()) inames = [i.name for i in sess0.get_inputs()] # pylint: disable=E1101 self.assertEqual(inames, ['X']) got = sess0.run(None, {'X': X_test}) self.assertEqualArray(expected.ravel(), got[0].ravel(), decimal=4) + if verbose: + print("[forward_no_training] evaluation") + sess_eval = forback.cls_type_._sess_eval # pylint: disable=E1101 inames = [i.name for i in sess_eval.get_inputs()] self.assertEqual(inames, ['X', 'coef', 'intercept']) @@ -82,8 +101,10 @@ def forward_no_training(self): self.assertEqualArray(expected.ravel(), got[0].ravel(), decimal=4) # OrtValue + if verbose: + print("[forward_no_training] OrtValue") inst = forback.new_instance() - device = OrtDevice(OrtDevice.cpu(), OrtMemType.DEFAULT, 0) + device = C_OrtDevice(C_OrtDevice.cpu(), OrtMemType.DEFAULT, 0) # list of OrtValues inputs = [] @@ -95,6 +116,8 @@ def forward_no_training(self): self.assertEqualArray(expected.ravel(), got[0].ravel(), decimal=4) # OrtValueVector + if verbose: + print("[forward_no_training] OrtValueVector") inputs = OrtValueVector() for a in [X_test, coef, intercept]: inputs.push_back(C_OrtValue.ortvalue_from_numpy(a, device)) @@ -104,16 +127,20 @@ def forward_no_training(self): expected.ravel(), got[0].numpy().ravel(), decimal=4) # numpy + if verbose: + print("[forward_no_training] numpy") inputs = [X_test, coef, intercept] got = inst.forward(inputs) self.assertEqual(len(got), 1) self.assertEqualArray( expected.ravel(), got[0].numpy().ravel(), decimal=4) + if verbose: + print("[forward_no_training] end") @unittest.skipIf(TrainingSession is None, reason="no training") def test_forward_no_training(self): res, logs = self.assertLogging( - self.forward_no_training, 'onnxcustom') + lambda: self.forward_no_training(exc=True), 'onnxcustom') self.assertEmpty(res) if len(logs) > 0: self.assertIn("[OrtGradientForwardBackward]", logs) @@ -122,7 +149,7 @@ def test_forward_no_training(self): @unittest.skipIf(TrainingSession is None, reason="no training") def test_forward_no_training_pickle(self): from onnxruntime.capi._pybind_state import ( - OrtValue as C_OrtValue, OrtDevice, OrtMemType) + OrtValue as C_OrtValue, OrtMemType, OrtDevice as C_OrtDevice) from onnxruntime.capi._pybind_state import ( OrtValueVector) from onnxcustom.training.ortgradient import OrtGradientForwardBackward @@ -176,7 +203,7 @@ def test_forward_no_training_pickle(self): # OrtValue inst = forback.new_instance() inputs = [] - device = OrtDevice(OrtDevice.cpu(), OrtMemType.DEFAULT, 0) + device = C_OrtDevice(C_OrtDevice.cpu(), OrtMemType.DEFAULT, 0) for a in [X_test, coef, intercept]: inputs.append(C_OrtValue.ortvalue_from_numpy(a, device)) got_ort = inst.forward(inputs) @@ -203,7 +230,7 @@ def test_forward_no_training_pickle(self): def forward_training(self, model, debug=False, n_classes=3, add_print=False): from onnxruntime.capi._pybind_state import ( - OrtValue as C_OrtValue, OrtDevice, OrtMemType) + OrtValue as C_OrtValue, OrtMemType, OrtDevice as C_OrtDevice) from onnxruntime.capi._pybind_state import ( OrtValueVector) from onnxcustom.training.ortgradient import OrtGradientForwardBackward @@ -284,7 +311,7 @@ def to_proba(yt): # OrtValue inst = forback.new_instance() - device = OrtDevice(OrtDevice.cpu(), OrtMemType.DEFAULT, 0) + device = C_OrtDevice(C_OrtDevice.cpu(), OrtMemType.DEFAULT, 0) # OrtValueVector if add_print: @@ -370,4 +397,5 @@ def test_forward_training_logreg(self): if __name__ == "__main__": - unittest.main() + # TestOrtTrainingForwardBackward().forward_no_training(verbose=True) + unittest.main(verbosity=2) diff --git a/_unittests/ut_utils/test_onnxruntime_helper.py b/_unittests/ut_utils/test_onnxruntime_helper.py index e3caf21b..153f6c11 100644 --- a/_unittests/ut_utils/test_onnxruntime_helper.py +++ b/_unittests/ut_utils/test_onnxruntime_helper.py @@ -4,7 +4,9 @@ import unittest from pyquickhelper.pycode import ExtTestCase from onnxcustom.utils.onnxruntime_helper import ( - device_to_provider, provider_to_device, get_ort_device_type) + provider_to_device, get_ort_device_type, + get_ort_device, ort_device_to_string, + device_to_providers) class TestOnnxRuntimeHelper(ExtTestCase): @@ -15,15 +17,28 @@ def test_provider_to_device(self): self.assertRaise(lambda: provider_to_device('NONE'), ValueError) def test_device_to_provider(self): - self.assertEqual(device_to_provider('cpu'), 'CPUExecutionProvider') - self.assertEqual(device_to_provider('gpu'), 'CUDAExecutionProvider') - self.assertRaise(lambda: device_to_provider('NONE'), ValueError) + self.assertEqual(device_to_providers('cpu'), ['CPUExecutionProvider']) + self.assertEqual(device_to_providers('gpu'), ['CUDAExecutionProvider']) + self.assertRaise(lambda: device_to_providers('NONE'), ValueError) def test_get_ort_device_type(self): self.assertEqual(get_ort_device_type('cpu'), 0) self.assertEqual(get_ort_device_type('cuda'), 1) self.assertRaise(lambda: get_ort_device_type('none'), ValueError) + def test_ort_device_to_string(self): + for value in ['cpu', 'cuda', ('gpu', 'cuda'), + ('gpu:0', 'cuda'), ('cuda:0', 'cuda'), + ('gpu:1', 'cuda:1'), 'cuda:1']: + with self.subTest(device=value): + if isinstance(value, str): + a, b = value, value + else: + a, b = value + dev = get_ort_device(a) + back = ort_device_to_string(dev) + self.assertEqual(b, back) + if __name__ == "__main__": unittest.main() diff --git a/onnxcustom/training/data_loader.py b/onnxcustom/training/data_loader.py index 174cd360..67944816 100644 --- a/onnxcustom/training/data_loader.py +++ b/onnxcustom/training/data_loader.py @@ -3,7 +3,8 @@ @brief Manipulate data for training. """ import numpy -from onnxruntime import OrtValue as PyOrtValue +from ..utils.onnxruntime_helper import ( + get_ort_device, numpy_to_ort_value, ort_device_to_string) class OrtDataLoader: @@ -15,14 +16,13 @@ class OrtDataLoader: :param X: features :param y: labels :param batch_size: batch size (consecutive observations) - :param device: `'cpu'` or `'cuda'` - :param device_index: device index + :param device: :epkg:`C_OrtDevice` or a string such as `'cpu'` :param random_iter: random iteration See example :ref:`l-orttraining-nn-gpu`. """ - def __init__(self, X, y, batch_size=20, device='cpu', device_index=0, + def __init__(self, X, y, batch_size=20, device='cpu', random_iter=True): if len(y.shape) == 1: y = y.reshape((-1, 1)) @@ -30,45 +30,42 @@ def __init__(self, X, y, batch_size=20, device='cpu', device_index=0, raise ValueError( # pragma: no cover "Shape mismatch X.shape=%r, y.shape=%r." % (X.shape, y.shape)) + self.batch_size = batch_size + self.device = get_ort_device(device) + self.random_iter = random_iter + self.X_np = numpy.ascontiguousarray(X) self.y_np = numpy.ascontiguousarray(y).reshape((-1, 1)) - self.X_ort = PyOrtValue.ortvalue_from_numpy( - self.X_np, device, device_index)._ortvalue - self.y_ort = PyOrtValue.ortvalue_from_numpy( - self.y_np, device, device_index)._ortvalue + self.X_ort = numpy_to_ort_value(self.X_np, self.device) + self.y_ort = numpy_to_ort_value(self.y_np, self.device) self.desc = [(self.X_np.shape, self.X_np.dtype), (self.y_np.shape, self.y_np.dtype)] - self.batch_size = batch_size - self.device = device - self.device_index = device_index - self.random_iter = random_iter - def __getstate__(self): "Removes any non pickable attribute." state = {} for att in ['X_np', 'y_np', 'desc', 'batch_size', - 'device', 'device_index', 'random_iter']: + 'random_iter']: state[att] = getattr(self, att) + state['device'] = ort_device_to_string(self.device) return state def __setstate__(self, state): "Restores any non pickable attribute." for att, v in state.items(): setattr(self, att, v) - self.X_ort = PyOrtValue.ortvalue_from_numpy( - self.X_np, self.device, self.device_index)._ortvalue - self.y_ort = PyOrtValue.ortvalue_from_numpy( - self.y_np, self.device, self.device_index)._ortvalue + self.device = get_ort_device(self.device) + self.X_ort = numpy_to_ort_value(self.X_np, self.device) + self.y_ort = numpy_to_ort_value(self.y_np, self.device) return self def __repr__(self): "usual" - return "%s(..., ..., batch_size=%r, device=%r, device_index=%r)" % ( - self.__class__.__name__, self.batch_size, self.device, - self.device_index) + return "%s(..., ..., batch_size=%r, device=%r)" % ( + self.__class__.__name__, self.batch_size, + ort_device_to_string(self.device)) def __len__(self): "Returns the number of observations." @@ -92,10 +89,10 @@ def iter_numpy(self): This iterator is slow as it copies the data of every batch. The function yields :epkg:`OrtValue`. """ - if self.device not in ('Cpu', 'cpu'): + if self.device.device_type() != self.device.cpu(): raise RuntimeError( # pragma: no cover "Only CPU device is allowed if numpy arrays are requested " - "not %r." % self.device) + "not %r." % ort_device_to_string(self.device)) N = 0 b = len(self) - self.batch_size if b <= 0 or self.batch_size <= 0: @@ -119,10 +116,8 @@ def iter_ortvalue(self): b = len(self) - self.batch_size if b <= 0 or self.batch_size <= 0: yield ( - PyOrtValue.ortvalue_from_numpy( - self.X_np, self.device, self.device_index)._ortvalue, - PyOrtValue.ortvalue_from_numpy( - self.y_np, self.device, self.device_index)._ortvalue) + numpy_to_ort_value(self.X_np, self.device), + numpy_to_ort_value(self.y_np, self.device)) else: i = -1 while N < len(self): @@ -131,10 +126,8 @@ def iter_ortvalue(self): xp = self.X_np[i:i + self.batch_size] yp = self.y_np[i:i + self.batch_size] yield ( - PyOrtValue.ortvalue_from_numpy( - xp, self.device, self.device_index)._ortvalue, - PyOrtValue.ortvalue_from_numpy( - yp, self.device, self.device_index)._ortvalue) + numpy_to_ort_value(xp, self.device), + numpy_to_ort_value(yp, self.device)) def iter_bind(self, bind, names): """ @@ -158,20 +151,11 @@ def local_bind(bind, offset, n): shape_y = (n, n_col_y) bind.bind_input( - name=names[0], - device_type=self.device, - device_id=self.device_index, - element_type=self.desc[0][1], - shape=shape_X, - buffer_ptr=self.X_ort.data_ptr() + offset * n_col_x * size_x) - + names[0], self.device, self.desc[0][1], shape_X, + self.X_ort.data_ptr() + offset * n_col_x * size_x) bind.bind_input( - name=names[1], - device_type=self.device, - device_id=self.device_index, - element_type=self.desc[0][1], - shape=shape_y, - buffer_ptr=self.y_ort.data_ptr() + offset * n_col_y * size_y) + names[1], self.device, self.desc[0][1], shape_y, + self.y_ort.data_ptr() + offset * n_col_y * size_y) N = 0 b = len(self) - self.batch_size diff --git a/onnxcustom/training/optimizers.py b/onnxcustom/training/optimizers.py index 8f20072d..7f940890 100644 --- a/onnxcustom/training/optimizers.py +++ b/onnxcustom/training/optimizers.py @@ -5,11 +5,14 @@ import inspect import numpy from onnxruntime import ( # pylint: disable=E0611 - OrtValue as PyOrtValue, TrainingParameters, - SessionOptions, TrainingSession) + TrainingParameters, SessionOptions, TrainingSession) from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 - OrtValue as C_OrtValue) + OrtValue as C_OrtValue, SessionIOBinding as C_IOBinding, + OrtDevice as C_OrtDevice) from ..utils.onnx_helper import proto_type_to_dtype +from ..utils.onnxruntime_helper import ( + get_ort_device, ort_device_to_string, numpy_to_ort_value, + device_to_providers) from .data_loader import OrtDataLoader from .sgd_learning_rate import BaseLearningRate from .excs import ConvergenceError, EvaluationError @@ -22,10 +25,13 @@ class BaseEstimator: :param learning_rate: learning rate class, see module :mod:`onnxcustom.training.sgd_learning_rate` + :param device: device as :epkg:`C_OrtDevice` or a string + representing this device """ - def __init__(self, learning_rate): + def __init__(self, learning_rate, device): self.learning_rate = BaseLearningRate.select(learning_rate) + self.device = get_ort_device(device) @classmethod def _get_param_names(cls): @@ -47,6 +53,8 @@ def __repr__(self): ov = getattr(self, k) if isinstance(ov, BaseLearningRate): ps.append("%s=%s" % (k, repr(ov))) + elif isinstance(ov, C_OrtDevice): + ps.append("%s=%r" % (k, ort_device_to_string(ov))) elif v is not inspect._empty or ov != v: ro = repr(ov) if len(ro) > 50 or "\n" in ro: @@ -56,6 +64,22 @@ def __repr__(self): ps.append("%s=%r" % (k, ov)) return "%s(%s)" % (self.__class__.__name__, ", ".join(ps)) + def __getstate__(self): + "Removes any non pickable attribute." + atts = [k for k in self.__dict__ if not k.endswith('_')] + if hasattr(self, 'trained_coef_'): + atts.append('trained_coef_') + state = {att: getattr(self, att) for att in atts} + state['device'] = ort_device_to_string(state['device']) + return state + + def __setstate__(self, state): + "Restores any non pickable attribute." + for att, v in state.items(): + setattr(self, att, v) + self.device = get_ort_device(self.device) + return self + class OrtGradientOptimizer(BaseEstimator): """ @@ -70,8 +94,8 @@ class OrtGradientOptimizer(BaseEstimator): :param batch_size: batch size (see class *DataLoader*) :param learning_rate: a name or a learning rate instance or a float, see module :mod:`onnxcustom.training.sgd_learning_rate` - :param device: `'cpu'` or `'cuda'` - :param device_index: device index + :param device: device as :epkg:`C_OrtDevice` or a string + representing this device :param warm_start: when set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. @@ -88,9 +112,9 @@ class OrtGradientOptimizer(BaseEstimator): def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate='SGDRegressor', - device='cpu', device_index=0, - warm_start=False, verbose=0, validation_every=0.1): - BaseEstimator.__init__(self, learning_rate) + device='cpu', warm_start=False, verbose=0, + validation_every=0.1): + BaseEstimator.__init__(self, learning_rate, device) self.model_onnx = model_onnx self.batch_size = batch_size self.weights_to_train = weights_to_train @@ -98,27 +122,12 @@ def __init__(self, model_onnx, weights_to_train, loss_output_name='loss', self.training_optimizer_name = training_optimizer_name self.verbose = verbose self.max_iter = max_iter - self.device = device - self.device_index = device_index self.warm_start = warm_start if validation_every < 1: self.validation_every = int(self.max_iter * validation_every) else: self.validation_every = validation_every # pragma: no cover - def __getstate__(self): - "Removes any non pickable attribute." - atts = [k for k in self.__dict__ if not k.endswith('_')] - if hasattr(self, 'trained_coef_'): - atts.append('trained_coef_') - return {att: getattr(self, att) for att in atts} - - def __setstate__(self, state): - "Restores any non pickable attribute." - for att, v in state.items(): - setattr(self, att, v) - return self - def fit(self, X, y, X_val=None, y_val=None, use_numpy=False): """ Trains the model. @@ -164,7 +173,7 @@ def fit(self, X, y, X_val=None, y_val=None, use_numpy=False): o.name for o in self.train_session_.get_outputs()] self.loss_index_ = self.output_names_.index(self.loss_output_name) - bind = self.train_session_.io_binding() + bind = self.train_session_.io_binding()._iobinding if self.verbose > 0: # pragma: no cover from tqdm import tqdm # pylint: disable=C0415 @@ -177,8 +186,7 @@ def fit(self, X, y, X_val=None, y_val=None, use_numpy=False): lr = self.learning_rate.value for it in loop: lr_alive = numpy.array([lr / self.batch_size], dtype=numpy.float32) - ort_lr = PyOrtValue.ortvalue_from_numpy( - lr_alive, self.device, self.device_index)._ortvalue + ort_lr = numpy_to_ort_value(lr_alive, self.device) loss = self._iteration(data_loader, ort_lr, bind, use_numpy=use_numpy) lr = self.learning_rate.update_learning_rate(it).value @@ -207,6 +215,9 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue): :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`), it can be also a numpy array """ + if not isinstance(bind, C_IOBinding): + raise TypeError( + "Unexpected type %r." % type(bind)) if isinstance(c_ortvalue, C_OrtValue): # does not work # bind._iobinding.bind_ortvalue_input(name, c_ortvalue) @@ -214,18 +225,12 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue): c_ortvalue.proto_type() if hasattr(c_ortvalue, 'proto_type') else c_ortvalue.data_type()) bind.bind_input( - name=name, device_type=self.device, - device_id=self.device_index, - element_type=dtype, - shape=c_ortvalue.shape(), - buffer_ptr=c_ortvalue.data_ptr()) + name, self.device, dtype, c_ortvalue.shape(), + c_ortvalue.data_ptr()) elif isinstance(c_ortvalue, numpy.ndarray): bind.bind_input( - name, device_type=self.device, - device_id=self.device_index, - element_type=c_ortvalue.dtype, - shape=c_ortvalue.shape, - buffer_ptr=c_ortvalue.__array_interface__['data'][0]) + name, self.device, c_ortvalue.dtype, c_ortvalue.shape, + c_ortvalue.__array_interface__['data'][0]) else: raise TypeError( # pragma: no cover "Unable to bind type %r for name %r." % ( @@ -234,7 +239,7 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue): def _iteration(self, data_loader, ort_lr, bind, use_numpy): actual_losses = [] - bind.bind_output('loss') + bind.bind_output('loss', self.device) if use_numpy: # onnxruntime does not copy the data, so the numpy @@ -251,7 +256,7 @@ def _iteration(self, data_loader, ort_lr, bind, use_numpy): self._bind_input_ortvalue( self.input_names_[1], bind, target) - self.train_session_.run_with_iobinding(bind) + self.train_session_._sess.run_with_iobinding(bind, None) outputs = bind.copy_outputs_to_cpu() if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]): raise ConvergenceError( @@ -269,7 +274,7 @@ def _iteration(self, data_loader, ort_lr, bind, use_numpy): # Fast iterations # Slow iterations. for batch_size in data_loader.iter_bind(bind, self.input_names_): - self.train_session_.run_with_iobinding(bind) + self.train_session_._sess.run_with_iobinding(bind, None) # We copy the predicted output as well which is not needed. outputs = bind.copy_outputs_to_cpu() if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]): @@ -288,11 +293,11 @@ def _iteration(self, data_loader, ort_lr, bind, use_numpy): def _evaluation(self, data_loader, bind): lr_alive = numpy.array([0], dtype=numpy.float32) self._bind_input_ortvalue(self.input_names_[2], bind, lr_alive) - bind.bind_output('loss') + bind.bind_output('loss', self.device) actual_losses = [] for batch_size in data_loader.iter_bind(bind, self.input_names_): - self.train_session_.run_with_iobinding(bind) + self.train_session_._sess.run_with_iobinding(bind, None) outputs = bind.copy_outputs_to_cpu() if numpy.isinf(outputs[0]) or numpy.isnan(outputs[0]): raise EvaluationError( # pragma: no cover @@ -339,19 +344,10 @@ def _create_training_session( session_options = SessionOptions() # session_options.use_deterministic_compute = True - lower_device = device.lower() - if lower_device == 'cpu': - provider = ['CPUExecutionProvider'] - elif (lower_device.startswith("cuda") or - lower_device == 'gpu'): # pragma: no cover - provider = ['CUDAExecutionProvider'] - else: - raise ValueError( # pragma: no cover - "Unexpected device %r." % device) - + providers = device_to_providers(self.device) session = TrainingSession( training_onnx.SerializeToString(), ort_parameters, session_options, - providers=provider) + providers=providers) return session diff --git a/onnxcustom/training/optimizers_partial.py b/onnxcustom/training/optimizers_partial.py index f1802ed0..cb884e44 100644 --- a/onnxcustom/training/optimizers_partial.py +++ b/onnxcustom/training/optimizers_partial.py @@ -4,11 +4,12 @@ """ import logging import numpy -from onnxruntime import InferenceSession, OrtValue +from onnxruntime import InferenceSession from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 OrtValue as C_OrtValue) from ..utils.onnx_helper import get_onnx_opset, proto_type_to_dtype -from ..utils.onnxruntime_helper import device_to_provider +from ..utils.onnxruntime_helper import ( + device_to_providers, numpy_to_ort_value) from ..utils.onnx_function import function_onnx_graph from ..utils.print_helper import str_ortvalue from ..utils.onnx_orttraining import get_train_initializer @@ -34,8 +35,8 @@ class OrtGradientForwardBackwardOptimizer(BaseEstimator): :param batch_size: batch size (see class *DataLoader*) :param learning_rate: a name or a learning rate instance or a float, see module :mod:`onnxcustom.training.sgd_learning_rate` - :param device: `'cpu'` or `'cuda'` - :param device_index: device index + :param device: device as :epkg:`C_OrtDevice` or a string + representing this device :param warm_start: when set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. @@ -54,13 +55,12 @@ def __init__(self, model_onnx, weights_to_train=None, loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate='SGDRegressor', - device='cpu', device_index=0, - warm_start=False, verbose=0, validation_every=0.1, - loss_function="square_error", + device='cpu', warm_start=False, verbose=0, + validation_every=0.1, loss_function="square_error", enable_logging=False): if weights_to_train is None: weights_to_train = list(get_train_initializer(model_onnx)) - BaseEstimator.__init__(self, learning_rate) + BaseEstimator.__init__(self, learning_rate, device) self.model_onnx = model_onnx self.batch_size = batch_size self.weights_to_train = weights_to_train @@ -68,8 +68,6 @@ def __init__(self, model_onnx, weights_to_train=None, self.training_optimizer_name = training_optimizer_name self.verbose = verbose self.max_iter = max_iter - self.device = device - self.device_index = device_index self.warm_start = warm_start self.loss_function = loss_function self.enable_logging = enable_logging @@ -81,8 +79,7 @@ def __init__(self, model_onnx, weights_to_train=None, def __getstate__(self): "Removes any non pickable attribute." - atts = [k for k in self.__dict__ if not k.endswith('_')] - state = {att: getattr(self, att) for att in atts} + state = BaseEstimator.__getstate__(self) if hasattr(self, 'train_state_'): train_state = [] for v in self.get_state(): @@ -99,8 +96,7 @@ def __setstate__(self, state): train_state = state.pop('train_state') else: train_state = None - for att, v in state.items(): - setattr(self, att, v) + BaseEstimator.__setstate__(self, state) if train_state is not None: self.set_state(train_state, check_trained=False) self._build_loss_function() @@ -144,8 +140,7 @@ def set_state(self, state, check_trained=True): self.train_state_.append(None) self.train_state_numpy_.append(None) elif isinstance(v, numpy.ndarray): - ortvalue = OrtValue.ortvalue_from_numpy( - v, self.device, self.device_index)._ortvalue + ortvalue = numpy_to_ort_value(v, self.device) self.train_state_.append(ortvalue) # The numpy container must be retained as the ortvalue # just borrows the pointer. @@ -164,11 +159,12 @@ def _build_loss_function(self): "grad_loss_" + self.loss_function, target_opset=opset) self.loss_grad_sess_ = InferenceSession( self.loss_grad_onnx_.SerializeToString()) - self.loss_grad_sess_bind_ = self.loss_grad_sess_.io_binding() + self.loss_grad_sess_bind_ = ( + self.loss_grad_sess_.io_binding()._iobinding) self.axpy_onnx_ = function_onnx_graph("axpy") self.axpy_sess_ = InferenceSession(self.axpy_onnx_.SerializeToString()) - self.axpy_sess_bind_ = self.axpy_sess_.io_binding() + self.axpy_sess_bind_ = self.axpy_sess_.io_binding()._iobinding if self.enable_logging: self._logger = logging.getLogger("onnxcustom") @@ -313,18 +309,12 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue): else: dtype = proto_type_to_dtype(c_ortvalue.data_type()) bind.bind_input( - name=name, device_type=self.device, - device_id=self.device_index, - element_type=dtype, - shape=c_ortvalue.shape(), - buffer_ptr=c_ortvalue.data_ptr()) + name, self.device, dtype, c_ortvalue.shape(), + c_ortvalue.data_ptr()) elif isinstance(c_ortvalue, numpy.ndarray): bind.bind_input( - name, device_type=self.device, - device_id=self.device_index, - element_type=c_ortvalue.dtype, - shape=c_ortvalue.shape, - buffer_ptr=c_ortvalue.__array_interface__['data'][0]) + name, self.device, c_ortvalue.dtype, c_ortvalue.shape, + c_ortvalue.__array_interface__['data'][0]) else: raise TypeError( # pragma: no cover "Unable to bind type %r for name %r." % ( @@ -349,11 +339,8 @@ def _bind_output_ortvalue(self, name, bind, c_ortvalue): else: dtype = proto_type_to_dtype(c_ortvalue.data_type()) bind.bind_output( - name=name, device_type=self.device, - device_id=self.device_index, - element_type=dtype, - shape=c_ortvalue.shape(), - buffer_ptr=c_ortvalue.data_ptr()) + name, self.device, dtype, c_ortvalue.shape(), + c_ortvalue.data_ptr()) else: raise TypeError( # pragma: no cover "Unable to bind type %r for name %r." % ( @@ -365,10 +352,11 @@ def _loss_gradient(self, expected, predicted): """ self._bind_input_ortvalue("X1", self.loss_grad_sess_bind_, expected) self._bind_input_ortvalue("X2", self.loss_grad_sess_bind_, predicted) - self.loss_grad_sess_bind_.bind_output('Y') - self.loss_grad_sess_bind_.bind_output('Z') - self.loss_grad_sess_.run_with_iobinding(self.loss_grad_sess_bind_) - loss, grad = self.loss_grad_sess_bind_._iobinding.get_outputs() + self.loss_grad_sess_bind_.bind_output('Y', self.device) + self.loss_grad_sess_bind_.bind_output('Z', self.device) + self.loss_grad_sess_._sess.run_with_iobinding( + self.loss_grad_sess_bind_, None) + loss, grad = self.loss_grad_sess_bind_.get_outputs() return loss, grad def _update_weights(self, statei, gradienti, alpha): @@ -378,8 +366,8 @@ def _update_weights(self, statei, gradienti, alpha): self._bind_input_ortvalue( "alpha", self.axpy_sess_bind_, alpha_alive) self._bind_output_ortvalue('Y', self.axpy_sess_bind_, statei) - self.axpy_sess_.run_with_iobinding(self.axpy_sess_bind_) - return self.axpy_sess_bind_._iobinding.get_outputs()[0] + self.axpy_sess_._sess.run_with_iobinding(self.axpy_sess_bind_, None) + return self.axpy_sess_bind_.get_outputs()[0] def _iteration(self, data_loader, learning_rate, state, n_weights): actual_losses = [] @@ -475,6 +463,6 @@ def _create_training_session( forback = OrtGradientForwardBackward( model_onnx, weights_to_train=weights_to_train, debug=False, enable_logging=False, - providers=[device_to_provider(device)]) + providers=device_to_providers(device)) inst = forback.new_instance() return (forback, inst) diff --git a/onnxcustom/training/ortgradient.py b/onnxcustom/training/ortgradient.py index db1e3151..bb00df89 100644 --- a/onnxcustom/training/ortgradient.py +++ b/onnxcustom/training/ortgradient.py @@ -92,7 +92,8 @@ def __init__(self, onnx_model, weights_to_train=None, self.output_names = [obj.name for obj in self.onnx_model.graph.output] if self.class_name is None: - self.class_name = "TorchOrtFunction_%r" % id(self) # pragma: no cover + self.class_name = "TorchOrtFunction_%r" % id( + self) # pragma: no cover if hasattr(self.providers, 'type'): if self.providers.type != 'cpu': self.device_index = self.providers.index diff --git a/onnxcustom/utils/onnxruntime_helper.py b/onnxcustom/utils/onnxruntime_helper.py index 395919a9..bec281f4 100644 --- a/onnxcustom/utils/onnxruntime_helper.py +++ b/onnxcustom/utils/onnxruntime_helper.py @@ -3,28 +3,7 @@ @brief Onnxruntime helper. """ from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 - OrtDevice as C_OrtDevice) - - -def device_to_provider(device_name): - """ - Converts device into a provider. - - :param device_name: device name (cpu or gpu or cuda) - :return: provider - - .. runpython:: - :showcode: - - from onnxcustom.utils.onnxruntime_helper import device_to_provider - print(device_to_provider('cpu')) - """ - if device_name in ('cpu', 'Cpu'): - return 'CPUExecutionProvider' - if device_name in ('Gpu', 'gpu', 'Cuda', 'cuda', 'cuda:0', 'cuda:1'): - return 'CUDAExecutionProvider' - raise ValueError( - "Unexpected value for device_name=%r." % device_name) + OrtDevice as C_OrtDevice, OrtValue as C_OrtValue) def provider_to_device(provider_name): @@ -50,10 +29,10 @@ def provider_to_device(provider_name): def get_ort_device_type(device): """ - Converts device into :epkg:`C_OrtDevice`. + Converts device into device type. :param device: string - :return: :epkg:`C_OrtDevice` + :return: device type """ device_type = device if isinstance(device, str) else device.type if device_type == 'cuda': @@ -61,3 +40,89 @@ def get_ort_device_type(device): if device_type == 'cpu': return C_OrtDevice.cpu() raise ValueError('Unsupported device type: %r.' % device_type) + + +def get_ort_device(device): + """ + Converts device into :epkg:`C_OrtDevice`. + + :param device: any type + :return: :epkg:`C_OrtDevice` + """ + if isinstance(device, C_OrtDevice): + return device + if isinstance(device, str): + if device == 'cpu': + return C_OrtDevice( + C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) + if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}: + return C_OrtDevice( + C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + if device.startswith('gpu:'): + idx = int(device[4:]) + return C_OrtDevice( + C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) + if device.startswith('cuda:'): + idx = int(device[5:]) + return C_OrtDevice( + C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) + raise ValueError( + "Unable to interpret string %r as a device." % device) + raise TypeError( + "Unable to interpret type %r, (%r) as de device." % ( + type(device), device)) + + +def ort_device_to_string(device): + """ + Returns a string representing the device. + Opposite of function @see fn get_ort_device. + + :param device: see :epkg:`C_OrtDevice` + :return: string + """ + if not isinstance(device, C_OrtDevice): + raise TypeError( + "device must be of type C_OrtDevice not %r." % type(device)) + ty = device.device_type() + if ty == C_OrtDevice.cpu(): + sty = 'cpu' + elif ty == C_OrtDevice.cuda(): + sty = 'cuda' + else: + raise NotImplementedError( # pragma: no cover + "Unable to guess device for %r and type=%r." % (device, ty)) + idx = device.device_id() + if idx == 0: + return sty + return "%s:%d" % (sty, idx) + + +def numpy_to_ort_value(arr, device=None): + """ + Converts a numpy array to :epkg:`C_OrtValue`. + + :param arr: numpy array + :param device: :epkg:`C_OrtDevice` or None for cpu + :return: :epkg:`C_OrtValue` + """ + if device is None: + device = get_ort_device('cpu') + return C_OrtValue.ortvalue_from_numpy(arr, device) + + +def device_to_providers(device): + """ + Returns the corresponding providers for a specific device. + + :param device: :epkg:`C_OrtDevice` + :return: providers + """ + if isinstance(device, str): + device = get_ort_device(device) + if device.device_type() == device.cpu(): + return ['CPUExecutionProvider'] + if device.device_type() == device.cuda(): + return ['CUDAExecutionProvider'] + raise ValueError( # pragma: no cover + "Unexpected device %r." % device)