Skip to content
Permalink
Browse files

[rllib/tune] Cache get_preprocessor() calls, default max_failur… (#6211)

  • Loading branch information
ericl authored and richardliaw committed Nov 21, 2019
1 parent d3227f2 commit 7559fdb1418e4c35f09e2edbc7c6762b3889f278
Showing with 12 additions and 2 deletions.
  1. +1 −1 python/ray/tune/tune.py
  2. +11 −1 rllib/models/model.py
@@ -74,7 +74,7 @@ def run(run_or_experiment,
checkpoint_score_attr=None,
global_checkpoint_period=10,
export_formats=None,
max_failures=3,
max_failures=0,
restore=None,
search_alg=None,
scheduler=None,
@@ -232,6 +232,10 @@ def restore_original_dimensions(obs, obs_space, tensorlib=tf):
return obs


# Cache of preprocessors, for if the user is calling unpack obs often.
_cache = {}


def _unpack_obs(obs, space, tensorlib=tf):
"""Unpack a flattened Dict or Tuple observation array/tensor.
@@ -243,7 +247,13 @@ def _unpack_obs(obs, space, tensorlib=tf):

if (isinstance(space, gym.spaces.Dict)
or isinstance(space, gym.spaces.Tuple)):
prep = get_preprocessor(space)(space)
if id(space) in _cache:
prep = _cache[id(space)]
else:
prep = get_preprocessor(space)(space)
# Make an attempt to cache the result, if enough space left.
if len(_cache) < 999:
_cache[id(space)] = prep
if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]:
raise ValueError(
"Expected flattened obs shape of [None, {}], got {}".format(

0 comments on commit 7559fdb

Please sign in to comment.
You can’t perform that action at this time.