Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 53 additions & 53 deletions onnxcustom/training/base_onnx_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,48 @@ def build_onnx_function(self, opset, device, *args):
raise NotImplementedError(
"This method must be overwritten.")

@staticmethod
def _cache_in_clear(cache, name, bind):
key = id(bind)
if key in cache:
if name in cache[key]:
if cache[key][name] == 0:
return True
cache[key][name] = 0
return False
return True

def clear_binding_inputs(self, name, bind, cache=False):
"""
Clears binding and empty cache.
"""
def _cache_in(name, bind):
key = id(bind)
if key in self.cache_in_:
if name in self.cache_in_[key]:
if self.cache_in_[key][name] == 0:
return True
self.cache_in_[key][name] = 0
return False
return True

if cache and _cache_in(name, bind):
if cache and self._cache_in_clear(self.cache_in_, name, bind):
return
bind.clear_binding_inputs()

def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, cache=False):
@staticmethod
def _bio_cache(cache, name, bind, c_ortvalue, ptr2):
key = id(bind)
if key in cache:
if name in cache[key]:
ptr = cache[key][name]
if ptr == ptr2:
return True
cache[key][name] = ptr2
else:
cache[key] = {name: ptr2}
return False

@staticmethod
def _bio_do_bind_in(name, bind, c_ortvalue):
bind.bind_ortvalue_input(name, c_ortvalue)

@staticmethod
def _bio_ptr(c):
return c.data_ptr()

def _bind_input_ortvalue(self, name, bind, c_ortvalue, device,
cache=False):
"""
Binds :epkg:`C_OrtValue` to the structure used by
:epkg:`InferenceSession` to run inference.
Expand All @@ -141,35 +164,22 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, cache=False):
:param device: device
:param cache: avoids binding again if the data pointer did not change,
only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is
equivalent to a dictionary `{ id(bind), name: c_ort_value.data_ptr() }`.
equivalent to a dictionary
`{ id(bind), name: c_ort_value.data_ptr() }`.
"""
def _cache_in(name, bind, c_ortvalue, ptr2):
key = id(bind)
if key in self.cache_in_:
if name in self.cache_in_[key]:
ptr = self.cache_in_[key][name]
if ptr == ptr2:
return True
self.cache_in_[key][name] = ptr2
else:
self.cache_in_[key] = {name: ptr2}
return False

def do_bind(name, bind, c_ortvalue):
bind.bind_ortvalue_input(name, c_ortvalue)

if isinstance(c_ortvalue, C_OrtValue):
if cache and _cache_in(
name, bind, c_ortvalue, c_ortvalue.data_ptr()):
if cache and self._bio_cache(
self.cache_in_, name, bind, c_ortvalue,
self._bio_ptr(c_ortvalue)):
return
do_bind(name, bind, c_ortvalue)
self._bio_do_bind_in(name, bind, c_ortvalue)
elif isinstance(c_ortvalue, numpy.ndarray):
if self.device_type() != device.cpu(): # pylint: disable=E1101
raise ProviderError(
"device=%s is not CPU." % ort_device_to_string(
device))
if cache and _cache_in(
name, bind, c_ortvalue,
if cache and self._bio_cache(
self.cache_in_, name, bind, c_ortvalue,
c_ortvalue.__array_interface__['data'][0]):
return
bind.bind_input(
Expand All @@ -180,6 +190,10 @@ def do_bind(name, bind, c_ortvalue):
"Unable to bind type %r for name %r." % (
type(c_ortvalue), name))

@staticmethod
def _bio_do_bind_out(name, bind, c_ortvalue):
bind.bind_ortvalue_output(name, c_ortvalue)

def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False):
"""
Binds :epkg:`C_OrtValue` to the structure used by
Expand All @@ -190,31 +204,17 @@ def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False):
:param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`)
:param cache: avoids binding again if the data pointer did not change,
only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is
equivalent to a dictionary `{ id(bind), name: c_ort_value.data_ptr() }`.
equivalent to a dictionary
`{ id(bind), name: c_ort_value.data_ptr() }`.

This method can be used for inplace computation.
"""
def _cache_out(name, bind, c_ortvalue, ptr2):
key = id(bind)
if key in self.cache_out_:
if name in self.cache_out_[key]:
ptr = self.cache_out_[key][name]
if ptr == ptr2:
return True
self.cache_out_[key][name] = ptr2
else:
self.cache_out_[key] = {name: ptr2}
return False

def do_bind(name, bind, c_ortvalue):
bind.bind_ortvalue_output(name, c_ortvalue)

if isinstance(c_ortvalue, C_OrtValue):
if cache and _cache_out(
name, bind, c_ortvalue,
c_ortvalue.data_ptr()):
if cache and self._bio_cache(
self.cache_out_, name, bind, c_ortvalue,
self._bio_ptr(c_ortvalue)):
return
do_bind(name, bind, c_ortvalue)
self._bio_do_bind_out(name, bind, c_ortvalue)
else:
raise TypeError( # pragma: no cover
"Unable to bind type %r for name %r." % (
Expand Down
27 changes: 25 additions & 2 deletions onnxcustom/training/optimizers_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ def _iteration(self, data_loader, states, n_weights):
"iteration begin learning_rate=%r",
self.learning_rate)

prediction_cache = None
prediction_cache_shape = None
backward_outputs_cache = None
for ib, ito in enumerate(data_loader.iter_ortvalue()):
if len(ito) == 2:
(ortx, orty) = ito
Expand All @@ -397,7 +400,22 @@ def _iteration(self, data_loader, states, n_weights):
"[OrtGradientForwardBackwardOptimizer._iteration] "
"batch %d", ib)

prediction = self.train_function_.forward(states[0], training=True)
ortx_shape = tuple(ortx.shape())
same_shape = (
prediction_cache_shape is not None and
ortx_shape == prediction_cache_shape)

# forward
if prediction_cache_shape is None or same_shape:
prediction_cache = None
prediction_cache_shape = None
prediction = self.train_function_.forward(
states[0], training=True,
forward_outputs_cache=prediction_cache)
prediction_cache = prediction
prediction_cache_shape = ortx_shape

# loss
loss, loss_gradient = self.learning_loss.loss_gradient(
self.device, orty, prediction[0], weight=ortw)
n = len(state) - n_weights
Expand All @@ -414,7 +432,12 @@ def _iteration(self, data_loader, states, n_weights):
actual_losses if len(actual_losses) < 5
else actual_losses[-5:])]))

gradient = self.train_function_.backward([loss_gradient])
# backward
if not same_shape:
backward_outputs_cache = None
gradient = self.train_function_.backward(
[loss_gradient], backward_outputs_cache=backward_outputs_cache)
backward_outputs_cache = gradient

if len(gradient) != len(state):
raise RuntimeError( # pragma: no cover
Expand Down
8 changes: 4 additions & 4 deletions onnxcustom/training/ortgradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def saved_tensors(self):
"No tensors was saved with save_for_backward.")
return self.saved_tensors_

def forward(self, inputs, training=False):
def forward(self, inputs, training=False, forward_outputs_cache=None):
"""
Implements forward function.

Expand Down Expand Up @@ -615,7 +615,7 @@ def _log(msg, *args):
inputs, cls._devices, cls._debug)

if training:
forward_outputs = OrtValueVector()
forward_outputs = forward_outputs_cache or OrtValueVector()
state = PartialGraphExecutionState()
self.states_.append(state)
if logger is not None:
Expand Down Expand Up @@ -666,7 +666,7 @@ def _log(msg, *args):
_log("end")
return ortvalues

def backward(self, grad_outputs):
def backward(self, grad_outputs, backward_outputs_cache=None):
"""
Implements backward function. The function returns
an :epkg:`OrtValueVector`.
Expand Down Expand Up @@ -705,7 +705,7 @@ def _log(msg, *args):
_log("backward_inputs[%d].shape=%r",
i, backward_inputs[i].shape())
_log("run_backward")
backward_outputs = OrtValueVector()
backward_outputs = backward_outputs_cache or OrtValueVector()
cls._training_agent.run_backward(
backward_inputs, backward_outputs, state)
if logger is not None: # pragma: no cover
Expand Down