Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion _doc/bench/bench_orttraining_nn_gpu_fwbw.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def benchmark(N=1000, n_features=100, hidden_layer_sizes="50,50", max_iter=500,
nn = MLPRegressor(hidden_layer_sizes=hidden_layer_sizes,
max_iter=max_iter if run_skl else 1,
solver='sgd', learning_rate_init=learning_rate_init,
n_iter_no_change=max_iter, batch_size=batch_size)
n_iter_no_change=max_iter, batch_size=batch_size,
alpha=0, nesterovs_momentum=False, momentum=0,
learning_rate="invscaling")

begin = time.perf_counter()
with warnings.catch_warnings():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ the penalty, its gradient, a learning rate possibly with momentum.
* an object inheriting from :class:`BaseLearningLoss
<onnxcustom.training.sgd_learning_loss.BaseLearningLoss>`
* an object inheriting from :class:`BaseLearningPenalty
<onnxcustom.training.sgd_learning_loss.BaseLearningPenalty>`
<onnxcustom.training.sgd_learning_penalty.BaseLearningPenalty>`
* an object inheriting from :class:`BaseLearningRate
<onnxcustom.training.sgd_learning_rate.BaseLearningRate>`

Expand Down
8 changes: 8 additions & 0 deletions _unittests/ut_utils/test_onnxruntime_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def test_get_ort_device_type(self):
self.assertEqual(get_ort_device_type('cuda'), 1)
self.assertRaise(lambda: get_ort_device_type('none'), ValueError)

def test_get_ort_device_type_exc_2(self):
dev = get_ort_device('cpu')
self.assertEqual(get_ort_device_type(dev), 0)
dev = get_ort_device('cuda')
self.assertEqual(get_ort_device_type(dev), 1)
self.assertRaise(lambda: get_ort_device_type(''), ValueError)
self.assertRaise(lambda: get_ort_device_type(0), TypeError)

def test_get_ort_device_type_exc(self):
self.assertRaise(
lambda: get_ort_device_type(['cpu']),
Expand Down
20 changes: 20 additions & 0 deletions _unittests/ut_utils/test_print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import unittest
import numpy
from onnxruntime import OrtValue
from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
OrtValue as C_OrtValue)
from pyquickhelper.pycode import ExtTestCase
Expand Down Expand Up @@ -33,6 +34,25 @@ def test_print_ortvalue(self):
text = str_ortvalue(ort) # pylint: disable=W0212
self.assertEqual(expected, text)

def test_print_py_ortvalue(self):
expected = (
"device=Cpu dtype=dtype('float32') shape=(1, 4) "
"value=[0.0, 1.0, 4.0, 4.5]")
value = numpy.array([[0, 1, 4, 4.5]], dtype=numpy.float32)
ort = OrtValue.ortvalue_from_numpy(value, 'cpu')
text = str_ortvalue(ort)
self.assertEqual(expected, text)
text = str_ortvalue(ort) # pylint: disable=W0212
self.assertEqual(expected, text)

expected = (
"device=Cpu dtype=dtype('int64') shape=(100,) "
"value=[0, 1, 2, 3, 4, '...', 95, 96, 97, 98, 99]")
value = numpy.arange(100).astype(numpy.int64)
ort = OrtValue.ortvalue_from_numpy(value, 'cpu')
text = str_ortvalue(ort) # pylint: disable=W0212
self.assertEqual(expected, text)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions onnxcustom/plotting/plotting_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,7 @@ def plot_onnxs(*onx, ax=None, dpi=300, temp_dot=None, temp_img=None,
"len(onx)=%d)" % (title, len(onx)))
fig.suptitle(title)
elif len(onx) == 1:
if isinstance(title, list):
title = title[0]
ax.set_title(title)
return ax
95 changes: 84 additions & 11 deletions onnxcustom/training/base_onnx_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,25 @@ class BaseLearningOnnx:
"""

def __init__(self):
pass
self.cache_in_ = {}
self.cache_out_ = {}

def __getstate__(self):
"""
Overwrites getstate to get rid of InferenceSession.
"""
atts = [k for k in self.__dict__ if not k.endswith('_')]
state = {k: getattr(self, k) for k in atts}
onx = [k for k in self.__dict__
if k.endswith('_onnx_')]
onx = [k for k in self.__dict__ if k.endswith('_onnx_')]
for o in onx:
state[o] = getattr(self, o).SerializeToString()
onx = [k for k in self.__dict__
if k.endswith('_sess_')]
onx = [k for k in self.__dict__ if k.endswith('_sess_')]
bind = [k for k in self.__dict__ if k.endswith('_bind_')]
for k in bind:
state[k] = True
binds = [k for k in self.__dict__ if k.endswith('_binds_')]
for k in binds:
state[k] = len(getattr(self, k))
for o in onx:
state[o] = getattr(self, o).get_providers()
return state
Expand All @@ -58,8 +63,18 @@ def __setstate__(self, state):
setattr(self, k2, InferenceSession(
getattr(self, k).SerializeToString(), so,
providers=prov))
bind = k2 + "bind_"
setattr(self, bind, getattr(self, k2).io_binding()._iobinding)
for k, v in state.items():
if k.endswith('_bind_'):
k2 = k[:-5]
setattr(self, k, getattr(self, k2).io_binding()._iobinding)
elif k.endswith('_binds_'):
k2 = k[:-6]
n = v
setattr(self, k, [
getattr(self, k2).io_binding()._iobinding
for i in range(n)])
self.cache_in_ = {}
self.cache_out_ = {}
return self

def __repr_extended__(self):
Expand Down Expand Up @@ -96,7 +111,22 @@ def build_onnx_function(self, opset, device, *args):
raise NotImplementedError(
"This method must be overwritten.")

def _bind_input_ortvalue(self, name, bind, c_ortvalue, device):
def clear_binding_inputs(self, name, bind, cache=False):
"""
Clears binding and empty cache.
"""
def _cache_in(name, bind):
key = name, id(bind)
if key in self.cache_in_ and self.cache_in_[key] == 0:
return True
self.cache_in_[key] = 0
return False

if cache and _cache_in(name, bind):
return
bind.clear_binding_inputs()

def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, cache=False):
"""
Binds :epkg:`C_OrtValue` to the structure used by
:epkg:`InferenceSession` to run inference.
Expand All @@ -106,14 +136,41 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue, device):
:param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`),
it can be also a numpy array
:param device: device
:param cache: avoids binding again the data pointer did not change,
only works when c_ortvalue is of :epkg:`C_OrtValue`
"""
if isinstance(c_ortvalue, C_OrtValue):
def _cache_in(name, bind, c_ortvalue):
key = name, id(bind)
ptr = self.cache_in_.get(key, 0)
ptr2 = c_ortvalue.data_ptr()
if ptr == ptr2:
return True
self.cache_in_[key] = ptr2
return False

def _cache_np(name, bind, c_ortvalue):
key = name, id(bind)
ptr = self.cache_in_.get(key, 0)
ptr2 = c_ortvalue.__array_interface__['data'][0]
if ptr == ptr2:
return True
self.cache_in_[key] = ptr2
return False

def do_bind(name, bind, c_ortvalue):
bind.bind_ortvalue_input(name, c_ortvalue)

if isinstance(c_ortvalue, C_OrtValue):
if cache and _cache_in(name, bind, c_ortvalue):
return
do_bind(name, bind, c_ortvalue)
elif isinstance(c_ortvalue, numpy.ndarray):
if self.device_type() != device.cpu(): # pylint: disable=E1101
raise ProviderError(
"device=%s is not CPU." % ort_device_to_string(
device))
if cache and _cache_np(name, bind, c_ortvalue):
return
bind.bind_input(
name, device, c_ortvalue.dtype, c_ortvalue.shape,
c_ortvalue.__array_interface__['data'][0])
Expand All @@ -122,19 +179,35 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue, device):
"Unable to bind type %r for name %r." % (
type(c_ortvalue), name))

def _bind_output_ortvalue(self, name, bind, c_ortvalue):
def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False):
"""
Binds :epkg:`C_OrtValue` to the structure used by
:epkg:`InferenceSession` to run inference.

:param name: str
:param bind: python structure
:param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`)
:param cache: avoids binding again the data pointer did not change,
only works when c_ortvalue is of :epkg:`C_OrtValue`

This method can be used for inplace computation.
"""
if isinstance(c_ortvalue, C_OrtValue):
def _cache_out(name, bind, c_ortvalue):
key = name, id(bind)
ptr = self.cache_in_.get(key, 0)
ptr2 = c_ortvalue.data_ptr()
if ptr == ptr2:
return True
self.cache_in_[key] = ptr2
return False

def do_bind(name, bind, c_ortvalue):
bind.bind_ortvalue_output(name, c_ortvalue)

if isinstance(c_ortvalue, C_OrtValue):
if cache and _cache_out(name, bind, c_ortvalue):
return
do_bind(name, bind, c_ortvalue)
else:
raise TypeError( # pragma: no cover
"Unable to bind type %r for name %r." % (
Expand Down
11 changes: 7 additions & 4 deletions onnxcustom/training/optimizers_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,16 @@ def build_onnx_function(self):
so = SessionOptions()
so.log_severity_level = 4

n = len(self.weights_to_train)

# loss_grad
self.learning_loss.build_onnx_function(
opset, self.device, self.weight_name)

# weight update
self.learning_rate.build_onnx_function(opset, self.device)
self.learning_rate.build_onnx_function(opset, self.device, n)

# penalty
n = len(self.weights_to_train)
self.learning_penalty.build_onnx_function(opset, self.device, n)

# zero
Expand Down Expand Up @@ -422,9 +423,11 @@ def _iteration(self, data_loader, states, n_weights):

n = len(state) - n_weights
for i in range(n, len(state)):
self.learning_penalty.update_weights(self.device, state[i])
self.learning_penalty.update_weights(
i - n, self.device, state[i])
self.learning_rate.update_weights(
self.device, state[i], gradient[i], bs,
i - n, self.device, state[i],
gradient[i], bs,
None if grad is None else grad[i])

if logger is not None:
Expand Down
2 changes: 1 addition & 1 deletion onnxcustom/training/ortgradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def device_name(device):
"""
if device.device_type() == OrtDevice.cpu():
return 'Cpu'
if device.device_type() == OrtDevice.cuda():
if device.device_type() == OrtDevice.cuda(): # pragma: no cover
return 'Gpu'
raise RuntimeError( # pragma: no cover
"Unexpected value for device type %r." % device.device_type())
Expand Down
20 changes: 9 additions & 11 deletions onnxcustom/training/sgd_learning_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,20 @@ def loss_gradient( # pylint: disable=E1101
if (not hasattr(self, "loss_grad_sess_") or
not hasattr(self, "loss_grad_sess_bind_")):
raise RuntimeError( # pragma: no cover
"Attributes 'loss_grad_sess_bind_' or 'loss_grad_sess_' "
"is missing. Method 'build_onnx_function' has not been called.")
"Attributes 'loss_grad_sess_bind_' or 'loss_grad_sess_' is "
"missing. Method 'build_onnx_function' has not been called.")
bind = self.loss_grad_sess_bind_
if weight is not None:
self._bind_input_ortvalue(
"weight", self.loss_grad_sess_bind_, weight, device)
"weight", bind, weight, device, cache=True)
else:
self.loss_grad_sess_bind_.clear_binding_inputs()
self._bind_input_ortvalue(
"X1", self.loss_grad_sess_bind_, expected, device)
self._bind_input_ortvalue(
"X2", self.loss_grad_sess_bind_, predicted, device)
self.clear_binding_inputs("weight", bind, cache=True)
self._bind_input_ortvalue("X1", bind, expected, device, cache=True)
self._bind_input_ortvalue("X2", bind, predicted, device, cache=True)
self.loss_grad_sess_bind_.bind_output('Y', device)
self.loss_grad_sess_bind_.bind_output('Z', device)
self.loss_grad_sess_._sess.run_with_iobinding(
self.loss_grad_sess_bind_, None)
loss, grad = self.loss_grad_sess_bind_.get_outputs()
self.loss_grad_sess_._sess.run_with_iobinding(bind, None)
loss, grad = bind.get_outputs()
return loss, grad

@staticmethod
Expand Down
35 changes: 18 additions & 17 deletions onnxcustom/training/sgd_learning_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def penalty_loss(self, device, loss, *weights):
"""
return loss

def update_weights(self, device, statei):
def update_weights(self, n_bind, device, statei):
"""
Returns the received loss. Updates the weight inplace.

Expand Down Expand Up @@ -129,8 +129,9 @@ def build_onnx_function(self, opset, device, n_tensors):
self.penalty_grad_sess_ = InferenceSession(
self.penalty_grad_onnx_.SerializeToString(), so,
providers=device_to_providers(device))
self.penalty_grad_sess_bind_ = (
self.penalty_grad_sess_.io_binding()._iobinding)
self.penalty_grad_sess_binds_ = [
self.penalty_grad_sess_.io_binding()._iobinding
for n in range(n_tensors)]

def penalty_loss(self, device, *inputs):
"""
Expand All @@ -144,31 +145,31 @@ def penalty_loss(self, device, *inputs):
if (not hasattr(self, "penalty_onnx_") or
not hasattr(self, "penalty_sess_bind_")):
raise RuntimeError( # pragma: no cover
"Attributes 'penalty_sess_bind_' or 'penalty_onnx_' "
"is missing. Method 'build_onnx_function' has not been called.")
"Attributes 'penalty_sess_bind_' or 'penalty_onnx_' is "
"missing. Method 'build_onnx_function' has not been called.")
if len(self.names_) != len(inputs):
raise RuntimeError(
raise RuntimeError( # pragma: no cover
"Mismatched number of inputs: %d != %d." % (
len(self.names_), len(inputs)))

for name, inp in zip(self.names_, inputs):
self._bind_input_ortvalue(
name, self.penalty_sess_bind_, inp, device)
self._bind_output_ortvalue('Y', self.penalty_sess_bind_, inputs[0])
name, self.penalty_sess_bind_, inp, device, cache=True)
self._bind_output_ortvalue(
'Y', self.penalty_sess_bind_, inputs[0], cache=True)
self.penalty_sess_._sess.run_with_iobinding(
self.penalty_sess_bind_, None)
return self.penalty_sess_bind_.get_outputs()[0]

def update_weights(self, device, statei):
def update_weights(self, n_bind, device, statei):
if (not hasattr(self, "penalty_grad_onnx_") or
not hasattr(self, "penalty_grad_sess_bind_")):
not hasattr(self, "penalty_grad_sess_binds_")):
raise RuntimeError( # pragma: no cover
"Attributes 'penalty_grad_sess_bind_' or "
"Attributes 'penalty_grad_sess_binds_' or "
"'penalty_grad_onnx_' is missing. Method "
"'build_onnx_function' has not been called.")
self._bind_input_ortvalue(
"X", self.penalty_grad_sess_bind_, statei, device)
self._bind_output_ortvalue('Y', self.penalty_grad_sess_bind_, statei)
self.penalty_grad_sess_._sess.run_with_iobinding(
self.penalty_grad_sess_bind_, None)
return self.penalty_grad_sess_bind_.get_outputs()[0] # X
bind = self.penalty_grad_sess_binds_[n_bind]
self._bind_input_ortvalue("X", bind, statei, device, cache=True)
self._bind_output_ortvalue('Y', bind, statei, cache=True)
self.penalty_grad_sess_._sess.run_with_iobinding(bind, None)
return bind.get_outputs()[0] # X
Loading