Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add on_workers_recreated callback to Algorithm. #40354

Merged
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2c48317
wip
sven1977 Oct 15, 2023
adad269
Merge branch 'master' of https://github.com/ray-project/ray into issu…
sven1977 Oct 16, 2023
24e3a4f
wip
sven1977 Oct 16, 2023
94eb316
wip
sven1977 Oct 16, 2023
8afbcde
wip
sven1977 Oct 16, 2023
7acc898
Merge branch 'master' into issue64_add_on_worker_created_callback
sven1977 Oct 16, 2023
9c99faa
Merge branch 'master' of https://github.com/ray-project/ray into issu…
sven1977 Oct 17, 2023
7bd9c78
wip
sven1977 Oct 17, 2023
6c7b737
wip
sven1977 Oct 17, 2023
0d9d72a
Merge remote-tracking branch 'origin/issue64_add_on_worker_created_ca…
sven1977 Oct 17, 2023
8378e8c
Merge branch 'master' of https://github.com/ray-project/ray into issu…
sven1977 Oct 17, 2023
317c66f
wip
sven1977 Oct 17, 2023
31839ee
wip
sven1977 Oct 17, 2023
747d46b
Merge branch 'master' into issue64_add_on_worker_created_callback
sven1977 Oct 17, 2023
aa164ec
Merge branch 'master' into issue64_add_on_worker_created_callback
sven1977 Oct 18, 2023
94fea55
fix
sven1977 Oct 18, 2023
794f702
Merge remote-tracking branch 'origin/issue64_add_on_worker_created_ca…
sven1977 Oct 18, 2023
6ab4763
wip
sven1977 Oct 19, 2023
3a1d857
Merge branch 'master' into issue64_add_on_worker_created_callback
sven1977 Oct 19, 2023
6cab9ba
Merge branch 'master' into issue64_add_on_worker_created_callback
sven1977 Oct 19, 2023
d8225d5
wip
sven1977 Oct 19, 2023
31c2e35
Merge remote-tracking branch 'origin/issue64_add_on_worker_created_ca…
sven1977 Oct 19, 2023
7c29ca3
Merge branch 'master' of https://github.com/ray-project/ray into issu…
sven1977 Oct 20, 2023
6a73515
wip
sven1977 Oct 20, 2023
1cd6a89
wip
sven1977 Oct 20, 2023
e50a54c
wip
sven1977 Oct 20, 2023
55a40b8
Update rllib/BUILD
sven1977 Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -965,12 +965,12 @@ py_test(
)

# A3C
py_test(
name = "test_a3c",
tags = ["team:rllib", "algorithms_dir"],
size = "large",
srcs = ["algorithms/a3c/tests/test_a3c.py"]
)
# py_test(
# name = "test_a3c",
# tags = ["team:rllib", "algorithms_dir"],
# size = "large",
# srcs = ["algorithms/a3c/tests/test_a3c.py"]
# )

# AlphaStar
py_test(
Expand Down Expand Up @@ -4478,14 +4478,15 @@ py_test(
# --------------------------------------------------------------------
py_test_module_list(
files = [
"tests/test_dnc.py",
"tests/test_perf.py",
"algorithms/a3c/tests/test_a3c.py"
sven1977 marked this conversation as resolved.
Show resolved Hide resolved
"env/wrappers/tests/test_kaggle_wrapper.py",
"examples/env/tests/test_cliff_walking_wall_env.py",
"examples/env/tests/test_coin_game_non_vectorized_env.py",
"examples/env/tests/test_coin_game_vectorized_env.py",
"examples/env/tests/test_matrix_sequential_social_dilemma.py",
"examples/env/tests/test_wrappers.py",
"tests/test_dnc.py",
"tests/test_perf.py",
"utils/tests/test_utils.py",
],
size = "large",
Expand Down
16 changes: 13 additions & 3 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,11 +1355,13 @@ def remote_fn(worker):

@OverrideToImplementCustomLogic
@DeveloperAPI
def restore_workers(self, workers: WorkerSet):
"""Try to restore failed workers if necessary.
def restore_workers(self, workers: WorkerSet) -> None:
"""Try syncing previously failed and restarted workers with local, if necessary.

Algorithms that use custom RolloutWorkers may override this method to
disable default, and create custom restoration logics.
disable default, and create custom restoration logics. Note that "restoring"
does not include the actual restarting process, but merely what should happen
after such a restart of a (previously failed) worker.

Args:
workers: The WorkerSet to restore. This may be Rollout or Evaluation
Expand Down Expand Up @@ -1397,6 +1399,14 @@ def restore_workers(self, workers: WorkerSet):
mark_healthy=True,
)

# Fire the callback for re-created workers.
self.callbacks.on_workers_recreated(
algorithm=self,
worker_set=workers,
worker_ids=restored,
is_evaluation=workers.local_worker().config.in_evaluation,
)

@OverrideToImplementCustomLogic
@DeveloperAPI
def training_step(self) -> ResultDict:
Expand Down
11 changes: 11 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,17 @@ def validate(self) -> None:
else:
_torch, _ = try_import_torch()

# Can not use "tf" with learner API.
if self.framework_str == "tf" and (
self._enable_rl_module_api or self._enable_learner_api
):
raise ValueError(
"Cannot use `framework=tf` with new API stack! Either do "
"`config.framework('tf2')` OR set both `config.rl_module("
"_enable_rl_module_api=False)` and `config.training("
"_enable_learner_api=False)`."
)

# Check if torch framework supports torch.compile.
if (
_torch is not None
Expand Down
73 changes: 71 additions & 2 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation import RolloutWorker, WorkerSet


@PublicAPI
Expand Down Expand Up @@ -75,6 +75,70 @@ def on_algorithm_init(
"""
pass

@OverrideToImplementCustomLogic
def on_workers_recreated(
self,
*,
algorithm: "Algorithm",
worker_set: "WorkerSet",
worker_ids: List[int],
is_evaluation: bool,
**kwargs,
) -> None:
"""Callback run after one or more workers have been recreated.

You can access (and change) the worker(s) in question via the following code
snippet inside your custom override of this method:

Note that any "worker" inside the algorithm's `self.worker` and
`self.evaluation_workers` WorkerSets are instances of a subclass of EnvRunner.

.. testcode::
from ray.rllib.algorithms.callbacks import DefaultCallbacks

class MyCallbacks(DefaultCallbacks):
def on_workers_recreated(
self,
*,
algorithm,
worker_set,
worker_ids,
is_evaluation,
**kwargs,
):
# Define what you would like to do on the recreated
# workers:
def func(w):
# Here, we just set some arbitrary property to 1.
if is_evaluation:
w._custom_property_for_evaluation = 1
else:
w._custom_property_for_training = 1

# Use the `foreach_workers` method of the worker set and
# only loop through those worker IDs that have been restarted.
# Note that we set `local_worker=False` to NOT include it (local
# workers are never recreated; if they fail, the entire Algorithm
# fails).
worker_set.foreach_worker(
func,
remote_worker_ids=worker_ids,
local_worker=False,
)

Args:
algorithm: Reference to the Algorithm instance.
worker_set: The WorkerSet object in which the workers in question reside.
You can use a `worker_set.foreach_worker(remote_worker_ids=...,
local_worker=False)` method call to execute custom
code on the recreated (remote) workers. Note that the local worker is
never recreated as a failure of this would also crash the Algorithm.
worker_ids: The list of (remote) worker IDs that have been recreated.
is_evaluation: Whether `worker_set` is the evaluation WorkerSet (located
in `Algorithm.evaluation_workers`) or not.
"""
pass

@OverrideToImplementCustomLogic
def on_checkpoint_loaded(
self,
Expand All @@ -98,7 +162,7 @@ def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None:

Args:
policy_id: ID of the newly created policy.
policy: the policy just created.
policy: The policy just created.
"""
pass

Expand Down Expand Up @@ -494,6 +558,11 @@ def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
for callback in self._callback_list:
callback.on_algorithm_init(algorithm=algorithm, **kwargs)

@override(DefaultCallbacks)
def on_workers_recreated(self, **kwargs) -> None:
for callback in self._callback_list:
callback.on_workers_recreated(**kwargs)

@override(DefaultCallbacks)
def on_checkpoint_loaded(self, *, algorithm: "Algorithm", **kwargs) -> None:
for callback in self._callback_list:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, algo_class=None):
# Add constructor kwargs here (if any).
}

# enable the rl module api by default
# Enable the rl module api by default.
self.rl_module(_enable_rl_module_api=True)
self.training(_enable_learner_api=True)

Expand Down
54 changes: 54 additions & 0 deletions rllib/algorithms/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,40 @@
import unittest

import ray
from ray.rllib.algorithms.appo import APPOConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
import ray.rllib.algorithms.dqn as dqn
from ray.rllib.algorithms.pg import PGConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.env.cartpole_crashing import CartPoleCrashing
from ray.rllib.evaluation.episode import Episode
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator


class OnWorkerCreatedCallbacks(DefaultCallbacks):
def on_workers_recreated(
self,
*,
algorithm,
worker_set,
worker_ids,
is_evaluation,
**kwargs,
):
# Store in the algorithm object's counters the number of times, this worker
# (ID'd by index and whether eval or not) has been recreated/restarted.
for id_ in worker_ids:
key = f"{'eval_' if is_evaluation else ''}worker_{id_}_recreated"
# Increase the counter.
algorithm._counters[key] += 1
print(f"changed {key} to {algorithm._counters[key]}")

# Execute some dummy code on each of the recreated workers.
results = worker_set.foreach_worker(lambda w: w.ping())
print(results) # should print "pong" n times (one for each recreated worker).


class InitAndCheckpointRestoredCallbacks(DefaultCallbacks):
def on_algorithm_init(self, *, algorithm, **kwargs):
self._on_init_was_called = True
Expand Down Expand Up @@ -84,6 +109,35 @@ def setUpClass(cls):
def tearDownClass(cls):
ray.shutdown()

def test_on_workers_recreated_callback(self):
config = (
APPOConfig()
.environment(CartPoleCrashing)
.callbacks(OnWorkerCreatedCallbacks)
.rollouts(num_rollout_workers=2)
.fault_tolerance(recreate_failed_workers=True)
)

for _ in framework_iterator(config, frameworks=("tf2", "torch")):
algo = config.build()
original_worker_ids = algo.workers.healthy_worker_ids()
for id_ in original_worker_ids:
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] == 0)

# After building the algorithm, we should have 2 healthy (remote) workers.
self.assertTrue(len(original_worker_ids) == 2)

# Train a bit (and have the envs/workers crash a couple of times).
for _ in range(3):
algo.train()

# After training, each new worker should have been recreated at least once.
new_worker_ids = algo.workers.healthy_worker_ids()
self.assertTrue(len(new_worker_ids) == 2)
for id_ in new_worker_ids:
self.assertTrue(algo._counters[f"worker_{id_}_recreated"] >= 1)
algo.stop()

def test_on_init_and_checkpoint_loaded(self):
config = (
PPOConfig()
Expand Down
37 changes: 18 additions & 19 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
in the returned set as well (default: True). If `num_workers`
is 0, always create a local worker.
logdir: Optional logging directory for workers.
_setup: Whether to setup workers. This is only for testing.
_setup: Whether to actually set up workers. This is only for testing.
"""
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig

Expand Down Expand Up @@ -635,9 +635,9 @@ def foreach_worker(
self,
func: Callable[[RolloutWorker], T],
*,
local_worker=True,
local_worker: bool = True,
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
healthy_only: bool = False,
remote_worker_ids: List[int] = None,
timeout_seconds: Optional[int] = None,
return_obj_refs: bool = False,
Expand All @@ -647,10 +647,9 @@ def foreach_worker(

Args:
func: The function to call for each worker (as only arg).
local_worker: Whether apply func on local worker too. Default is True.
healthy_only: Apply func on known active workers only. By default
this will apply func on all workers regardless of their states.
remote_worker_ids: Apply func on a selected set of remote workers.
local_worker: Whether apply `func` on local worker too. Default is True.
healthy_only: Apply `func` on known-to-be healthy workers only.
remote_worker_ids: Apply `func` on a selected set of remote workers.
timeout_seconds: Time to wait for results. Default is None.
return_obj_refs: whether to return ObjectRef instead of actual results.
Note, for fault tolerance reasons, these returned ObjectRefs should
Expand Down Expand Up @@ -689,20 +688,19 @@ def foreach_worker_with_id(
self,
func: Callable[[int, RolloutWorker], T],
*,
local_worker=True,
local_worker: bool = True,
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
healthy_only: bool = False,
remote_worker_ids: List[int] = None,
timeout_seconds: Optional[int] = None,
) -> List[T]:
"""Similar to foreach_worker(), but calls the function with id of the worker too.

Args:
func: The function to call for each worker (as only arg).
local_worker: Whether apply func on local worker too. Default is True.
healthy_only: Apply func on known active workers only. By default
this will apply func on all workers regardless of their states.
remote_worker_ids: Apply func on a selected set of remote workers.
local_worker: Whether apply `func` on local worker too. Default is True.
healthy_only: Apply `func` on known-to-be healthy workers only.
remote_worker_ids: Apply `func` on a selected set of remote workers.
timeout_seconds: Time to wait for results. Default is None.

Returns:
Expand Down Expand Up @@ -736,7 +734,7 @@ def foreach_worker_async(
func: Callable[[RolloutWorker], T],
*,
# TODO(jungong) : switch to True once Algorithm is migrated.
healthy_only=False,
healthy_only: bool = False,
remote_worker_ids: List[int] = None,
) -> int:
"""Calls the given function asynchronously with each worker as the argument.
Expand All @@ -747,9 +745,8 @@ def foreach_worker_async(

Args:
func: The function to call for each worker (as only arg).
healthy_only: Apply func on known active workers only. By default
this will apply func on all workers regardless of their states.
remote_worker_ids: Apply func on a selected set of remote workers.
healthy_only: Apply `func` on known-to-be healthy workers only.
remote_worker_ids: Apply `func` on a selected set of remote workers.

Returns:
The number of async requests that are currently in-flight.
Expand All @@ -773,6 +770,7 @@ def fetch_ready_async_reqs(
Args:
timeout_seconds: Time to wait for results. Default is 0, meaning
those requests that are already ready.
return_obj_refs: Whether to return ObjectRef instead of actual results.
mark_healthy: Whether to mark the worker as healthy based on call results.

Returns:
Expand Down Expand Up @@ -888,15 +886,16 @@ def foreach_env_with_context(

@DeveloperAPI
def probe_unhealthy_workers(self) -> List[int]:
"""Checks the unhealth workers, and try restoring their states.
"""Checks for unhealthy workers and tries restoring their states.

Returns:
IDs of the workers that were restored.
List of IDs of the workers that were restored.
"""
return self.__worker_manager.probe_unhealthy_actors(
timeout_seconds=self._remote_config.worker_health_probe_timeout_s
)

# TODO (sven): Deprecate once ARS/ES have been moved to `rllib_contrib`.
@staticmethod
def _from_existing(
local_worker: RolloutWorker, remote_workers: List[ActorHandle] = None
Expand Down
1 change: 1 addition & 0 deletions rllib/examples/nested_action_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
),
"b": Box(-10.0, 10.0, (2,)),
"c": MultiDiscrete([3, 3]),
"d": Discrete(2),
}
),
},
Expand Down
Loading
Loading