Skip to content

Commit

Permalink
[RLlib] Stats bug fix: EMA stats w/o window would lead to infinite li…
Browse files Browse the repository at this point in the history
…st mem-leak. (ray-project#45752)

Signed-off-by: yucai <yyu1@ebay.com>
  • Loading branch information
sven1977 authored and yucai committed Jun 7, 2024
1 parent 4b013f4 commit 08449dd
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __call__(
)
self.add_n_batch_items(
batch=data,
column="loss_mask",
column=Columns.LOSS_MASK,
items_to_add=mask,
num_items=len(mask),
single_agent_episode=sa_episode,
Expand Down
2 changes: 2 additions & 0 deletions rllib/connectors/connector_pipeline_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def set_state(self, state: Dict[str, Any]) -> None:
@override(ConnectorV2)
def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
merged_states = {}
if not states:
return merged_states
for i, (key, item) in enumerate(states[0].items()):
state_list = [state[key] for state in states]
conn = self.connectors[i]
Expand Down
5 changes: 5 additions & 0 deletions rllib/core/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,8 @@ class Columns:
# Postprocessing columns.
ADVANTAGES = "advantages"
VALUE_TARGETS = "value_targets"

# Loss mask. If provided in a train batch, a Learner's compute_loss_for_module
# method should respect the False-set value in here and mask out the respective
# items form the loss.
LOSS_MASK = "loss_mask"
2 changes: 1 addition & 1 deletion rllib/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def summarize_episodes(
episode_reward_mean=avg_reward,
episode_len_mean=avg_length,
episode_media=dict(episode_media),
episodes_this_iter=len(new_episodes),
episodes_timesteps_total=sum(episode_lengths),
policy_reward_min=policy_reward_min,
policy_reward_max=policy_reward_max,
Expand All @@ -259,4 +258,5 @@ def summarize_episodes(
episode_return_max=max_reward,
episode_return_min=min_reward,
episode_return_mean=avg_reward,
episodes_this_iter=len(new_episodes), # deprecate in favor of `num_epsodes_...`
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
parser.add_argument("--num-workers", type=int, default=1)

# This should be >1, otherwise, remote envs make no sense.
parser.add_argument("--num-envs-per-worker", type=int, default=4)
parser.add_argument("--num-envs-per-env-runner", type=int, default=4)

parser.add_argument(
"--as-test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_cli_args():

# example-specific args
# This should be >1, otherwise, remote envs make no sense.
parser.add_argument("--num-envs-per-worker", type=int, default=4)
parser.add_argument("--num-envs-per-env-runner", type=int, default=4)

# general args
parser.add_argument(
Expand Down Expand Up @@ -134,7 +134,7 @@ def default_resource_request(
# Force sub-envs to be ray.actor.ActorHandles, so we can step
# through them in parallel.
remote_worker_envs=True,
num_envs_per_env_runner=args.num_envs_per_worker,
num_envs_per_env_runner=args.num_envs_per_env_runner,
# Use a single worker (however, with n parallelized remote envs, maybe
# even running on another node).
# Action computations occur on the "main" (GPU?) node, while
Expand All @@ -146,7 +146,7 @@ def default_resource_request(
num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")),
# Set the number of CPUs used by the (local) worker, aka "driver"
# to match the number of Ray remote envs.
num_cpus_for_main_process=args.num_envs_per_worker + 1,
num_cpus_for_main_process=args.num_envs_per_env_runner + 1,
)
)

Expand Down
36 changes: 20 additions & 16 deletions rllib/utils/actor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def __init__(
self._restored_actors = set()
self.add_actors(actors or [])

# Maps outstanding async requests to the ids of the actors that
# Maps outstanding async requests to the IDs of the actor IDs that
# are executing them.
self._in_flight_req_to_actor_id: Mapping[ray.ObjectRef, int] = {}

Expand Down Expand Up @@ -457,7 +457,7 @@ def _fetch_result(
calls were fired against.
remote_calls: list of remote calls to fetch.
tags: list of tags used for identifying the remote calls.
timeout_seconds: timeout for the ray.wait() call. Default is None.
timeout_seconds: Timeout (in sec) for the ray.wait() call. Default is None.
return_obj_refs: whether to return ObjectRef instead of actual results.
mark_healthy: whether to mark certain actors healthy based on the results
of these remote calls. Useful, for example, to make sure actors
Expand Down Expand Up @@ -593,10 +593,9 @@ def foreach_actor(
actors "healthy" that respond to the request within `timeout_seconds`
and are currently tagged as "unhealthy".
remote_actor_ids: Apply func on a selected set of remote actors.
timeout_seconds: Ray.get() timeout in seconds. Default is None, which will
block until all remote results have been received. Setting this to 0.0
makes all the remote calls fire-and-forget, while setting this to
None make them synchronous calls.
timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
synchronous execution).
return_obj_refs: whether to return ObjectRef instead of actual results.
Note, for fault tolerance reasons, these returned ObjectRefs should
never be resolved with ray.get() outside of the context of this manager.
Expand Down Expand Up @@ -649,8 +648,10 @@ def foreach_actor_async(
"""Calls given functions against each actors without waiting for results.
Args:
func: A single, or a list of Callables, that get applied on the list
of specified remote actors.
func: A single Callable applied to all specified remote actors or a list
of Callables, that get applied on the list of specified remote actors.
In the latter case, both list of Callables and list of specified actors
must have the same length.
tag: A tag to identify the results from this async call.
healthy_only: If True, applies `func` only to actors currently tagged
"healthy", otherwise to all actors. If `healthy_only=False` and
Expand Down Expand Up @@ -730,41 +731,44 @@ def foreach_actor_async(
return len(remote_calls)

def _filter_calls_by_tag(
self, tags
self, tags: Union[str, List[str], Tuple[str]]
) -> Tuple[List[ray.ObjectRef], List[ActorHandle], List[str]]:
"""Return all the in flight requests that match the given tags.
"""Return all the in flight requests that match the given tags, if any.
Args:
tags: A str or a list of str. If tags is empty, return all the in flight
tags: A str or a list/tuple of str. If tags is empty, return all the in
flight requests.
Returns:
A tuple of corresponding (remote_calls, remote_actor_ids, valid_tags)
A tuple consisting of a list of the remote calls that match the tag(s),
a list of the corresponding remote actor IDs for these calls (same length),
and a list of the tags corresponding to these calls (same length).
"""
if isinstance(tags, str):
tags = {tags}
elif isinstance(tags, (list, tuple)):
tags = set(tags)
else:
raise ValueError(
f"tags must be either a str or a list of str, got {type(tags)}."
f"tags must be either a str or a list/tuple of str, got {type(tags)}."
)
remote_calls = []
remote_actor_ids = []
valid_tags = []
for call, (tag, actor_id) in self._in_flight_req_to_actor_id.items():
# the default behavior is to return all ready results.
if not len(tags) or tag in tags:
if len(tags) == 0 or tag in tags:
remote_calls.append(call)
remote_actor_ids.append(actor_id)
valid_tags.append(tag)

return remote_calls, remote_actor_ids, valid_tags

@DeveloperAPI
def fetch_ready_async_reqs(
self,
*,
tags: Union[str, List[str]] = (),
tags: Union[str, List[str], Tuple[str]] = (),
timeout_seconds: Optional[float] = 0.0,
return_obj_refs: bool = False,
mark_healthy: bool = True,
Expand Down
47 changes: 32 additions & 15 deletions rllib/utils/metrics/stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import defaultdict
import time
import threading
from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -205,8 +207,8 @@ def __init__(
self._window = window
self._ema_coeff = ema_coeff

# Timing functionality.
self._start_time = None
# Timing functionality (keep start times per thread).
self._start_times = defaultdict(lambda: None)

# Simply store ths flag for the user of this class.
self._clear_on_reduce = clear_on_reduce
Expand All @@ -227,23 +229,32 @@ def __enter__(self) -> "Stats":
"""Called when entering a context (with which users can measure a time delta).
Returns:
This Stats instance (self).
This Stats instance (self), unless another thread has already entered (and
not exited yet), in which case a copy of `self` is returned. This way, the
second thread(s) cannot mess with the original Stat's (self) time-measuring.
This also means that only the first thread to __enter__ actually logs into
`self` and the following threads' measurements are discarded (logged into
a non-referenced shim-Stats object, which will simply be garbage collected).
"""
assert self._start_time is None, "Concurrent updates not supported!"
self._start_time = time.time()
# In case another thread already is measuring this Stats (timing), simply ignore
# the "enter request" and return a clone of `self`.
thread_id = threading.get_ident()
assert self._start_times[thread_id] is None
self._start_times[thread_id] = time.perf_counter()
return self

def __exit__(self, exc_type, exc_value, tb) -> None:
"""Called when exiting a context (with which users can measure a time delta)."""
assert self._start_time is not None
time_delta = time.time() - self._start_time
thread_id = threading.get_ident()
assert self._start_times[thread_id] is not None
time_delta = time.perf_counter() - self._start_times[thread_id]
self.push(time_delta)

# Call the on_exit handler.
if self._on_exit:
self._on_exit(time_delta)

self._start_time = None
del self._start_times[thread_id]

def peek(self) -> Any:
"""Returns the result of reducing the internal values list.
Expand Down Expand Up @@ -583,31 +594,37 @@ def _reduced_values(self, values=None, window=None) -> Tuple[Any, Any]:
"""
values = values if values is not None else self.values
window = window if window is not None else self._window
inf_window = window in [None, float("inf")]

# Apply the window (if provided and not inf).
values = (
values if window is None or window == float("inf") else values[-window:]
)
values = values if inf_window else values[-window:]

# No reduction method. Return list as-is OR reduce list to len=window.
if self._reduce_method is None:
return values, values

# Special case: Internal values list is empty -> return NaN.
elif len(values) == 0:
return float("nan"), []
if self._reduce_method in ["min", "max", "mean"]:
return float("nan"), []
else:
return 0, []

# Do EMA (always a "mean" reduction; possibly using a window).
if self._ema_coeff is not None:
elif self._ema_coeff is not None:
# Perform EMA reduction over all values in internal values list.
mean_value = values[0]
for v in values[1:]:
mean_value = self._ema_coeff * v + (1.0 - self._ema_coeff) * mean_value
return mean_value, values
if inf_window:
return mean_value, [mean_value]
else:
return mean_value, values
# Do non-EMA reduction (possibly using a window).
else:
# Use the numpy/torch "nan"-prefix to ignore NaN's in our value lists.
if torch and torch.is_tensor(values[0]):
assert all(torch.is_tensor(v) for v in values), values
reduce_meth = getattr(torch, "nan" + self._reduce_method)
reduce_in = torch.stack(values)
if self._reduce_method == "mean":
Expand Down Expand Up @@ -643,7 +660,7 @@ def _reduced_values(self, values=None, window=None) -> Tuple[Any, Any]:

# For window=None|inf (infinite window) and reduce != mean, we don't have to
# keep any values, except the last (reduced) one.
if window in [None, float("inf")] and self._reduce_method != "mean":
if inf_window and self._reduce_method != "mean":
# TODO (sven): What if out values are torch tensors? In this case, we
# would have to do reduction using `torch` above (not numpy) and only
# then return the python primitive AND put the reduced new torch
Expand Down
9 changes: 7 additions & 2 deletions rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ def run_rllib_example_script_experiment(
# New stack.
if config.enable_rl_module_and_learner:
# Define compute resources used.
config.resources(num_gpus=0)
config.learners(
num_learners=args.num_gpus,
num_gpus_per_learner=1 if torch.cuda.is_available() else 0,
Expand All @@ -1490,9 +1491,13 @@ def run_rllib_example_script_experiment(
if args.no_tune:
assert not args.as_test and not args.as_release_test
algo = config.build()
for _ in range(stop.get(TRAINING_ITERATION, args.stop_iters)):
for i in range(stop.get(TRAINING_ITERATION, args.stop_iters)):
results = algo.train()
print(f"R={results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}", end="")
if ENV_RUNNER_RESULTS in results:
print(
f"iter={i} R={results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}",
end="",
)
if EVALUATION_RESULTS in results:
Reval = results[EVALUATION_RESULTS][ENV_RUNNER_RESULTS][
EPISODE_RETURN_MEAN
Expand Down
10 changes: 0 additions & 10 deletions rllib/utils/timer.py

This file was deleted.

0 comments on commit 08449dd

Please sign in to comment.