-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add on_workers_recreated
callback to Algorithm.
#40354
Changes from 12 commits
2c48317
adad269
24e3a4f
94eb316
8afbcde
7acc898
9c99faa
7bd9c78
6c7b737
0d9d72a
8378e8c
317c66f
31839ee
747d46b
aa164ec
94fea55
794f702
6ab4763
3a1d857
6cab9ba
d8225d5
31c2e35
7c29ca3
6a73515
1cd6a89
e50a54c
55a40b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,39 @@ | |
import unittest | ||
|
||
import ray | ||
from ray.rllib.algorithms.appo import APPOConfig | ||
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks | ||
import ray.rllib.algorithms.dqn as dqn | ||
from ray.rllib.algorithms.pg import PGConfig | ||
from ray.rllib.examples.env.cartpole_crashing import CartPoleCrashing | ||
from ray.rllib.evaluation.episode import Episode | ||
from ray.rllib.examples.env.random_env import RandomEnv | ||
from ray.rllib.utils.test_utils import framework_iterator | ||
|
||
|
||
class OnWorkerCreatedCallbacks(DefaultCallbacks): | ||
def on_workers_recreated( | ||
self, | ||
*, | ||
algorithm, | ||
worker_set, | ||
worker_ids, | ||
is_evaluation, | ||
**kwargs, | ||
): | ||
# Store in the algorithm object's counters the number of times, this worker | ||
# (ID'd by index and whether eval or not) has been recreated/restarted. | ||
for id_ in worker_ids: | ||
key = f"{'eval_' if is_evaluation else ''}worker_{id_}_recreated" | ||
# Increase the counter. | ||
algorithm._counters[key] += 1 | ||
print(f"changed {key} to {algorithm._counters[key]}") | ||
|
||
# Execute some dummy code on each of the recreated workers. | ||
results = worker_set.foreach_worker(lambda w: w.ping()) | ||
print(results) # should print "pong" n times (one for each recreated worker). | ||
|
||
|
||
class EpisodeAndSampleCallbacks(DefaultCallbacks): | ||
def __init__(self): | ||
super().__init__() | ||
|
@@ -74,6 +99,40 @@ def setUpClass(cls): | |
def tearDownClass(cls): | ||
ray.shutdown() | ||
|
||
def test_on_workers_recreated_callback(self): | ||
config = ( | ||
APPOConfig() | ||
.environment(CartPoleCrashing) | ||
.callbacks(OnWorkerCreatedCallbacks) | ||
.rollouts(num_rollout_workers=2) | ||
.fault_tolerance(recreate_failed_workers=True) | ||
) | ||
|
||
for _ in framework_iterator(config, frameworks=("tf2", "torch")): | ||
algo = config.build() | ||
original_worker_ids = algo.workers.healthy_worker_ids() | ||
for id_ in original_worker_ids: | ||
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 0) | ||
|
||
# After building the algorithm, we should have 2 healthy (remote) workers. | ||
self.assertTrue(len(original_worker_ids) == 2) | ||
|
||
# Train a bit (and have the envs/workers crash a couple of times). | ||
for _ in range(3): | ||
algo.train() | ||
|
||
# After training, each new worker should have been recreated. | ||
worker_ids_2 = algo.workers.healthy_worker_ids() | ||
for id_ in worker_ids_2: | ||
# A newly created worker: It's recreated counter should be 1. | ||
if id_ not in original_worker_ids: | ||
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this not become flakey in case a worker gets created >1 times? We iterate three times so i'd expect that after each iteration, workers might be recreated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine a recreated worker gets a new ID, hence us checking, which are the new healthy worker IDs right before this loop. |
||
# Still an original worker, recreated counter should still be 0. | ||
else: | ||
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 0) | ||
|
||
algo.stop() | ||
|
||
def test_episode_and_sample_callbacks(self): | ||
config = ( | ||
PGConfig() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,7 +113,7 @@ def __init__( | |
in the returned set as well (default: True). If `num_workers` | ||
is 0, always create a local worker. | ||
logdir: Optional logging directory for workers. | ||
_setup: Whether to setup workers. This is only for testing. | ||
_setup: Whether to actually set up workers. This is only for testing. | ||
""" | ||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig | ||
|
||
|
@@ -635,9 +635,8 @@ def foreach_worker( | |
self, | ||
func: Callable[[RolloutWorker], T], | ||
*, | ||
local_worker=True, | ||
# TODO(jungong) : switch to True once Algorithm is migrated. | ||
healthy_only=False, | ||
local_worker: bool = True, | ||
healthy_only: bool = True, | ||
remote_worker_ids: List[int] = None, | ||
timeout_seconds: Optional[int] = None, | ||
return_obj_refs: bool = False, | ||
|
@@ -647,10 +646,9 @@ def foreach_worker( | |
|
||
Args: | ||
func: The function to call for each worker (as only arg). | ||
local_worker: Whether apply func on local worker too. Default is True. | ||
healthy_only: Apply func on known active workers only. By default | ||
this will apply func on all workers regardless of their states. | ||
remote_worker_ids: Apply func on a selected set of remote workers. | ||
local_worker: Whether apply `func` on local worker too. Default is True. | ||
healthy_only: Apply `func` on known-to-be healthy workers only. | ||
remote_worker_ids: Apply `func` on a selected set of remote workers. | ||
timeout_seconds: Time to wait for results. Default is None. | ||
return_obj_refs: whether to return ObjectRef instead of actual results. | ||
Note, for fault tolerance reasons, these returned ObjectRefs should | ||
|
@@ -689,20 +687,18 @@ def foreach_worker_with_id( | |
self, | ||
func: Callable[[int, RolloutWorker], T], | ||
*, | ||
local_worker=True, | ||
# TODO(jungong) : switch to True once Algorithm is migrated. | ||
healthy_only=False, | ||
local_worker: bool = True, | ||
healthy_only: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just out of curiosity: Why can we switch to True and what migration was Jun's comment directed at? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good one :) I think now that Algorithm uses the ActorManager (within WorkerSet), we can switch the default behavior to only operate on the healthy workers. Which makes sense: You wouldn't want to "ping" an already failed worker with an |
||
remote_worker_ids: List[int] = None, | ||
timeout_seconds: Optional[int] = None, | ||
) -> List[T]: | ||
"""Similar to foreach_worker(), but calls the function with id of the worker too. | ||
|
||
Args: | ||
func: The function to call for each worker (as only arg). | ||
local_worker: Whether apply func on local worker too. Default is True. | ||
healthy_only: Apply func on known active workers only. By default | ||
this will apply func on all workers regardless of their states. | ||
remote_worker_ids: Apply func on a selected set of remote workers. | ||
local_worker: Whether apply `func` on local worker too. Default is True. | ||
healthy_only: Apply `func` on known-to-be healthy workers only. | ||
remote_worker_ids: Apply `func` on a selected set of remote workers. | ||
timeout_seconds: Time to wait for results. Default is None. | ||
|
||
Returns: | ||
|
@@ -735,8 +731,7 @@ def foreach_worker_async( | |
self, | ||
func: Callable[[RolloutWorker], T], | ||
*, | ||
# TODO(jungong) : switch to True once Algorithm is migrated. | ||
healthy_only=False, | ||
healthy_only: bool = True, | ||
remote_worker_ids: List[int] = None, | ||
) -> int: | ||
"""Calls the given function asynchronously with each worker as the argument. | ||
|
@@ -747,9 +742,8 @@ def foreach_worker_async( | |
|
||
Args: | ||
func: The function to call for each worker (as only arg). | ||
healthy_only: Apply func on known active workers only. By default | ||
this will apply func on all workers regardless of their states. | ||
remote_worker_ids: Apply func on a selected set of remote workers. | ||
healthy_only: Apply `func` on known-to-be healthy workers only. | ||
remote_worker_ids: Apply `func` on a selected set of remote workers. | ||
|
||
Returns: | ||
The number of async requests that are currently in-flight. | ||
|
@@ -773,6 +767,7 @@ def fetch_ready_async_reqs( | |
Args: | ||
timeout_seconds: Time to wait for results. Default is 0, meaning | ||
those requests that are already ready. | ||
return_obj_refs: Whether to return ObjectRef instead of actual results. | ||
mark_healthy: Whether to mark the worker as healthy based on call results. | ||
|
||
Returns: | ||
|
@@ -888,15 +883,16 @@ def foreach_env_with_context( | |
|
||
@DeveloperAPI | ||
def probe_unhealthy_workers(self) -> List[int]: | ||
"""Checks the unhealth workers, and try restoring their states. | ||
"""Checks for unhealthy workers and tries restoring their states. | ||
|
||
Returns: | ||
IDs of the workers that were restored. | ||
List of IDs of the workers that were restored. | ||
""" | ||
return self.__worker_manager.probe_unhealthy_actors( | ||
timeout_seconds=self._remote_config.worker_health_probe_timeout_s | ||
) | ||
|
||
# TODO (sven): Deprecate once ARS/ES have been moved to `rllib_contrib`. | ||
@staticmethod | ||
def _from_existing( | ||
local_worker: RolloutWorker, remote_workers: List[ActorHandle] = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use code-blocks anymore, we use
testcode
andtestoutput
these days!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh no! :D Will fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done