Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Add timeout to retry_fn to catch hanging syncs #28155

Merged
merged 13 commits into from Sep 2, 2022
15 changes: 11 additions & 4 deletions python/ray/tune/execution/ray_trial_executor.py
Expand Up @@ -217,7 +217,7 @@ def __init__(

self._has_cleaned_up_pgs = False
self._reuse_actors = reuse_actors
# The maxlen will be updated when `set_max_pending_trials()` is called
# The maxlen will be updated when `setup(max_pending_trials)` is called
self._cached_actor_pg = deque(maxlen=1)
self._pg_manager = _PlacementGroupManager(prefix=_get_tune_pg_prefix())
self._staged_trials = set()
Expand All @@ -235,16 +235,20 @@ def __init__(
self._buffer_max_time_s = float(
os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0)
)
self._trainable_kwargs = {}

def set_max_pending_trials(self, max_pending: int) -> None:
def setup(
self, max_pending_trials: int, trainable_kwargs: Optional[Dict] = None
) -> None:
if len(self._cached_actor_pg) > 0:
logger.warning(
"Cannot update maximum number of queued actors for reuse "
"during a run."
)
else:
self._cached_actor_pg = deque(maxlen=max_pending)
self._pg_manager.set_max_staging(max_pending)
self._cached_actor_pg = deque(maxlen=max_pending_trials)
self._pg_manager.set_max_staging(max_pending_trials)
self._trainable_kwargs = trainable_kwargs or {}

def set_status(self, trial: Trial, status: str) -> None:
"""Sets status and checkpoints metadata if needed.
Expand Down Expand Up @@ -377,6 +381,9 @@ def _setup_remote_runner(self, trial):
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
kwargs["custom_syncer"] = trial.custom_syncer

if self._trainable_kwargs:
kwargs.update(self._trainable_kwargs)

# Throw a meaningful error if trainable does not use the
# new API
sig = inspect.signature(trial.get_trainable_cls())
Expand Down
11 changes: 9 additions & 2 deletions python/ray/tune/execution/trial_runner.py
Expand Up @@ -198,6 +198,8 @@ def _serialize_and_write():
exclude = ["*/checkpoint_*"]

if self._syncer:
# Todo: Implement sync_timeout for experiment-level syncing
# (it is currently only used for trainable-to-cloud syncing)
if force:
# Wait until previous sync command finished
self._syncer.wait()
Expand Down Expand Up @@ -341,7 +343,13 @@ def __init__(
else:
# Manual override
self._max_pending_trials = int(max_pending_trials)
self.trial_executor.set_max_pending_trials(self._max_pending_trials)

sync_config = sync_config or SyncConfig()

self.trial_executor.setup(
max_pending_trials=self._max_pending_trials,
trainable_kwargs={"sync_timeout": sync_config.sync_timeout},
)

self._metric = metric

Expand Down Expand Up @@ -385,7 +393,6 @@ def __init__(
if self._local_checkpoint_dir:
os.makedirs(self._local_checkpoint_dir, exist_ok=True)

sync_config = sync_config or SyncConfig()
self._remote_checkpoint_dir = remote_checkpoint_dir

self._syncer = get_node_to_storage_syncer(sync_config)
Expand Down
6 changes: 6 additions & 0 deletions python/ray/tune/syncer.py
Expand Up @@ -40,6 +40,9 @@
# Syncing period for syncing checkpoints between nodes or to cloud.
DEFAULT_SYNC_PERIOD = 300

# Default sync timeout after which syncing processes are aborted
DEFAULT_SYNC_TIMEOUT = 1800

_EXCLUDE_FROM_SYNC = [
"./checkpoint_-00001",
"./checkpoint_tmp*",
Expand Down Expand Up @@ -85,6 +88,8 @@ class SyncConfig:
is asynchronous and best-effort. This does not affect persistent
storage syncing. Defaults to True.
sync_period: Syncing period for syncing between nodes.
sync_timeout: Timeout after which running sync processes are aborted.
Currently only affects trial-to-cloud syncing.

"""

Expand All @@ -93,6 +98,7 @@ class SyncConfig:

sync_on_checkpoint: bool = True
sync_period: int = DEFAULT_SYNC_PERIOD
sync_timeout: int = DEFAULT_SYNC_TIMEOUT

def _repr_html_(self) -> str:
"""Generate an HTML representation of the SyncConfig.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/tests/test_ray_trial_executor.py
Expand Up @@ -499,7 +499,7 @@ def testHasResourcesForTrialWithCaching(self):

executor = RayTrialExecutor(reuse_actors=True)
executor._pg_manager = pgm
executor.set_max_pending_trials(1)
executor.setup(max_pending_trials=1)

def train(config):
yield 1
Expand Down
36 changes: 35 additions & 1 deletion python/ray/tune/tests/test_trainable.py
@@ -1,14 +1,16 @@
import json
import os
import tempfile
import time
from typing import Dict, Union
from unittest.mock import patch

import pytest

import ray
from ray import tune
from ray.air import session, Checkpoint
from ray.air._internal.remote_storage import download_from_uri
from ray.air._internal.remote_storage import download_from_uri, upload_to_uri
from ray.tune.trainable import wrap_function


Expand Down Expand Up @@ -188,6 +190,38 @@ def test_checkpoint_object_no_sync(tmpdir):
trainable.restore_from_object(obj)


@pytest.mark.parametrize("hanging", [True, False])
def test_sync_timeout(tmpdir, hanging):
orig_upload_fn = upload_to_uri

def _hanging_upload(*args, **kwargs):
time.sleep(200 if hanging else 0)
orig_upload_fn(*args, **kwargs)

trainable = SavingTrainable(
"object", remote_checkpoint_dir="memory:///test/location", sync_timeout=0.1
)

with patch("ray.air.checkpoint.upload_to_uri", _hanging_upload):
trainable.save()

check_dir = tmpdir / "check_save_obj"

try:
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
except FileNotFoundError:
hung = True
else:
hung = False

assert hung == hanging

if hanging:
assert not check_dir.exists()
else:
assert check_dir.listdir()


if __name__ == "__main__":
import sys

Expand Down
113 changes: 81 additions & 32 deletions python/ray/tune/tests/test_utils.py
@@ -1,45 +1,94 @@
import unittest
import time

import pytest

from ray.tune.search.variant_generator import format_vars
from ray.tune.utils.util import retry_fn


def test_format_vars():

class TuneUtilsTest(unittest.TestCase):
def testFormatVars(self):
# Format brackets correctly
self.assertTrue(
format_vars(
{
("a", "b", "c"): 8.1234567,
("a", "b", "d"): [7, 8],
("a", "b", "e"): [[[3, 4]]],
}
),
"c=8.12345,d=7_8,e=3_4",
# Format brackets correctly
assert (
format_vars(
{
("a", "b", "c"): 8.1234567,
("a", "b", "d"): [7, 8],
("a", "b", "e"): [[[3, 4]]],
}
)
# Sorted by full keys, but only last key is reported
self.assertTrue(
format_vars(
{
("a", "c", "x"): [7, 8],
("a", "b", "x"): 8.1234567,
}
),
"x=8.12345,x=7_8",
== "c=8.1235,d=7_8,e=3_4"
)
# Sorted by full keys, but only last key is reported
assert (
format_vars(
{
("a", "c", "x"): [7, 8],
("a", "b", "x"): 8.1234567,
}
)
# Filter out invalid chars. It's ok to have empty keys or values.
self.assertTrue(
format_vars(
{
("a c?x"): " <;%$ok ",
("some"): " ",
}
),
"a_c_x=ok,some=",
== "x=8.1235,x=7_8"
)
# Filter out invalid chars. It's ok to have empty keys or values.
assert (
format_vars(
{
("a c?x",): " <;%$ok ",
("some",): " ",
}
)
== "a_c_x=ok,some="
)


def test_retry_fn_repeat(tmpdir):
success = tmpdir / "success"
marker = tmpdir / "marker"

def _fail_once():
if marker.exists():
success.write_text(".", encoding="utf-8")
return
marker.write_text(".", encoding="utf-8")
raise RuntimeError("Failing")

assert not success.exists()
assert not marker.exists()

assert retry_fn(
fn=_fail_once,
exception_type=RuntimeError,
sleep_time=0,
)

assert success.exists()
assert marker.exists()


def test_retry_fn_timeout(tmpdir):
success = tmpdir / "success"
marker = tmpdir / "marker"

def _fail_once():
if not marker.exists():
marker.write_text(".", encoding="utf-8")
raise RuntimeError("Failing")
time.sleep(5)
success.write_text(".", encoding="utf-8")
return

assert not success.exists()
assert not marker.exists()

assert not retry_fn(
fn=_fail_once, exception_type=RuntimeError, sleep_time=0, timeout=0.1
)

assert not success.exists()
assert marker.exists()


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
35 changes: 27 additions & 8 deletions python/ray/tune/trainable/trainable.py
Expand Up @@ -101,8 +101,9 @@ def __init__(
logger_creator: Callable[[Dict[str, Any]], "Logger"] = None,
remote_checkpoint_dir: Optional[str] = None,
custom_syncer: Optional[Syncer] = None,
sync_timeout: Optional[int] = None,
):
"""Initialize an Trainable.
"""Initialize a Trainable.

Sets up logging and points ``self.logdir`` to a directory in which
training outputs should be placed.
Expand All @@ -120,6 +121,7 @@ def __init__(
which is different from **per checkpoint** directory.
custom_syncer: Syncer used for synchronizing data from Ray nodes
to external storage.
sync_timeout: Timeout after which sync processes are aborted.
"""

self._experiment_id = uuid.uuid4().hex
Expand Down Expand Up @@ -171,6 +173,7 @@ def __init__(

self.remote_checkpoint_dir = remote_checkpoint_dir
self.custom_syncer = custom_syncer
self.sync_timeout = sync_timeout

@property
def uses_cloud_checkpointing(self):
Expand Down Expand Up @@ -512,12 +515,18 @@ def _maybe_save_to_cloud(self, checkpoint_dir: str) -> bool:
return True

checkpoint = Checkpoint.from_directory(checkpoint_dir)
retry_fn(
lambda: checkpoint.to_uri(self._storage_path(checkpoint_dir)),
checkpoint_uri = self._storage_path(checkpoint_dir)
if not retry_fn(
lambda: checkpoint.to_uri(checkpoint_uri),
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
)
timeout=self.sync_timeout,
):
logger.error(
f"Could not upload checkpoint even after 3 retries: "
f"{checkpoint_uri}"
)
krfricke marked this conversation as resolved.
Show resolved Hide resolved
return True

def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
Expand Down Expand Up @@ -546,12 +555,17 @@ def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
return True

checkpoint = Checkpoint.from_uri(external_uri)
retry_fn(
if not retry_fn(
lambda: checkpoint.to_directory(local_dir),
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
)
timeout=self.sync_timeout,
):
logger.error(
f"Could not download checkpoint even after 3 retries: "
f"{external_uri}"
)

return True

Expand Down Expand Up @@ -719,12 +733,17 @@ def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]):
self.custom_syncer.wait_or_retry()
else:
checkpoint_uri = self._storage_path(checkpoint_dir)
retry_fn(
if not retry_fn(
lambda: _delete_external_checkpoint(checkpoint_uri),
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
)
timeout=self.sync_timeout,
):
logger.error(
f"Could not delete checkpoint even after 3 retries: "
f"{checkpoint_uri}"
)

if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
Expand Down