Skip to content

Commit

Permalink
[Tune] Add timeout for experiment checkpoint syncing to cloud (#30855)
Browse files Browse the repository at this point in the history
#28155 introduced a sync timeout for trainable checkpoint syncing to the cloud, in the case that the sync operation (default is with pyarrow) hangs. This PR adds a similar timeout for experiment checkpoint cloud syncing.

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
  • Loading branch information
justinvyu committed Dec 7, 2022
1 parent 185c8a5 commit ed5b9e5
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 19 deletions.
13 changes: 10 additions & 3 deletions python/ray/tune/execution/trial_runner.py
Expand Up @@ -261,11 +261,18 @@ def _serialize_and_write():

synced = False
if self._syncer:
# Todo: Implement sync_timeout for experiment-level syncing
# (it is currently only used for trainable-to-cloud syncing)
if force:
# Wait until previous sync command finished
self._syncer.wait()
try:
self._syncer.wait()
except TimeoutError as e:
logger.warning(
"The previous sync of the experiment checkpoint to the cloud "
f"timed out: {str(e)}. Tune will continue to retry syncing. "
"If this warning keeps showing up, consider diagnosing the "
"reason behind the hanging sync operation, or increase the "
"`sync_timeout` in `SyncConfig`."
)
synced = self._syncer.sync_up(
local_dir=self._local_dir,
remote_dir=self._remote_dir,
Expand Down
91 changes: 75 additions & 16 deletions python/ray/tune/syncer.py
Expand Up @@ -2,6 +2,7 @@
from functools import partial
import threading
from typing import (
Any,
Callable,
Dict,
List,
Expand Down Expand Up @@ -89,8 +90,6 @@ class SyncConfig:
storage syncing. Defaults to True.
sync_period: Syncing period for syncing between nodes.
sync_timeout: Timeout after which running sync processes are aborted.
Currently only affects trial-to-cloud syncing.
"""

upload_dir: Optional[str] = None
Expand Down Expand Up @@ -141,11 +140,16 @@ def __init__(self, fn: Callable):
self._fn = fn
self._process = None
self._result = {}
self._start_time = float("-inf")

@property
def is_running(self):
return self._process and self._process.is_alive()

@property
def start_time(self):
return self._start_time

def start(self, *args, **kwargs):
if self.is_running:
return False
Expand All @@ -162,13 +166,31 @@ def entrypoint():
self._result["result"] = result

self._process = threading.Thread(target=entrypoint)
self._process.daemon = True
self._process.start()
self._start_time = time.time()

def wait(self):
def wait(self, timeout: Optional[float] = None) -> Any:
"""Waits for the backgrond process to finish running. Waits until the
background process has run for at least `timeout` seconds, counting from
the time when the process was started."""
if not self._process:
return
return None

time_remaining = None
if timeout:
elapsed = time.time() - self.start_time
time_remaining = max(timeout - elapsed, 0)

self._process.join(timeout=time_remaining)

if self._process.is_alive():
self._process = None
raise TimeoutError(
f"{getattr(self._fn, '__name__', str(self._fn))} did not finish "
f"running within the timeout of {timeout} seconds."
)

self._process.join()
self._process = None

exception = self._result.get("exception")
Expand Down Expand Up @@ -199,10 +221,21 @@ class Syncer(abc.ABC):
The base class also exposes an API to only kick off syncs every ``sync_period``
seconds.
Args:
sync_period: The minimum time in seconds between sync operations, as
used by ``sync_up/down_if_needed``.
sync_timeout: The maximum time to wait for a sync process to finish before
issuing a new sync operation. Ex: should be used by ``wait`` if launching
asynchronous sync tasks.
"""

def __init__(self, sync_period: float = 300.0):
def __init__(
self,
sync_period: float = DEFAULT_SYNC_PERIOD,
sync_timeout: float = DEFAULT_SYNC_TIMEOUT,
):
self.sync_period = sync_period
self.sync_timeout = sync_timeout
self.last_sync_up_time = float("-inf")
self.last_sync_down_time = float("-inf")

Expand Down Expand Up @@ -279,7 +312,8 @@ def wait(self):
"""Wait for asynchronous sync command to finish.
You should implement this method if you spawn asynchronous syncing
processes.
processes. This method should timeout after `sync_timeout` and
raise a `TimeoutError`.
"""
pass

Expand Down Expand Up @@ -357,23 +391,37 @@ def _repr_html_(self) -> str:
class _BackgroundSyncer(Syncer):
"""Syncer using a background process for asynchronous file transfer."""

def __init__(self, sync_period: float = 300.0):
super(_BackgroundSyncer, self).__init__(sync_period=sync_period)
def __init__(
self,
sync_period: float = DEFAULT_SYNC_PERIOD,
sync_timeout: float = DEFAULT_SYNC_TIMEOUT,
):
super(_BackgroundSyncer, self).__init__(
sync_period=sync_period, sync_timeout=sync_timeout
)
self._sync_process = None
self._current_cmd = None

def _should_continue_existing_sync(self):
"""Returns whether a previous sync is still running within the timeout."""
return (
self._sync_process
and self._sync_process.is_running
and time.time() - self._sync_process.start_time < self.sync_timeout
)

def sync_up(
self, local_dir: str, remote_dir: str, exclude: Optional[List] = None
) -> bool:
if self._sync_process and self._sync_process.is_running:
if self._should_continue_existing_sync():
logger.warning(
f"Last sync still in progress, "
f"skipping sync up of {local_dir} to {remote_dir}"
)
return False
elif self._sync_process:
try:
self._sync_process.wait()
self.wait()
except Exception as e:
logger.warning(f"Last sync command failed: {e}")

Expand All @@ -392,15 +440,15 @@ def _sync_up_command(
def sync_down(
self, remote_dir: str, local_dir: str, exclude: Optional[List] = None
) -> bool:
if self._sync_process and self._sync_process.is_running:
if self._should_continue_existing_sync():
logger.warning(
f"Last sync still in progress, "
f"skipping sync down of {remote_dir} to {local_dir}"
)
return False
elif self._sync_process:
try:
self._sync_process.wait()
self.wait()
except Exception as e:
logger.warning(f"Last sync command failed: {e}")

Expand All @@ -415,11 +463,16 @@ def _sync_down_command(self, uri: str, local_path: str) -> Tuple[Callable, Dict]
raise NotImplementedError

def delete(self, remote_dir: str) -> bool:
if self._sync_process and self._sync_process.is_running:
if self._should_continue_existing_sync():
logger.warning(
f"Last sync still in progress, skipping deletion of {remote_dir}"
)
return False
elif self._sync_process:
try:
self.wait()
except Exception as e:
logger.warning(f"Last sync command failed: {e}")

self._current_cmd = self._delete_command(uri=remote_dir)
self.retry()
Expand All @@ -432,8 +485,12 @@ def _delete_command(self, uri: str) -> Tuple[Callable, Dict]:
def wait(self):
if self._sync_process:
try:
self._sync_process.wait()
self._sync_process.wait(timeout=self.sync_timeout)
except Exception as e:
# Let `TimeoutError` pass through, to be handled separately
# from errors thrown by the sync operation
if isinstance(e, TimeoutError):
raise e
raise TuneError(f"Sync process failed: {e}") from e
finally:
self._sync_process = None
Expand Down Expand Up @@ -482,7 +539,9 @@ def get_node_to_storage_syncer(sync_config: SyncConfig) -> Optional[Syncer]:
return None

if sync_config.syncer == "auto":
return _DefaultSyncer(sync_period=sync_config.sync_period)
return _DefaultSyncer(
sync_period=sync_config.sync_period, sync_timeout=sync_config.sync_timeout
)

if isinstance(sync_config.syncer, Syncer):
return sync_config.syncer
Expand Down
41 changes: 41 additions & 0 deletions python/ray/tune/tests/test_syncer.py
Expand Up @@ -3,6 +3,7 @@
import shutil
import subprocess
import tempfile
import time
from typing import List, Optional
from unittest.mock import patch

Expand Down Expand Up @@ -336,6 +337,11 @@ class FakeSyncProcess:
def is_running(self):
return True

@property
def start_time(self):
# Don't consider the sync process timeout
return float("inf")

syncer = _DefaultSyncer(sync_period=60)
syncer._sync_process = FakeSyncProcess()
assert not syncer.sync_up_if_needed(
Expand Down Expand Up @@ -364,6 +370,41 @@ def wait(self):
)


def test_syncer_hanging_sync_with_timeout(temp_data_dirs):
"""Check that syncing times out when the sync process is hanging."""
tmp_source, tmp_target = temp_data_dirs

def _hanging_sync_up_command(*args, **kwargs):
time.sleep(200)

class _HangingSyncer(_DefaultSyncer):
def _sync_up_command(
self, local_path: str, uri: str, exclude: Optional[List] = None
):
return _hanging_sync_up_command, {}

syncer = _HangingSyncer(sync_period=60, sync_timeout=10)

def sync_up():
return syncer.sync_up(
local_dir=tmp_source, remote_dir="memory:///test/test_syncer_timeout"
)

with freeze_time() as frozen:
assert sync_up()
frozen.tick(5)
# 5 seconds - initial sync hasn't reached the timeout yet
# It should continue running without launching a new sync
assert not sync_up()
frozen.tick(5)
# Reached the timeout - start running a new sync command
assert sync_up()
frozen.tick(20)
# We're 10 seconds past the timeout, waiting should result in a timeout error
with pytest.raises(TimeoutError):
syncer.wait()


def test_syncer_not_running_sync_last_failed(caplog, temp_data_dirs):
"""Check that new sync is issued if old sync completed"""
caplog.set_level(logging.WARNING)
Expand Down
87 changes: 87 additions & 0 deletions python/ray/tune/tests/test_trial_runner_3.py
@@ -1,5 +1,6 @@
import time
from collections import Counter
import logging
import os
import pickle
import shutil
Expand All @@ -8,6 +9,8 @@
import unittest
from unittest.mock import patch

from freezegun import freeze_time

import ray
from ray.air import CheckpointConfig
from ray.rllib import _register_all
Expand Down Expand Up @@ -883,6 +886,90 @@ def should_checkpoint(self):
# happen as the experiment finishes before it is triggered
assert syncer.sync_up_counter == 4

def getHangingSyncer(self, sync_period: float, sync_timeout: float):
def _hanging_sync_up_command(*args, **kwargs):
time.sleep(200)

from ray.tune.syncer import _DefaultSyncer

class HangingSyncer(_DefaultSyncer):
def __init__(self, sync_period: float, sync_timeout: float):
super(HangingSyncer, self).__init__(
sync_period=sync_period, sync_timeout=sync_timeout
)
self.sync_up_counter = 0

def sync_up(
self, local_dir: str, remote_dir: str, exclude: list = None
) -> bool:
self.sync_up_counter += 1
super(HangingSyncer, self).sync_up(local_dir, remote_dir, exclude)

def _sync_up_command(self, local_path: str, uri: str, exclude: list = None):
return _hanging_sync_up_command, {}

return HangingSyncer(sync_period=sync_period, sync_timeout=sync_timeout)

def testForcedCloudCheckpointSyncTimeout(self):
"""Test that trial runner experiment checkpointing with forced cloud syncing
times out correctly when the sync process hangs."""
ray.init(num_cpus=3)

syncer = self.getHangingSyncer(sync_period=60, sync_timeout=0.5)
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
sync_config=SyncConfig(upload_dir="fake", syncer=syncer),
remote_checkpoint_dir="fake",
)
# Checkpoint for the first time starts the first sync in the background
runner.checkpoint(force=True)
assert syncer.sync_up_counter == 1

buffer = []
logger = logging.getLogger("ray.tune.execution.trial_runner")
with patch.object(logger, "warning", lambda x: buffer.append(x)):
# The second checkpoint will log a warning about the previous sync
# timing out. Then, it will launch a new sync process in the background.
runner.checkpoint(force=True)
assert any(
"sync of the experiment checkpoint to the cloud timed out" in x
for x in buffer
)
assert syncer.sync_up_counter == 2

def testPeriodicCloudCheckpointSyncTimeout(self):
"""Test that trial runner experiment checkpointing with the default periodic
cloud syncing times out and retries correctly when the sync process hangs."""
ray.init(num_cpus=3)

sync_period = 60
syncer = self.getHangingSyncer(sync_period=sync_period, sync_timeout=0.5)
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir,
sync_config=SyncConfig(upload_dir="fake", syncer=syncer),
remote_checkpoint_dir="fake",
)

with freeze_time() as frozen:
runner.checkpoint()
assert syncer.sync_up_counter == 1

frozen.tick(sync_period / 2)
# Cloud sync has already timed out, but we shouldn't retry until
# the next sync_period
runner.checkpoint()
assert syncer.sync_up_counter == 1

frozen.tick(sync_period / 2)
# We've now reached the sync_period - a new sync process should be
# started, with the old one timing out
buffer = []
logger = logging.getLogger("ray.tune.syncer")
with patch.object(logger, "warning", lambda x: buffer.append(x)):
runner.checkpoint()
assert any("did not finish running within the timeout" in x for x in buffer)
assert syncer.sync_up_counter == 2


class SearchAlgorithmTest(unittest.TestCase):
@classmethod
Expand Down

0 comments on commit ed5b9e5

Please sign in to comment.