Skip to content

Commit

Permalink
Tune test autoscaler / fix stale node detection bug (#21516)
Browse files Browse the repository at this point in the history
See #21458. Currently, Tune keeps its own list of alive node IPs, but this information is only updated every 10 seconds and is usually stale when a new node is added. Because of this, the first trial scheduled on this node is usually marked as failed. This PR adds a test confirming this behavior and gets rid of the unneeded code path.

Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 19, 2022
1 parent 1f563aa commit 8fd5b7a
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 117 deletions.
1 change: 0 additions & 1 deletion ci/travis/build-multinode-image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def build_multinode_image(source_image: str, target_image: str):
f.write(f"FROM {source_image}\n")
f.write("RUN sudo apt update\n")
f.write("RUN sudo apt install -y openssh-server\n")
f.write("RUN sudo service ssh start\n")

subprocess.check_output(
f"docker build -t {target_image} .", shell=True, cwd=tempdir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"sudo chmod 700 ~/.ssh && "
"sudo chmod 600 ~/.ssh/authorized_keys && "
"sudo chmod 600 ~/ray_bootstrap_key.pem && "
"sudo chown ray:users ~/.ssh ~/.ssh/authorized_keys && "
"sudo chown ray:users "
"~/.ssh ~/.ssh/authorized_keys ~/ray_bootstrap_key.pem && "
"{ensure_ssh} && "
"sleep 1 && "
"RAY_FAKE_CLUSTER=1 ray start --head "
Expand Down Expand Up @@ -131,7 +132,8 @@ def create_node_spec(head: bool,

ensure_ssh = ("((sudo apt update && sudo apt install -y openssh-server && "
"sudo service ssh start) || true)") if not bool(
int(os.environ.get("RAY_HAS_SSH", "0"))) else "true"
int(os.environ.get("RAY_HAS_SSH", "0")
or "0")) else "sudo service ssh start"

cmd_kwargs = dict(
ensure_ssh=ensure_ssh,
Expand Down
32 changes: 0 additions & 32 deletions python/ray/tune/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,38 +632,6 @@ def get_running_trials(self) -> List[Trial]:
"""Returns the running trials."""
return list(self._running.values())

def get_alive_node_ips(self):
now = time.time()
if now - self._last_ip_refresh < self._refresh_period:
return self._last_ip_addresses
logger.debug("Checking ips from Ray state.")
self._last_ip_refresh = now
nodes = ray.state.nodes()
ip_addresses = set()
for node in nodes:
if node["alive"]:
ip_addresses.add(node["NodeManagerAddress"])
self._last_ip_addresses = ip_addresses
return ip_addresses

def get_current_trial_ips(self):
return {t.node_ip for t in self.get_running_trials()}

def get_next_failed_trial(self) -> Optional[Trial]:
"""Gets the first trial found to be running on a node presumed dead.
Returns:
A Trial object that is ready for failure processing. None if
no failure detected.
"""
if ray.worker._mode() != ray.worker.LOCAL_MODE:
live_cluster_ips = self.get_alive_node_ips()
if live_cluster_ips - self.get_current_trial_ips():
for trial in self.get_running_trials():
if trial.node_ip and trial.node_ip not in live_cluster_ips:
return trial
return None

def get_next_available_trial(
self, timeout: Optional[float] = None) -> Optional[Trial]:
if not self._running:
Expand Down
96 changes: 96 additions & 0 deletions python/ray/tune/tests/test_multinode_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import unittest

import ray
from ray import tune
from ray.autoscaler._private.fake_multi_node.test_utils import DockerCluster
from ray.tune.callback import Callback
from ray.tune.trial import Trial


@ray.remote
Expand Down Expand Up @@ -72,6 +75,99 @@ def testClusterAutoscaling(self):

print("Node was restarted.")

def testAutoscalingNewNode(self):
"""Test that newly added nodes from autoscaling are not stale."""
self.cluster.update_config({
"provider": {
"head_resources": {
"CPU": 4,
"GPU": 0
}
},
"available_node_types": {
"ray.worker.cpu": {
"resources": {
"CPU": 4
},
"min_workers": 0, # No minimum nodes
"max_workers": 2,
},
"ray.worker.gpu": {
"min_workers": 0,
"max_workers": 0, # No GPU nodes
}
},
})
self.cluster.start()
self.cluster.connect(client=True, timeout=120)

def autoscaling_train(config):
time.sleep(120)
tune.report(1.)

tune.run(
autoscaling_train,
num_samples=3,
resources_per_trial={"cpu": 4},
fail_fast=True)

def testFaultTolerance(self):
"""Test that Tune run can recover from a failed node.
When `max_failures` is set to larger than zero.
"""

self.cluster.update_config({
"provider": {
"head_resources": {
"CPU": 4,
"GPU": 0
}
},
"available_node_types": {
"ray.worker.cpu": {
"resources": {
"CPU": 4
},
"min_workers": 0, # No minimum nodes
"max_workers": 2,
},
"ray.worker.gpu": {
"min_workers": 0,
"max_workers": 0, # No GPU nodes
}
},
})
self.cluster.start()
self.cluster.connect(client=True, timeout=120)

def train(config):
time.sleep(120)
tune.report(1.)

class FailureInjectionCallback(Callback):
def __init__(self, cluster):
self._cluster = cluster
self._killed = False

def on_step_begin(self, iteration, trials, **info):
if not self._killed and len(trials) == 3 and all(
trial.status == Trial.RUNNING for trial in trials):
self._cluster.kill_node(num=2)
self._killed = True

tune.run(
train,
num_samples=3,
resources_per_trial={"cpu": 4},
max_failures=1,
callbacks=[FailureInjectionCallback(self.cluster)],
# The following two are to be removed once we have proper setup
# for killing nodes while in ray client mode.
_remote=False,
local_dir="/tmp/ray_results/",
)


if __name__ == "__main__":
import pytest
Expand Down
3 changes: 0 additions & 3 deletions python/ray/tune/tests/test_trial_executor_inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def fetch_result(self):
def get_next_available_trial(self):
return None

def get_next_failed_trial(self):
return None

def get_running_trials(self):
return []

Expand Down
18 changes: 10 additions & 8 deletions python/ray/tune/tests/test_trial_runner_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ray
from ray import tune
from ray.exceptions import RayActorError
from ray.rllib import _register_all
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, \
Expand Down Expand Up @@ -67,19 +68,20 @@ def on_experiment_end(self, **info):
class _MockTrialExecutor(RayTrialExecutor):
def __init__(self):
super().__init__()
self.results = {}
self.next_trial = None
self.failed_trial = None
self.results = {}
self.should_fail_in_fetch_result = False

def fetch_result(self, trial):
return [self.results.get(trial, {})]
if self.should_fail_in_fetch_result:
raise RayActorError(
"The actor died unexpectedly before finishing this task.")
else:
return [self.results.get(trial, {})]

def get_next_available_trial(self, timeout=None):
return self.next_trial or super().get_next_available_trial()

def get_next_failed_trial(self):
return self.failed_trial or super().get_next_failed_trial()


class TrialRunnerCallbacks(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -184,7 +186,8 @@ def testCallbackSteps(self):
self.callback.state["trial_complete"]["trial"].trial_id, "two")

# Let the first trial error
self.executor.failed_trial = trials[0]
self.executor.next_trial = trials[0]
self.executor.should_fail_in_fetch_result = True
self.trial_runner.step()
self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)
self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id,
Expand Down Expand Up @@ -277,7 +280,6 @@ def get_positions(callbacks):
lc = LegacyLoggerCallback(logger_classes=DEFAULT_LOGGERS)
callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(),
None)
print(callbacks)
first_logger_pos, last_logger_pos, syncer_pos = get_positions(
callbacks)
self.assertLess(last_logger_pos, syncer_pos)
Expand Down
3 changes: 0 additions & 3 deletions python/ray/tune/tests/test_trial_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,6 @@ def fetch_result(self):
def get_next_available_trial(self):
return None

def get_next_failed_trial(self):
return None

def get_running_trials(self):
return []

Expand Down
10 changes: 0 additions & 10 deletions python/ray/tune/trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,6 @@ def get_next_available_trial(self) -> Optional[Trial]:
"""
pass

@abstractmethod
def get_next_failed_trial(self) -> Optional[Trial]:
"""Non-blocking call that detects and returns one failed trial.
Returns:
A Trial object that is ready for failure processing. None if
no failure detected.
"""
pass

@abstractmethod
def fetch_result(self, trial: Trial) -> List[Trial]:
"""Fetches one result for the trial.
Expand Down
105 changes: 47 additions & 58 deletions python/ray/tune/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,65 +817,54 @@ def _get_next_trial(self):
return trial

def _process_events(self, timeout: Optional[float] = None):
with warn_if_slow("get_next_failed_trial"):
failed_trial = self.trial_executor.get_next_failed_trial()
if failed_trial:
error_msg = (
"{} (IP: {}) detected as stale. This is likely because the "
"node was lost").format(failed_trial, failed_trial.node_ip)
logger.info(error_msg)
with warn_if_slow("process_failed_trial"):
self._process_trial_failure(failed_trial, error_msg=error_msg)
else:
# TODO(ujvl): Consider combining get_next_available_trial and
# fetch_result functionality so that we don't timeout on fetch.
trial = self.trial_executor.get_next_available_trial(
timeout=timeout) # blocking
if not trial:
return
if trial.is_restoring:
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
with warn_if_slow("callbacks.on_trial_restore"):
self._callbacks.on_trial_restore(
iteration=self._iteration,
trials=self._trials,
trial=trial)
elif trial.is_saving:
with warn_if_slow("process_trial_save") as _profile:
self._process_trial_save(trial)
with warn_if_slow("callbacks.on_trial_save"):
self._callbacks.on_trial_save(
iteration=self._iteration,
trials=self._trials,
trial=trial)
if _profile.too_slow and trial.sync_on_checkpoint:
# TODO(ujvl): Suggest using cloud checkpointing once
# API has converged.

msg = (
"Consider turning off forced head-worker trial "
"checkpoint syncs by setting sync_on_checkpoint=False"
". Note that this may result in faulty trial "
"restoration if a failure occurs while the checkpoint "
"is being synced from the worker to the head node.")

if trial.location.hostname and (trial.location.hostname !=
get_node_ip_address()):
if log_once("tune_head_worker_checkpoint"):
logger.warning(msg)
# TODO(ujvl): Consider combining get_next_available_trial and
# fetch_result functionality so that we don't timeout on fetch.
trial = self.trial_executor.get_next_available_trial(
timeout=timeout) # blocking
if not trial:
return
if trial.is_restoring:
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
with warn_if_slow("callbacks.on_trial_restore"):
self._callbacks.on_trial_restore(
iteration=self._iteration,
trials=self._trials,
trial=trial)
elif trial.is_saving:
with warn_if_slow("process_trial_save") as _profile:
self._process_trial_save(trial)
with warn_if_slow("callbacks.on_trial_save"):
self._callbacks.on_trial_save(
iteration=self._iteration,
trials=self._trials,
trial=trial)
if _profile.too_slow and trial.sync_on_checkpoint:
# TODO(ujvl): Suggest using cloud checkpointing once
# API has converged.

else:
with warn_if_slow("process_trial"):
self._process_trial(trial)

# `self._queued_trial_decisions` now contains a final decision
# based on all results
if trial not in self._cached_trial_decisions:
final_decision = self._queued_trial_decisions.pop(
trial.trial_id, None)
if final_decision:
self._execute_action(trial, final_decision)
msg = ("Consider turning off forced head-worker trial "
"checkpoint syncs by setting sync_on_checkpoint=False"
". Note that this may result in faulty trial "
"restoration if a failure occurs while the checkpoint "
"is being synced from the worker to the head node.")

if trial.location.hostname and (trial.location.hostname !=
get_node_ip_address()):
if log_once("tune_head_worker_checkpoint"):
logger.warning(msg)

else:
with warn_if_slow("process_trial"):
self._process_trial(trial)

# `self._queued_trial_decisions` now contains a final decision
# based on all results
if trial not in self._cached_trial_decisions:
final_decision = self._queued_trial_decisions.pop(
trial.trial_id, None)
if final_decision:
self._execute_action(trial, final_decision)

def _process_trial(self, trial):
"""Processes a trial result.
Expand Down

0 comments on commit 8fd5b7a

Please sign in to comment.