Skip to content

Commit

Permalink
[train/tune] Update checkpoint index before persisting checkpoint (#4…
Browse files Browse the repository at this point in the history
…0003)

Following up to #39927, this PR updates the logic of updating the checkpoint ID (and thus the checkpoint directory name) just before persisting the checkpoint. This means that the (renamed) `_update_checkpoint_index` gets the metrics associated with the current checkpoint, rather than the previous one.

---------

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <krfricke@users.noreply.github.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
  • Loading branch information
krfricke and justinvyu committed Oct 5, 2023
1 parent 85f3d98 commit 07f3d3a
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 45 deletions.
15 changes: 9 additions & 6 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
WORKER_PID,
TIME_TOTAL_S,
RAY_CHDIR_TO_TRIAL_DIR,
CHECKPOINT_DIR_NAME,
)
from ray.train.error import SessionMisuseError
from ray.util.annotations import DeveloperAPI, PublicAPI
Expand Down Expand Up @@ -380,10 +381,6 @@ def _report_training_result(self, training_result: _TrainingResult) -> None:
# NOTE: This populates `train.get_checkpoint`
self.loaded_checkpoint = training_result.checkpoint

# NOTE: This is where the coordinator AND workers increment their
# checkpoint index.
self.storage._increase_checkpoint_index(training_result.metrics)

# Add result to a thread-safe queue.
self.result_queue.put(training_result, block=True)

Expand Down Expand Up @@ -416,20 +413,26 @@ def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None
if self.ignore_report:
return

metrics = self._auto_fill_metrics(metrics)

persisted_checkpoint = None
if checkpoint:
self.storage._update_checkpoint_index(metrics)

# Persist the reported checkpoint files to storage.
persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)

metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
else:
metrics[CHECKPOINT_DIR_NAME] = None

# Persist trial artifacts to storage.
force_artifact_sync = (
persisted_checkpoint
and self.storage.sync_config.sync_artifacts_on_checkpoint
)
self.storage.persist_artifacts(force=force_artifact_sync)

metrics = self._auto_fill_metrics(metrics)

# Set additional user metadata from the Trainer.
if persisted_checkpoint and self.metadata:
user_metadata = persisted_checkpoint.get_metadata()
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def __init__(
sync_config: Optional[SyncConfig] = None,
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
trial_dir_name: Optional[str] = None,
current_checkpoint_index: int = 0,
current_checkpoint_index: int = -1,
):
self.custom_fs_provided = storage_filesystem is not None

Expand Down Expand Up @@ -512,7 +512,7 @@ def _check_validation_file(self):
"to the configured storage path."
)

def _increase_checkpoint_index(self, metrics: Dict):
def _update_checkpoint_index(self, metrics: Dict):
# Per default, increase by 1. This can be overwritten to customize checkpoint
# directories.
self.current_checkpoint_index += 1
Expand Down
12 changes: 9 additions & 3 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,14 @@ def _report(self, training_iterator: TrainingIterator) -> None:
result.checkpoint for result in results if result.checkpoint is not None
]
at_least_one_reported_checkpoint = len(worker_checkpoints) > 0

if at_least_one_reported_checkpoint:
# Update the coordinator's checkpoint index to the latest.
# This is what keeps the checkpoint index in line with the workers.
tune_session.storage._update_checkpoint_index(
first_worker_result.metrics
)

# Make sure that all workers uploaded to the same location.
assert all(
checkpoint.path == tune_session.storage.checkpoint_fs_path
Expand All @@ -387,14 +395,12 @@ def _report(self, training_iterator: TrainingIterator) -> None:
checkpoint = (
Checkpoint(
filesystem=tune_session.storage.storage_filesystem,
# NOTE: The checkpoint index has not been incremented yet
# at this point, which is why `checkpoint_fs_path` points
# to the most recent checkpoint.
path=tune_session.storage.checkpoint_fs_path,
)
if at_least_one_reported_checkpoint
else None
)

tracked_training_result = _TrainingResult(
checkpoint=checkpoint,
metrics=first_worker_result.metrics,
Expand Down
11 changes: 7 additions & 4 deletions python/ray/train/tests/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import os
import tempfile
from typing import Any, Dict, Type
from typing import Any, Dict, Optional, Type

import ray.cloudpickle as ray_pickle
from ray.train import Checkpoint
Expand All @@ -26,12 +26,15 @@ def load_dict_checkpoint(checkpoint: Checkpoint) -> Dict[str, Any]:


def mock_storage_context(
exp_name: str = "exp_name", delete_syncer: bool = True
exp_name: str = "exp_name",
delete_syncer: bool = True,
storage_path: Optional[str] = None,
storage_context_cls: Type = StorageContext,
) -> StorageContext:
storage_path = tempfile.mkdtemp()
storage_path = storage_path or tempfile.mkdtemp()
exp_name = exp_name
trial_name = "trial_name"
storage = StorageContext(
storage = storage_context_cls(
storage_path=storage_path,
experiment_dir_name=exp_name,
trial_dir_name=trial_name,
Expand Down
8 changes: 8 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ py_test(
tags = ["team:ml", "exclusive", "rllib"],
)

py_test(
name = "test_api_checkpoint_integration",
size = "medium",
srcs = ["tests/test_api_checkpoint_integration.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_callbacks",
size = "small",
Expand Down
9 changes: 0 additions & 9 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ray.air.execution._internal import RayActorManager, TrackedActor
from ray.train._internal.session import _FutureTrainingResult
from ray.train._internal.storage import StorageContext, _use_storage_context
from ray.train.constants import CHECKPOINT_DIR_NAME
from ray.exceptions import RayActorError, RayTaskError
from ray.tune.error import _AbortTrialExecution, _TuneStopTrialError, _TuneRestoreError
from ray.tune.execution.class_cache import _ActorClassCache
Expand Down Expand Up @@ -1748,14 +1747,6 @@ def _process_trial_result(self, trial, result):
result = trial.last_result
result.update(done=True)

# NOTE: This checkpoint dir name metric should only be auto-filled
# after we know the trial will save a checkpoint.
if _use_storage_context() and not is_duplicate:
trial_will_checkpoint = trial.should_checkpoint() or force_checkpoint
result[CHECKPOINT_DIR_NAME] = (
trial.storage.checkpoint_dir_name if trial_will_checkpoint else None
)

self._total_time += result.get(TIME_THIS_ITER_S, 0)

flat_result = flatten_dict(result)
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,10 +1101,10 @@ def on_checkpoint(self, checkpoint: Union[_TrackedCheckpoint, _TrainingResult]):
checkpoint_result = checkpoint
assert isinstance(checkpoint_result, _TrainingResult)
self.run_metadata.checkpoint_manager.register_checkpoint(checkpoint_result)
# Increment the checkpoint index to keep the checkpoint index in sync.
# Update the checkpoint index to keep the checkpoint index in sync.
# This index will get restored when the trial is restored and will
# be passed to the Trainable as the starting checkpoint index.
self.storage._increase_checkpoint_index(checkpoint_result.metrics)
self.storage._update_checkpoint_index(checkpoint_result.metrics)
else:
self.run_metadata.checkpoint_manager.on_checkpoint(checkpoint)
self.invalidate_json_state()
Expand Down
2 changes: 2 additions & 0 deletions python/ray/tune/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ray.rllib import _register_all
from ray.train._internal.session import shutdown_session
from ray.train._internal.storage import StorageContext
from ray.train.constants import CHECKPOINT_DIR_NAME
from ray.train.tests.util import create_dict_checkpoint, load_dict_checkpoint
from ray.tune import (
register_env,
Expand Down Expand Up @@ -151,6 +152,7 @@ def _function_trainable(config):
TIME_THIS_ITER_S,
TIME_TOTAL_S,
DONE, # This is ignored because FunctionAPI has different handling
CHECKPOINT_DIR_NAME,
"timestamp",
"time_since_restore",
"experiment_id",
Expand Down
165 changes: 165 additions & 0 deletions python/ray/tune/tests/test_api_checkpoint_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import os
import tempfile
from pathlib import Path

import pytest
import sys

import ray
from ray import train
from ray.air import ScalingConfig
from ray.train import CheckpointConfig
from ray.air.execution import FixedResourceManager
from ray.air.constants import TRAINING_ITERATION
from ray.train._internal.storage import StorageContext
from ray.tune import Trainable, register_trainable
from ray.tune.execution.tune_controller import TuneController
from ray.tune.experiment import Trial

from ray.train.tests.util import mock_storage_context

STORAGE = mock_storage_context()


@pytest.fixture(scope="function")
def ray_start_4_cpus_2_gpus_extra():
address_info = ray.init(num_cpus=4, num_gpus=2, resources={"a": 2})
yield address_info
ray.shutdown()


@pytest.mark.parametrize("trainable_type", ["class", "function", "data_parallel"])
@pytest.mark.parametrize("patch_iter", [False, True])
def test_checkpoint_freq_dir_name(
ray_start_4_cpus_2_gpus_extra, trainable_type, patch_iter, tmp_path
):
"""Test that trial checkpoint IDs are correctly set across trainable types.
This includes a current workaround to set checkpoint IDs according to reported
metrics.
"""

def num_checkpoints(trial):
return sum(
item.startswith("checkpoint_") for item in os.listdir(trial.local_path)
)

def last_checkpoint_dir(trial):
return max(
item
for item in os.listdir(trial.local_path)
if item.startswith("checkpoint_")
)

checkpoint_config = None

if trainable_type == "class":

class MyTrainable(Trainable):
def step(self):
# `training_iteration` is increased after the report, so we
# +1 here.
return {"step": self.iteration + 1}

def save_checkpoint(self, checkpoint_dir):
return {"test": self.iteration}

def load_checkpoint(self, checkpoint_dir):
pass

register_trainable("test_checkpoint_freq", MyTrainable)
checkpoint_config = CheckpointConfig(checkpoint_frequency=3)

elif trainable_type in {"function", "data_parallel"}:

def train_fn(config):
for step in range(1, 10):
if step > 0 and step % 3 == 0:
with tempfile.TemporaryDirectory() as checkpoint_dir:
(Path(checkpoint_dir) / "data.ckpt").write_text(str(step))
train.report(
{"step": step},
checkpoint=train.Checkpoint.from_directory(checkpoint_dir),
)
else:
train.report({"step": step})

if trainable_type == "function":
register_trainable("test_checkpoint_freq", train_fn)
elif trainable_type == "data_parallel":
from ray.train.data_parallel_trainer import DataParallelTrainer

trainer = DataParallelTrainer(
train_loop_per_worker=train_fn,
scaling_config=ScalingConfig(num_workers=1),
)
register_trainable("test_checkpoint_freq", trainer.as_trainable())

else:
raise RuntimeError("Invalid trainable type")

if patch_iter:

class CustomStorageContext(StorageContext):
def _update_checkpoint_index(self, metrics):
# Todo: Support auto-fille metrics for function trainables
self.current_checkpoint_index = metrics.get(
"step", self.current_checkpoint_index + 1
)

storage = mock_storage_context(
delete_syncer=False,
storage_context_cls=CustomStorageContext,
storage_path=tmp_path,
)
else:
storage = mock_storage_context(delete_syncer=False, storage_path=tmp_path)

trial = Trial(
"test_checkpoint_freq",
checkpoint_config=checkpoint_config,
storage=storage,
)
runner = TuneController(
resource_manager_factory=lambda: FixedResourceManager(),
storage=STORAGE,
checkpoint_period=0,
)
runner.add_trial(trial)

while not trial.is_saving:
runner.step()
runner.step()
assert trial.last_result[TRAINING_ITERATION] == 3
assert num_checkpoints(trial) == 1

if patch_iter:
assert last_checkpoint_dir(trial) == "checkpoint_000003"
else:
assert last_checkpoint_dir(trial) == "checkpoint_000000"

while not trial.is_saving:
runner.step()
runner.step()
assert trial.last_result[TRAINING_ITERATION] == 6
assert num_checkpoints(trial) == 2

if patch_iter:
assert last_checkpoint_dir(trial) == "checkpoint_000006"
else:
assert last_checkpoint_dir(trial) == "checkpoint_000001"

while not trial.is_saving:
runner.step()
runner.step()
assert trial.last_result[TRAINING_ITERATION] == 9
assert num_checkpoints(trial) == 3

if patch_iter:
assert last_checkpoint_dir(trial) == "checkpoint_000009"
else:
assert last_checkpoint_dir(trial) == "checkpoint_000002"


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
Loading

0 comments on commit 07f3d3a

Please sign in to comment.