From d10231f2b205068566934a1a106470078f20a167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 9 Jan 2022 15:40:43 +0100 Subject: [PATCH] Improves performance of caching (#37) --- onnxcustom/training/base_onnx_function.py | 67 ++++++++++++----------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/onnxcustom/training/base_onnx_function.py b/onnxcustom/training/base_onnx_function.py index d3ec03d0..e6308f40 100644 --- a/onnxcustom/training/base_onnx_function.py +++ b/onnxcustom/training/base_onnx_function.py @@ -116,11 +116,14 @@ def clear_binding_inputs(self, name, bind, cache=False): Clears binding and empty cache. """ def _cache_in(name, bind): - key = name, id(bind) - if key in self.cache_in_ and self.cache_in_[key] == 0: - return True - self.cache_in_[key] = 0 - return False + key = id(bind) + if key in self.cache_in_: + if name in self.cache_in_[key]: + if self.cache_in_[key][name] == 0: + return True + self.cache_in_[key][name] = 0 + return False + return True if cache and _cache_in(name, bind): return @@ -139,29 +142,24 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, cache=False): :param cache: avoids binding again the data pointer did not change, only works when c_ortvalue is of :epkg:`C_OrtValue` """ - def _cache_in(name, bind, c_ortvalue): - key = name, id(bind) - ptr = self.cache_in_.get(key, 0) - ptr2 = c_ortvalue.data_ptr() - if ptr == ptr2: - return True - self.cache_in_[key] = ptr2 - return False - - def _cache_np(name, bind, c_ortvalue): - key = name, id(bind) - ptr = self.cache_in_.get(key, 0) - ptr2 = c_ortvalue.__array_interface__['data'][0] - if ptr == ptr2: - return True - self.cache_in_[key] = ptr2 + def _cache_in(name, bind, c_ortvalue, ptr2): + key = id(bind) + if key in self.cache_in_: + if name in self.cache_in_[key]: + ptr = self.cache_in_[key][name] + if ptr == ptr2: + return True + self.cache_in_[key][name] = ptr2 + else: + self.cache_in_[key] = {name: ptr2} return False def do_bind(name, bind, c_ortvalue): bind.bind_ortvalue_input(name, c_ortvalue) if isinstance(c_ortvalue, C_OrtValue): - if cache and _cache_in(name, bind, c_ortvalue): + if cache and _cache_in( + name, bind, c_ortvalue, c_ortvalue.data_ptr()): return do_bind(name, bind, c_ortvalue) elif isinstance(c_ortvalue, numpy.ndarray): @@ -169,7 +167,9 @@ def do_bind(name, bind, c_ortvalue): raise ProviderError( "device=%s is not CPU." % ort_device_to_string( device)) - if cache and _cache_np(name, bind, c_ortvalue): + if cache and _cache_in( + name, bind, c_ortvalue, + c_ortvalue.__array_interface__['data'][0]): return bind.bind_input( name, device, c_ortvalue.dtype, c_ortvalue.shape, @@ -192,20 +192,25 @@ def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False): This method can be used for inplace computation. """ - def _cache_out(name, bind, c_ortvalue): - key = name, id(bind) - ptr = self.cache_in_.get(key, 0) - ptr2 = c_ortvalue.data_ptr() - if ptr == ptr2: - return True - self.cache_in_[key] = ptr2 + def _cache_out(name, bind, c_ortvalue, ptr2): + key = id(bind) + if key in self.cache_out_: + if name in self.cache_out_[key]: + ptr = self.cache_out_[key][name] + if ptr == ptr2: + return True + self.cache_out_[key][name] = ptr2 + else: + self.cache_out_[key] = {name: ptr2} return False def do_bind(name, bind, c_ortvalue): bind.bind_ortvalue_output(name, c_ortvalue) if isinstance(c_ortvalue, C_OrtValue): - if cache and _cache_out(name, bind, c_ortvalue): + if cache and _cache_out( + name, bind, c_ortvalue, + c_ortvalue.data_ptr()): return do_bind(name, bind, c_ortvalue) else: