Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Stats bug fix: EMA stats w/o window would lead to infinite list mem-leak. #45752

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __call__(
)
self.add_n_batch_items(
batch=data,
column="loss_mask",
column=Columns.LOSS_MASK,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! More constants! This will reduce our errors everywhere in the lib.

items_to_add=mask,
num_items=len(mask),
single_agent_episode=sa_episode,
Expand Down
2 changes: 2 additions & 0 deletions rllib/connectors/connector_pipeline_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def set_state(self, state: Dict[str, Any]) -> None:
@override(ConnectorV2)
def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
merged_states = {}
if not states:
return merged_states
for i, (key, item) in enumerate(states[0].items()):
state_list = [state[key] for state in states]
conn = self.connectors[i]
Expand Down
5 changes: 5 additions & 0 deletions rllib/core/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,8 @@ class Columns:
# Postprocessing columns.
ADVANTAGES = "advantages"
VALUE_TARGETS = "value_targets"

# Loss mask. If provided in a train batch, a Learner's compute_loss_for_module
# method should respect the False-set value in here and mask out the respective
# items form the loss.
LOSS_MASK = "loss_mask"
2 changes: 1 addition & 1 deletion rllib/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def summarize_episodes(
episode_reward_mean=avg_reward,
episode_len_mean=avg_length,
episode_media=dict(episode_media),
episodes_this_iter=len(new_episodes),
episodes_timesteps_total=sum(episode_lengths),
policy_reward_min=policy_reward_min,
policy_reward_max=policy_reward_max,
Expand All @@ -259,4 +258,5 @@ def summarize_episodes(
episode_return_max=max_reward,
episode_return_min=min_reward,
episode_return_mean=avg_reward,
episodes_this_iter=len(new_episodes), # deprecate in favor of `num_epsodes_...`
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
parser.add_argument("--num-workers", type=int, default=1)

# This should be >1, otherwise, remote envs make no sense.
parser.add_argument("--num-envs-per-worker", type=int, default=4)
parser.add_argument("--num-envs-per-env-runner", type=int, default=4)

parser.add_argument(
"--as-test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_cli_args():

# example-specific args
# This should be >1, otherwise, remote envs make no sense.
parser.add_argument("--num-envs-per-worker", type=int, default=4)
parser.add_argument("--num-envs-per-env-runner", type=int, default=4)

# general args
parser.add_argument(
Expand Down Expand Up @@ -134,7 +134,7 @@ def default_resource_request(
# Force sub-envs to be ray.actor.ActorHandles, so we can step
# through them in parallel.
remote_worker_envs=True,
num_envs_per_env_runner=args.num_envs_per_worker,
num_envs_per_env_runner=args.num_envs_per_env_runner,
# Use a single worker (however, with n parallelized remote envs, maybe
# even running on another node).
# Action computations occur on the "main" (GPU?) node, while
Expand All @@ -146,7 +146,7 @@ def default_resource_request(
num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")),
# Set the number of CPUs used by the (local) worker, aka "driver"
# to match the number of Ray remote envs.
num_cpus_for_main_process=args.num_envs_per_worker + 1,
num_cpus_for_main_process=args.num_envs_per_env_runner + 1,
)
)

Expand Down
36 changes: 20 additions & 16 deletions rllib/utils/actor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def __init__(
self._restored_actors = set()
self.add_actors(actors or [])

# Maps outstanding async requests to the ids of the actors that
# Maps outstanding async requests to the IDs of the actor IDs that
# are executing them.
self._in_flight_req_to_actor_id: Mapping[ray.ObjectRef, int] = {}

Expand Down Expand Up @@ -457,7 +457,7 @@ def _fetch_result(
calls were fired against.
remote_calls: list of remote calls to fetch.
tags: list of tags used for identifying the remote calls.
timeout_seconds: timeout for the ray.wait() call. Default is None.
timeout_seconds: Timeout (in sec) for the ray.wait() call. Default is None.
return_obj_refs: whether to return ObjectRef instead of actual results.
mark_healthy: whether to mark certain actors healthy based on the results
of these remote calls. Useful, for example, to make sure actors
Expand Down Expand Up @@ -593,10 +593,9 @@ def foreach_actor(
actors "healthy" that respond to the request within `timeout_seconds`
and are currently tagged as "unhealthy".
remote_actor_ids: Apply func on a selected set of remote actors.
timeout_seconds: Ray.get() timeout in seconds. Default is None, which will
block until all remote results have been received. Setting this to 0.0
makes all the remote calls fire-and-forget, while setting this to
None make them synchronous calls.
timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
synchronous execution).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great description!

return_obj_refs: whether to return ObjectRef instead of actual results.
Note, for fault tolerance reasons, these returned ObjectRefs should
never be resolved with ray.get() outside of the context of this manager.
Expand Down Expand Up @@ -649,8 +648,10 @@ def foreach_actor_async(
"""Calls given functions against each actors without waiting for results.

Args:
func: A single, or a list of Callables, that get applied on the list
of specified remote actors.
func: A single Callable applied to all specified remote actors or a list
of Callables, that get applied on the list of specified remote actors.
In the latter case, both list of Callables and list of specified actors
must have the same length.
tag: A tag to identify the results from this async call.
healthy_only: If True, applies `func` only to actors currently tagged
"healthy", otherwise to all actors. If `healthy_only=False` and
Expand Down Expand Up @@ -730,41 +731,44 @@ def foreach_actor_async(
return len(remote_calls)

def _filter_calls_by_tag(
self, tags
self, tags: Union[str, List[str], Tuple[str]]
) -> Tuple[List[ray.ObjectRef], List[ActorHandle], List[str]]:
"""Return all the in flight requests that match the given tags.
"""Return all the in flight requests that match the given tags, if any.

Args:
tags: A str or a list of str. If tags is empty, return all the in flight
tags: A str or a list/tuple of str. If tags is empty, return all the in
flight requests.

Returns:
A tuple of corresponding (remote_calls, remote_actor_ids, valid_tags)

A tuple consisting of a list of the remote calls that match the tag(s),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Did we have this all this time? So we can tag certain tasks and check, if these tasks have been worked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This was always there, albeit rarely used b/c we don't really have any async algos on Learner API, yet.

You can send async requests to the ActorManager with a tag, then - later - fetch the async results from the manager using this tag, kind of as a label to say: I only want these results, the others - even if already ready - I don't care about right now and will fetch them later.

a list of the corresponding remote actor IDs for these calls (same length),
and a list of the tags corresponding to these calls (same length).
"""
if isinstance(tags, str):
tags = {tags}
elif isinstance(tags, (list, tuple)):
tags = set(tags)
else:
raise ValueError(
f"tags must be either a str or a list of str, got {type(tags)}."
f"tags must be either a str or a list/tuple of str, got {type(tags)}."
)
remote_calls = []
remote_actor_ids = []
valid_tags = []
for call, (tag, actor_id) in self._in_flight_req_to_actor_id.items():
# the default behavior is to return all ready results.
if not len(tags) or tag in tags:
if len(tags) == 0 or tag in tags:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we throw away all other results not having this tag? What if we have e.g. 2 tags and we call this function and we need afterwards the two tags separately (maybe results are from different sampling regimes) - cann we distinguish by gthe valid_tags the results?

remote_calls.append(call)
remote_actor_ids.append(actor_id)
valid_tags.append(tag)

return remote_calls, remote_actor_ids, valid_tags

@DeveloperAPI
def fetch_ready_async_reqs(
self,
*,
tags: Union[str, List[str]] = (),
tags: Union[str, List[str], Tuple[str]] = (),
timeout_seconds: Optional[float] = 0.0,
return_obj_refs: bool = False,
mark_healthy: bool = True,
Expand Down
47 changes: 32 additions & 15 deletions rllib/utils/metrics/stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import defaultdict
import time
import threading
from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -205,8 +207,8 @@ def __init__(
self._window = window
self._ema_coeff = ema_coeff

# Timing functionality.
self._start_time = None
# Timing functionality (keep start times per thread).
self._start_times = defaultdict(lambda: None)

# Simply store ths flag for the user of this class.
self._clear_on_reduce = clear_on_reduce
Expand All @@ -227,23 +229,32 @@ def __enter__(self) -> "Stats":
"""Called when entering a context (with which users can measure a time delta).

Returns:
This Stats instance (self).
This Stats instance (self), unless another thread has already entered (and
not exited yet), in which case a copy of `self` is returned. This way, the
second thread(s) cannot mess with the original Stat's (self) time-measuring.
This also means that only the first thread to __enter__ actually logs into
`self` and the following threads' measurements are discarded (logged into
a non-referenced shim-Stats object, which will simply be garbage collected).
"""
assert self._start_time is None, "Concurrent updates not supported!"
self._start_time = time.time()
# In case another thread already is measuring this Stats (timing), simply ignore
# the "enter request" and return a clone of `self`.
thread_id = threading.get_ident()
assert self._start_times[thread_id] is None
self._start_times[thread_id] = time.perf_counter()
return self

def __exit__(self, exc_type, exc_value, tb) -> None:
"""Called when exiting a context (with which users can measure a time delta)."""
assert self._start_time is not None
time_delta = time.time() - self._start_time
thread_id = threading.get_ident()
assert self._start_times[thread_id] is not None
time_delta = time.perf_counter() - self._start_times[thread_id]
self.push(time_delta)

# Call the on_exit handler.
if self._on_exit:
self._on_exit(time_delta)

self._start_time = None
del self._start_times[thread_id]

def peek(self) -> Any:
"""Returns the result of reducing the internal values list.
Expand Down Expand Up @@ -583,31 +594,37 @@ def _reduced_values(self, values=None, window=None) -> Tuple[Any, Any]:
"""
values = values if values is not None else self.values
window = window if window is not None else self._window
inf_window = window in [None, float("inf")]

# Apply the window (if provided and not inf).
values = (
values if window is None or window == float("inf") else values[-window:]
)
values = values if inf_window else values[-window:]

# No reduction method. Return list as-is OR reduce list to len=window.
if self._reduce_method is None:
return values, values

# Special case: Internal values list is empty -> return NaN.
elif len(values) == 0:
return float("nan"), []
if self._reduce_method in ["min", "max", "mean"]:
return float("nan"), []
else:
return 0, []

# Do EMA (always a "mean" reduction; possibly using a window).
if self._ema_coeff is not None:
elif self._ema_coeff is not None:
# Perform EMA reduction over all values in internal values list.
mean_value = values[0]
for v in values[1:]:
mean_value = self._ema_coeff * v + (1.0 - self._ema_coeff) * mean_value
return mean_value, values
if inf_window:
return mean_value, [mean_value]
else:
return mean_value, values
# Do non-EMA reduction (possibly using a window).
else:
# Use the numpy/torch "nan"-prefix to ignore NaN's in our value lists.
if torch and torch.is_tensor(values[0]):
assert all(torch.is_tensor(v) for v in values), values
reduce_meth = getattr(torch, "nan" + self._reduce_method)
reduce_in = torch.stack(values)
if self._reduce_method == "mean":
Expand Down Expand Up @@ -643,7 +660,7 @@ def _reduced_values(self, values=None, window=None) -> Tuple[Any, Any]:

# For window=None|inf (infinite window) and reduce != mean, we don't have to
# keep any values, except the last (reduced) one.
if window in [None, float("inf")] and self._reduce_method != "mean":
if inf_window and self._reduce_method != "mean":
# TODO (sven): What if out values are torch tensors? In this case, we
# would have to do reduction using `torch` above (not numpy) and only
# then return the python primitive AND put the reduced new torch
Expand Down
9 changes: 7 additions & 2 deletions rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ def run_rllib_example_script_experiment(
# New stack.
if config.enable_rl_module_and_learner:
# Define compute resources used.
config.resources(num_gpus=0)
config.learners(
num_learners=args.num_gpus,
num_gpus_per_learner=1 if torch.cuda.is_available() else 0,
Expand All @@ -1490,9 +1491,13 @@ def run_rllib_example_script_experiment(
if args.no_tune:
assert not args.as_test and not args.as_release_test
algo = config.build()
for _ in range(stop.get(TRAINING_ITERATION, args.stop_iters)):
for i in range(stop.get(TRAINING_ITERATION, args.stop_iters)):
results = algo.train()
print(f"R={results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}", end="")
if ENV_RUNNER_RESULTS in results:
print(
f"iter={i} R={results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}",
end="",
)
if EVALUATION_RESULTS in results:
Reval = results[EVALUATION_RESULTS][ENV_RUNNER_RESULTS][
EPISODE_RETURN_MEAN
Expand Down
10 changes: 0 additions & 10 deletions rllib/utils/timer.py

This file was deleted.

Loading