Skip to content

Commit

Permalink
Improves performance of caching (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jan 9, 2022
1 parent e48306e commit d10231f
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions onnxcustom/training/base_onnx_function.py
Expand Up @@ -116,11 +116,14 @@ def clear_binding_inputs(self, name, bind, cache=False):
Clears binding and empty cache.
"""
def _cache_in(name, bind):
key = name, id(bind)
if key in self.cache_in_ and self.cache_in_[key] == 0:
return True
self.cache_in_[key] = 0
return False
key = id(bind)
if key in self.cache_in_:
if name in self.cache_in_[key]:
if self.cache_in_[key][name] == 0:
return True
self.cache_in_[key][name] = 0
return False
return True

if cache and _cache_in(name, bind):
return
Expand All @@ -139,37 +142,34 @@ def _bind_input_ortvalue(self, name, bind, c_ortvalue, device, cache=False):
:param cache: avoids binding again the data pointer did not change,
only works when c_ortvalue is of :epkg:`C_OrtValue`
"""
def _cache_in(name, bind, c_ortvalue):
key = name, id(bind)
ptr = self.cache_in_.get(key, 0)
ptr2 = c_ortvalue.data_ptr()
if ptr == ptr2:
return True
self.cache_in_[key] = ptr2
return False

def _cache_np(name, bind, c_ortvalue):
key = name, id(bind)
ptr = self.cache_in_.get(key, 0)
ptr2 = c_ortvalue.__array_interface__['data'][0]
if ptr == ptr2:
return True
self.cache_in_[key] = ptr2
def _cache_in(name, bind, c_ortvalue, ptr2):
key = id(bind)
if key in self.cache_in_:
if name in self.cache_in_[key]:
ptr = self.cache_in_[key][name]
if ptr == ptr2:
return True
self.cache_in_[key][name] = ptr2
else:
self.cache_in_[key] = {name: ptr2}
return False

def do_bind(name, bind, c_ortvalue):
bind.bind_ortvalue_input(name, c_ortvalue)

if isinstance(c_ortvalue, C_OrtValue):
if cache and _cache_in(name, bind, c_ortvalue):
if cache and _cache_in(
name, bind, c_ortvalue, c_ortvalue.data_ptr()):
return
do_bind(name, bind, c_ortvalue)
elif isinstance(c_ortvalue, numpy.ndarray):
if self.device_type() != device.cpu(): # pylint: disable=E1101
raise ProviderError(
"device=%s is not CPU." % ort_device_to_string(
device))
if cache and _cache_np(name, bind, c_ortvalue):
if cache and _cache_in(
name, bind, c_ortvalue,
c_ortvalue.__array_interface__['data'][0]):
return
bind.bind_input(
name, device, c_ortvalue.dtype, c_ortvalue.shape,
Expand All @@ -192,20 +192,25 @@ def _bind_output_ortvalue(self, name, bind, c_ortvalue, cache=False):
This method can be used for inplace computation.
"""
def _cache_out(name, bind, c_ortvalue):
key = name, id(bind)
ptr = self.cache_in_.get(key, 0)
ptr2 = c_ortvalue.data_ptr()
if ptr == ptr2:
return True
self.cache_in_[key] = ptr2
def _cache_out(name, bind, c_ortvalue, ptr2):
key = id(bind)
if key in self.cache_out_:
if name in self.cache_out_[key]:
ptr = self.cache_out_[key][name]
if ptr == ptr2:
return True
self.cache_out_[key][name] = ptr2
else:
self.cache_out_[key] = {name: ptr2}
return False

def do_bind(name, bind, c_ortvalue):
bind.bind_ortvalue_output(name, c_ortvalue)

if isinstance(c_ortvalue, C_OrtValue):
if cache and _cache_out(name, bind, c_ortvalue):
if cache and _cache_out(
name, bind, c_ortvalue,
c_ortvalue.data_ptr()):
return
do_bind(name, bind, c_ortvalue)
else:
Expand Down

0 comments on commit d10231f

Please sign in to comment.