From 4ffedecc0ef844585189623f4eecceb9c84b992d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 14 Jan 2022 01:08:57 +0100 Subject: [PATCH 1/3] improves caching --- onnxcustom/training/optimizers_partial.py | 15 ++++++++++++++- onnxcustom/training/ortgradient.py | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/onnxcustom/training/optimizers_partial.py b/onnxcustom/training/optimizers_partial.py index 814e64e6..1e13b8b4 100644 --- a/onnxcustom/training/optimizers_partial.py +++ b/onnxcustom/training/optimizers_partial.py @@ -384,6 +384,8 @@ def _iteration(self, data_loader, states, n_weights): "iteration begin learning_rate=%r", self.learning_rate) + prediction_cache = None + prediction_cache_shape = None for ib, ito in enumerate(data_loader.iter_ortvalue()): if len(ito) == 2: (ortx, orty) = ito @@ -397,7 +399,17 @@ def _iteration(self, data_loader, states, n_weights): "[OrtGradientForwardBackwardOptimizer._iteration] " "batch %d", ib) - prediction = self.train_function_.forward(states[0], training=True) + # forward + if (prediction_cache_shape is None or + tuple(ortx.shape()) != prediction_cache_shape): + prediction_cache = None + prediction_cache_shape = None + prediction = self.train_function_.forward( + states[0], training=True, forward_outputs_cache=prediction_cache_shape) + prediction_cache = prediction + prediction_cache_shape = tuple(ortx.shape()) + + # loss loss, loss_gradient = self.learning_loss.loss_gradient( self.device, orty, prediction[0], weight=ortw) n = len(state) - n_weights @@ -414,6 +426,7 @@ def _iteration(self, data_loader, states, n_weights): actual_losses if len(actual_losses) < 5 else actual_losses[-5:])])) + # backward gradient = self.train_function_.backward([loss_gradient]) if len(gradient) != len(state): diff --git a/onnxcustom/training/ortgradient.py b/onnxcustom/training/ortgradient.py index 08101309..2fad325d 100644 --- a/onnxcustom/training/ortgradient.py +++ b/onnxcustom/training/ortgradient.py @@ -587,7 +587,7 @@ def saved_tensors(self): "No tensors was saved with save_for_backward.") return self.saved_tensors_ - def forward(self, inputs, training=False): + def forward(self, inputs, training=False, forward_outputs_cache=None): """ Implements forward function. From f8f8f4eff7d333914f41eafb8ada53bbadab6c96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 14 Jan 2022 01:14:59 +0100 Subject: [PATCH 2/3] fix cache --- onnxcustom/training/optimizers_partial.py | 2 +- onnxcustom/training/ortgradient.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxcustom/training/optimizers_partial.py b/onnxcustom/training/optimizers_partial.py index 1e13b8b4..7bfeaf2a 100644 --- a/onnxcustom/training/optimizers_partial.py +++ b/onnxcustom/training/optimizers_partial.py @@ -405,7 +405,7 @@ def _iteration(self, data_loader, states, n_weights): prediction_cache = None prediction_cache_shape = None prediction = self.train_function_.forward( - states[0], training=True, forward_outputs_cache=prediction_cache_shape) + states[0], training=True, forward_outputs_cache=prediction_cache) prediction_cache = prediction prediction_cache_shape = tuple(ortx.shape()) diff --git a/onnxcustom/training/ortgradient.py b/onnxcustom/training/ortgradient.py index 2fad325d..e85d647e 100644 --- a/onnxcustom/training/ortgradient.py +++ b/onnxcustom/training/ortgradient.py @@ -615,7 +615,7 @@ def _log(msg, *args): inputs, cls._devices, cls._debug) if training: - forward_outputs = OrtValueVector() + forward_outputs = forward_outputs_cache or OrtValueVector() state = PartialGraphExecutionState() self.states_.append(state) if logger is not None: From 275453f8f8b285ce4cd5655990cba14f50bf7e24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Fri, 14 Jan 2022 02:49:02 +0100 Subject: [PATCH 3/3] backward cache --- onnxcustom/training/base_onnx_function.py | 106 +++++++++++----------- onnxcustom/training/optimizers_partial.py | 20 +++- onnxcustom/training/ortgradient.py | 4 +- 3 files changed, 70 insertions(+), 60 deletions(-) diff --git a/onnxcustom/training/base_onnx_function.py b/onnxcustom/training/base_onnx_function.py index 9f41e813..bed3dd27 100644 --- a/onnxcustom/training/base_onnx_function.py +++ b/onnxcustom/training/base_onnx_function.py @@ -111,25 +111,48 @@ def build_onnx_function(self, opset, device, *args): raise NotImplementedError( "This method must be overwritten.") + @staticmethod + def _cache_in_clear(cache, name, bind): + key = id(bind) + if key in cache: + if name in cache[key]: + if cache[key][name] == 0: + return True + cache[key][name] = 0 + return False + return True + def clear_binding_inputs(self, name, bind, cache=False): """ Clears binding and empty cache. """ - def _cache_in(name, bind): - key = id(bind) - if key in self.cache_in_: - if name in self.cache_in_[key]: - if self.cache_in_[key][name] == 0: - return True - self.cache_in_[key][name] = 0 - return False - return True - - if cache and _cache_in(name, bind): + if cache and self._cache_in_clear(self.cache_in_, name, bind): return bind.clear_binding_inputs() - def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, cache=False): + @staticmethod + def _bio_cache(cache, name, bind, c_ortvalue, ptr2): + key = id(bind) + if key in cache: + if name in cache[key]: + ptr = cache[key][name] + if ptr == ptr2: + return True + cache[key][name] = ptr2 + else: + cache[key] = {name: ptr2} + return False + + @staticmethod + def _bio_do_bind_in(name, bind, c_ortvalue): + bind.bind_ortvalue_input(name, c_ortvalue) + + @staticmethod + def _bio_ptr(c): + return c.data_ptr() + + def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, + cache=False): """ Binds :epkg:`C_OrtValue` to the structure used by :epkg:`InferenceSession` to run inference. @@ -141,35 +164,22 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, cache=False): :param device: device :param cache: avoids binding again if the data pointer did not change, only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is - equivalent to a dictionary `{ id(bind), name: c_ort_value.data_ptr() }`. + equivalent to a dictionary + `{ id(bind), name: c_ort_value.data_ptr() }`. """ - def _cache_in(name, bind, c_ortvalue, ptr2): - key = id(bind) - if key in self.cache_in_: - if name in self.cache_in_[key]: - ptr = self.cache_in_[key][name] - if ptr == ptr2: - return True - self.cache_in_[key][name] = ptr2 - else: - self.cache_in_[key] = {name: ptr2} - return False - - def do_bind(name, bind, c_ortvalue): - bind.bind_ortvalue_input(name, c_ortvalue) - if isinstance(c_ortvalue, C_OrtValue): - if cache and _cache_in( - name, bind, c_ortvalue, c_ortvalue.data_ptr()): + if cache and self._bio_cache( + self.cache_in_, name, bind, c_ortvalue, + self._bio_ptr(c_ortvalue)): return - do_bind(name, bind, c_ortvalue) + self._bio_do_bind_in(name, bind, c_ortvalue) elif isinstance(c_ortvalue, numpy.ndarray): if self.device_type() != device.cpu(): # pylint: disable=E1101 raise ProviderError( "device=%s is not CPU." % ort_device_to_string( device)) - if cache and _cache_in( - name, bind, c_ortvalue, + if cache and self._bio_cache( + self.cache_in_, name, bind, c_ortvalue, c_ortvalue.__array_interface__['data'][0]): return bind.bind_input( @@ -180,6 +190,10 @@ def do_bind(name, bind, c_ortvalue): "Unable to bind type %r for name %r." % ( type(c_ortvalue), name)) + @staticmethod + def _bio_do_bind_out(name, bind, c_ortvalue): + bind.bind_ortvalue_output(name, c_ortvalue) + def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False): """ Binds :epkg:`C_OrtValue` to the structure used by @@ -190,31 +204,17 @@ def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False): :param c_ortvalue: C structure for OrtValue (:epkg:`C_OrtValue`) :param cache: avoids binding again if the data pointer did not change, only works when c_ortvalue is of :epkg:`C_OrtValue`, the cache is - equivalent to a dictionary `{ id(bind), name: c_ort_value.data_ptr() }`. + equivalent to a dictionary + `{ id(bind), name: c_ort_value.data_ptr() }`. This method can be used for inplace computation. """ - def _cache_out(name, bind, c_ortvalue, ptr2): - key = id(bind) - if key in self.cache_out_: - if name in self.cache_out_[key]: - ptr = self.cache_out_[key][name] - if ptr == ptr2: - return True - self.cache_out_[key][name] = ptr2 - else: - self.cache_out_[key] = {name: ptr2} - return False - - def do_bind(name, bind, c_ortvalue): - bind.bind_ortvalue_output(name, c_ortvalue) - if isinstance(c_ortvalue, C_OrtValue): - if cache and _cache_out( - name, bind, c_ortvalue, - c_ortvalue.data_ptr()): + if cache and self._bio_cache( + self.cache_out_, name, bind, c_ortvalue, + self._bio_ptr(c_ortvalue)): return - do_bind(name, bind, c_ortvalue) + self._bio_do_bind_out(name, bind, c_ortvalue) else: raise TypeError( # pragma: no cover "Unable to bind type %r for name %r." % ( diff --git a/onnxcustom/training/optimizers_partial.py b/onnxcustom/training/optimizers_partial.py index 7bfeaf2a..2774484f 100644 --- a/onnxcustom/training/optimizers_partial.py +++ b/onnxcustom/training/optimizers_partial.py @@ -386,6 +386,7 @@ def _iteration(self, data_loader, states, n_weights): prediction_cache = None prediction_cache_shape = None + backward_outputs_cache = None for ib, ito in enumerate(data_loader.iter_ortvalue()): if len(ito) == 2: (ortx, orty) = ito @@ -399,15 +400,20 @@ def _iteration(self, data_loader, states, n_weights): "[OrtGradientForwardBackwardOptimizer._iteration] " "batch %d", ib) + ortx_shape = tuple(ortx.shape()) + same_shape = ( + prediction_cache_shape is not None and + ortx_shape == prediction_cache_shape) + # forward - if (prediction_cache_shape is None or - tuple(ortx.shape()) != prediction_cache_shape): + if prediction_cache_shape is None or same_shape: prediction_cache = None prediction_cache_shape = None prediction = self.train_function_.forward( - states[0], training=True, forward_outputs_cache=prediction_cache) + states[0], training=True, + forward_outputs_cache=prediction_cache) prediction_cache = prediction - prediction_cache_shape = tuple(ortx.shape()) + prediction_cache_shape = ortx_shape # loss loss, loss_gradient = self.learning_loss.loss_gradient( @@ -427,7 +433,11 @@ def _iteration(self, data_loader, states, n_weights): else actual_losses[-5:])])) # backward - gradient = self.train_function_.backward([loss_gradient]) + if not same_shape: + backward_outputs_cache = None + gradient = self.train_function_.backward( + [loss_gradient], backward_outputs_cache=backward_outputs_cache) + backward_outputs_cache = gradient if len(gradient) != len(state): raise RuntimeError( # pragma: no cover diff --git a/onnxcustom/training/ortgradient.py b/onnxcustom/training/ortgradient.py index e85d647e..32a54b4a 100644 --- a/onnxcustom/training/ortgradient.py +++ b/onnxcustom/training/ortgradient.py @@ -666,7 +666,7 @@ def _log(msg, *args): _log("end") return ortvalues - def backward(self, grad_outputs): + def backward(self, grad_outputs, backward_outputs_cache=None): """ Implements backward function. The function returns an :epkg:`OrtValueVector`. @@ -705,7 +705,7 @@ def _log(msg, *args): _log("backward_inputs[%d].shape=%r", i, backward_inputs[i].shape()) _log("run_backward") - backward_outputs = OrtValueVector() + backward_outputs = backward_outputs_cache or OrtValueVector() cls._training_agent.run_backward( backward_inputs, backward_outputs, state) if logger is not None: # pragma: no cover