Skip to content

Commit

Permalink
consolidated some batch caching code
Browse files Browse the repository at this point in the history
  • Loading branch information
qmac committed Apr 2, 2018
1 parent d60f664 commit a42f14d
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions pliers/transformers/base.py
Expand Up @@ -236,18 +236,13 @@ def _iterate(self, stims, validation='strict', *args, **kwargs):
results = []
for batch in progress_bar_wrapper(batches):
use_cache = config.get_option('cache_transformers')
target_inds = []
target_inds = {}
non_cached = []
transformed_keys = set()
for stim in batch:
key = hash((hash(self), hash(stim)))
if use_cache and (key in _cache or key in transformed_keys):
target_inds.append(-1) # signals to query cache
else:
target_inds.append(len(non_cached))
if not (use_cache and (key in _cache or key in target_inds)):
target_inds[key] = len(non_cached)
non_cached.append(stim)
# Can't use _cache in case _transform fails
transformed_keys.add(key)

if len(non_cached) > 0:
batch_results = self._transform(non_cached, *args, **kwargs)
Expand All @@ -256,17 +251,17 @@ def _iterate(self, stims, validation='strict', *args, **kwargs):

for i, stim in enumerate(batch):
key = hash((hash(self), hash(stim)))
if target_inds[i] == -1:
results.append(_cache[key])
else:
result = batch_results[target_inds[i]]
if key in target_inds:
result = batch_results[target_inds[key]]
result = _log_transformation(stim, result, self)
self._propagate_context(stim, result)
if use_cache:
if isgenerator(result):
result = list(result)
_cache[key] = result
results.append(result)
else:
results.append(_cache[key])
return results

def _transform(self, stim, *args, **kwargs):
Expand Down

0 comments on commit a42f14d

Please sign in to comment.