Skip to content

Commit

Permalink
[RLlib] Eval workers use async req manager. (ray-project#27390)
Browse files Browse the repository at this point in the history
Signed-off-by: Philipp Moritz <pcmoritz@gmail.com>
  • Loading branch information
sven1977 authored and pcmoritz committed Aug 31, 2022
1 parent c6fe6c4 commit 49fb40c
Show file tree
Hide file tree
Showing 9 changed files with 407 additions and 42 deletions.
4 changes: 2 additions & 2 deletions rllib/BUILD
Expand Up @@ -502,12 +502,12 @@ py_test(
)

py_test(
name = "learning_tests_multi_agent_cartpole_crashing_pg",
name = "learning_tests_multi_agent_cartpole_crashing_restart_sub_envs_pg",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/pg/multi-agent-cartpole-crashing-pg.yaml"],
data = ["tuned_examples/pg/multi-agent-cartpole-crashing-restart-sub-envs-pg.yaml"],
args = ["--yaml-dir=tuned_examples/pg"]
)

Expand Down
261 changes: 241 additions & 20 deletions rllib/algorithms/algorithm.py
Expand Up @@ -48,6 +48,7 @@
from ray.rllib.execution.common import (
STEPS_TRAINED_THIS_ITER_COUNTER, # TODO: Backward compatibility.
)
from ray.rllib.execution.parallel_requests import AsyncRequestsManager
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
from ray.rllib.offline import get_offline_io_resource_bundles
Expand Down Expand Up @@ -293,6 +294,9 @@ def default_logger_creator(config):

# Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
self.evaluation_workers: Optional[WorkerSet] = None
# If evaluation duration is "auto", use a AsyncRequestsManager to be more
# robust against eval worker failures.
self._evaluation_async_req_manager: Optional[AsyncRequestsManager] = None
# Initialize common evaluation_metrics to nan, before they become
# available. We want to make sure the metrics are always present
# (although their values may be nan), so that Tune does not complain
Expand Down Expand Up @@ -513,7 +517,7 @@ def setup(self, config: PartialAlgorithmConfigDict):
# Set `rollout_fragment_length` such that desired steps are divided
# equally amongst workers or - in "auto" duration mode - set it
# to a reasonably small number (10), such that a single `sample()`
# call doesn't take too much time so we can stop evaluation as soon
# call doesn't take too much time and we can stop evaluation as soon
# as possible after the train step is completed.
else:
eval_config.update(
Expand Down Expand Up @@ -551,6 +555,14 @@ def setup(self, config: PartialAlgorithmConfigDict):
logdir=self.logdir,
)

if self.config["enable_async_evaluation"]:
self._evaluation_async_req_manager = AsyncRequestsManager(
workers=self.evaluation_workers.remote_workers(),
max_remote_requests_in_flight_per_worker=1,
return_object_refs=True,
)
self._evaluation_weights_seq_number = 0

self.reward_estimators: Dict[str, OffPolicyEstimator] = {}
ope_types = {
"is": ImportanceSampling,
Expand Down Expand Up @@ -730,15 +742,6 @@ def evaluate(
episodes left to run. It's used to find out whether
evaluation should continue.
"""
# In case we are evaluating (in a thread) parallel to training,
# we may have to re-enable eager mode here (gets disabled in the
# thread).
if (
self.config.get("framework") in ["tf2", "tfe"]
and not tf.executing_eagerly()
):
tf1.enable_eager_execution()

# Call the `_before_evaluate` hook.
self._before_evaluate()

Expand Down Expand Up @@ -850,13 +853,13 @@ def duration_fn(num_units_done):
else:
# How many episodes have we run (across all eval workers)?
num_units_done = 0
round_ = 0
_round = 0
while True:
units_left_to_do = duration_fn(num_units_done)
if units_left_to_do <= 0:
break

round_ += 1
_round += 1
try:
batches = ray.get(
[
Expand Down Expand Up @@ -908,7 +911,7 @@ def duration_fn(num_units_done):
env_steps_this_iter += _env_steps

logger.info(
f"Ran round {round_} of parallel evaluation "
f"Ran round {_round} of parallel evaluation "
f"({num_units_done}/{duration if not auto else '?'} "
f"{unit} done)"
)
Expand Down Expand Up @@ -941,6 +944,198 @@ def duration_fn(num_units_done):
# Also return the results here for convenience.
return self.evaluation_metrics

@ExperimentalAPI
def _evaluate_async(
self,
duration_fn: Optional[Callable[[int], int]] = None,
) -> dict:
"""Evaluates current policy under `evaluation_config` settings.
Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.
Args:
duration_fn: An optional callable taking the already run
num episodes as only arg and returning the number of
episodes left to run. It's used to find out whether
evaluation should continue.
"""
# How many episodes/timesteps do we need to run?
# In "auto" mode (only for parallel eval + training): Run as long
# as training lasts.
unit = self.config["evaluation_duration_unit"]
eval_cfg = self.config["evaluation_config"]
rollout = eval_cfg["rollout_fragment_length"]
num_envs = eval_cfg["num_envs_per_worker"]
auto = self.config["evaluation_duration"] == "auto"
duration = (
self.config["evaluation_duration"]
if not auto
else (self.config["evaluation_num_workers"] or 1)
* (1 if unit == "episodes" else rollout)
)

# Call the `_before_evaluate` hook.
self._before_evaluate()

# Put weights only once into object store and use same object
# ref to synch to all workers.
self._evaluation_weights_seq_number += 1
weights_ref = ray.put(self.workers.local_worker().get_weights())
# TODO(Jun): Make sure this cannot block for e.g. 1h. Implement solution via
# connectors.
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
workers=self.evaluation_workers,
timeout_seconds=eval_cfg.get("sync_filters_on_rollout_workers_timeout_s"),
)

if self.config["custom_eval_function"]:
raise ValueError(
"`custom_eval_function` not supported in combination "
"with `enable_async_evaluation=True` config setting!"
)
if self.evaluation_workers is None and (
self.workers.local_worker().input_reader is None
or self.config["evaluation_num_workers"] == 0
):
raise ValueError(
"Evaluation w/o eval workers (calling Algorithm.evaluate() w/o "
"evaluation specifically set up) OR evaluation without input reader "
"OR evaluation with only a local evaluation worker "
"(`evaluation_num_workers=0`) not supported in combination "
"with `enable_async_evaluation=True` config setting!"
)

agent_steps_this_iter = 0
env_steps_this_iter = 0

logger.info(f"Evaluating current policy for {duration} {unit}.")

all_batches = []

# Default done-function returns True, whenever num episodes
# have been completed.
if duration_fn is None:

def duration_fn(num_units_done):
return duration - num_units_done

def remote_fn(worker, w_ref, w_seq_no):
# Pass in seq-no so that eval workers may ignore this call if no update has
# happened since the last call to `remote_fn` (sample).
worker.set_weights(weights=w_ref, weights_seq_no=w_seq_no)
batch = worker.sample()
metrics = worker.get_metrics()
return batch, metrics, w_seq_no

rollout_metrics = []

# How many episodes have we run (across all eval workers)?
num_units_done = 0
_round = 0
errors = []

while len(self._evaluation_async_req_manager.workers) > 0:
units_left_to_do = duration_fn(num_units_done)
if units_left_to_do <= 0:
break

_round += 1
# Use the AsyncRequestsManager to get ready evaluation results and
# metrics.
self._evaluation_async_req_manager.call_on_all_available(
remote_fn=remote_fn,
fn_args=[weights_ref, self._evaluation_weights_seq_number],
)
ready_requests = self._evaluation_async_req_manager.get_ready()

batches = []
i = 0
for actor, requests in ready_requests.items():
for req in requests:
try:
batch, metrics, seq_no = ray.get(req)
# Ignore results, if the weights seq-number does not match (is
# from a previous evaluation step) OR if we have already reached
# the configured duration (e.g. number of episodes to evaluate
# for).
if seq_no == self._evaluation_weights_seq_number and (
i * (1 if unit == "episodes" else rollout * num_envs)
< units_left_to_do
):
batches.append(batch)
rollout_metrics.extend(metrics)
except RayError as e:
errors.append(e)
self._evaluation_async_req_manager.remove_workers(actor)

i += 1

_agent_steps = sum(b.agent_steps() for b in batches)
_env_steps = sum(b.env_steps() for b in batches)

# 1 episode per returned batch.
if unit == "episodes":
num_units_done += len(batches)
# Make sure all batches are exactly one episode.
for ma_batch in batches:
ma_batch = ma_batch.as_multi_agent()
for batch in ma_batch.policy_batches.values():
assert np.sum(batch[SampleBatch.DONES])
# n timesteps per returned batch.
else:
num_units_done += _agent_steps if self._by_agent_steps else _env_steps

if self.reward_estimators:
all_batches.extend(batches)

agent_steps_this_iter += _agent_steps
env_steps_this_iter += _env_steps

logger.info(
f"Ran round {_round} of parallel evaluation "
f"({num_units_done}/{duration if not auto else '?'} "
f"{unit} done)"
)

num_recreated_workers = 0
if errors:
num_recreated_workers = self.try_recover_from_step_attempt(
error=errors[0],
worker_set=self.evaluation_workers,
ignore=eval_cfg.get("ignore_worker_failures"),
recreate=eval_cfg.get("recreate_failed_workers"),
)

metrics = summarize_episodes(
rollout_metrics,
keep_custom_metrics=eval_cfg["keep_per_episode_custom_metrics"],
)

metrics["num_recreated_workers"] = num_recreated_workers

metrics[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps_this_iter
metrics[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps_this_iter
# TODO: Remove this key at some point. Here for backward compatibility.
metrics["timesteps_this_iter"] = env_steps_this_iter

if self.reward_estimators:
# Compute off-policy estimates
metrics["off_policy_estimator"] = {}
total_batch = concat_samples(all_batches)
for name, estimator in self.reward_estimators.items():
estimates = estimator.estimate(total_batch)
metrics["off_policy_estimator"][name] = estimates

# Evaluation does not run for every step.
# Save evaluation metrics on trainer, so it can be attached to
# subsequent step results as latest evaluation result.
self.evaluation_metrics = {"evaluation": metrics}

# Return evaluation results.
return self.evaluation_metrics

@OverrideToImplementCustomLogic
@DeveloperAPI
def training_step(self) -> ResultDict:
Expand Down Expand Up @@ -2211,6 +2406,11 @@ def try_recover_from_step_attempt(self, error, worker_set, ignore, recreate) ->
self.train_exec_impl = self.execution_plan(
worker_set, self.config, **self._kwargs_for_execution_plan()
)
elif self._evaluation_async_req_manager is not None and worker_set is getattr(
self, "evaluation_workers", None
):
self._evaluation_async_req_manager.remove_workers(removed_workers)
self._evaluation_async_req_manager.add_workers(new_workers)

return len(new_workers)

Expand Down Expand Up @@ -2359,6 +2559,15 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
Returns:
The results dict from the training iteration.
"""
# In case we are training (in a thread) parallel to evaluation,
# we may have to re-enable eager mode here (gets disabled in the
# thread).
if (
self.config.get("framework") in ["tf2", "tfe"]
and not tf.executing_eagerly()
):
tf1.enable_eager_execution()

results = None
# Create a step context ...
with TrainIterCtx(algo=self) as train_iter_ctx:
Expand Down Expand Up @@ -2399,16 +2608,31 @@ def _run_one_evaluation(
Returns:
The results dict from the evaluation call.
"""
eval_results = {"evaluation": {}}
eval_results = {
"evaluation": {
"episode_reward_max": np.nan,
"episode_reward_min": np.nan,
"episode_reward_mean": np.nan,
}
}
eval_results["evaluation"]["num_recreated_workers"] = 0

eval_func_to_use = (
self._evaluate_async
if self.config["enable_async_evaluation"]
else self.evaluate
)

num_recreated = 0

try:
if self.config["evaluation_duration"] == "auto":
assert (
train_future is not None
and self.config["evaluation_parallel_to_training"]
)
unit = self.config["evaluation_duration_unit"]
eval_results = self.evaluate(
eval_results = eval_func_to_use(
duration_fn=functools.partial(
self._automatic_evaluation_duration_fn,
unit,
Expand All @@ -2419,7 +2643,8 @@ def _run_one_evaluation(
)
# Run `self.evaluate()` only once per training iteration.
else:
eval_results = self.evaluate()
eval_results = eval_func_to_use()

# In case of any failures, try to ignore/recover the failed evaluation workers.
except Exception as e:
num_recreated = self.try_recover_from_step_attempt(
Expand Down Expand Up @@ -2631,10 +2856,6 @@ def _make_workers(
logdir=self.logdir,
)

@Deprecated(new="Trainer.try_recover_from_step_attempt()", error=False)
def _try_recover(self):
return self.try_recover_from_step_attempt()

@staticmethod
@Deprecated(new="Trainer.validate_config()", error=False)
def _validate_config(config, trainer_or_none):
Expand Down

0 comments on commit 49fb40c

Please sign in to comment.