Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] pyarrow.fs persistence (4/n): Introduce a simplified checkpoint manager #37962

Merged
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_checkpoint_manager",
size = "small",
srcs = ["tests/test_checkpoint_manager.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib"]
)

py_test(
name = "test_data_parallel_trainer",
size = "medium",
Expand Down
190 changes: 190 additions & 0 deletions python/ray/train/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import logging
import numbers
from typing import Any, Callable, Dict, List, Optional, Tuple

from ray._private.dict import flatten_dict
from ray.air.config import MAX
from ray.air._internal.util import is_nan
from ray.train import CheckpointConfig
from ray.train._internal.storage import _delete_fs_path
from ray.train.checkpoint import Checkpoint


logger = logging.getLogger(__name__)


class _TrainingResult:
"""A (checkpoint, metrics) result reported by the user."""

def __init__(self, checkpoint: Checkpoint, metrics: Dict[str, Any]):
self.checkpoint = checkpoint
self.metrics = metrics


def _insert_into_sorted_list(list: List[Any], item: Any, key: Callable[[Any], Any]):
"""Insert an item into a sorted list with a custom key function.

Examples:

>>> list = []
>>> _insert_into_sorted_list(list, {"a": 1, "b": 0}, lambda x: x["a"])
>>> list
[{'a': 1, 'b': 0}]
>>> _insert_into_sorted_list(list, {"a": 3, "b": 1}, lambda x: x["a"])
>>> list
[{'a': 1, 'b': 0}, {'a': 3, 'b': 1}]
>>> _insert_into_sorted_list(list, {"a": 4, "b": 2}, lambda x: x["a"])
>>> list
[{'a': 1, 'b': 0}, {'a': 3, 'b': 1}, {'a': 4, 'b': 2}]
>>> _insert_into_sorted_list(list, {"a": 1, "b": 3}, lambda x: x["a"])
>>> list
[{'a': 1, 'b': 0}, {'a': 1, 'b': 3}, {'a': 3, 'b': 1}, {'a': 4, 'b': 2}]
"""
i = 0
while i < len(list):
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
# Insert to the right of all duplicates.
if key(list[i]) > key(item):
break
i += 1
list.insert(i, item)


class _CheckpointManager:
"""Checkpoint manager that handles checkpoint book-keeping for a trial.

The main purpose of this abstraction is to keep the top K checkpoints based on
recency/a user-provided metric.

NOTE: This class interacts with `_TrainingResult` objects, which are
(checkpoint, metrics) pairs. This is to order checkpoints by metrics.

Args:
checkpoint_config: Defines how many and which checkpoints to keep.
"""

def __init__(self, checkpoint_config: Optional[CheckpointConfig]):
self._checkpoint_config = checkpoint_config or CheckpointConfig()

# List of checkpoints ordered by ascending score.
self._checkpoint_results: List[_TrainingResult] = []

# The latest registered checkpoint.
# This should never be immediately deleted upon registration,
# even if it's not in the top K checkpoints, based on score.
self._latest_checkpoint_result: _TrainingResult = None

if (
self._checkpoint_config.num_to_keep is not None
and self._checkpoint_config.num_to_keep <= 0
):
raise ValueError(
f"`num_to_keep` must >= 1, got: "
f"{self._checkpoint_config.num_to_keep}"
)

def register_checkpoint(self, checkpoint_result: _TrainingResult):
"""Register new checkpoint and add to bookkeeping.

This method will register a new checkpoint and add it to the internal
bookkeeping logic. This means the checkpoint manager will decide if
this checkpoint should be kept, and if older or worse performing
checkpoints should be deleted.

Args:
checkpoint: Tracked checkpoint object to add to bookkeeping.
"""
self._latest_checkpoint_result = checkpoint_result

if self._checkpoint_config.checkpoint_score_attribute is not None:
# If we're ordering by a score, insert the checkpoint
# so that the list remains sorted.
_insert_into_sorted_list(
self._checkpoint_results,
checkpoint_result,
key=self._get_checkpoint_score,
)
else:
# If no metric is provided, just append (ordering by time of registration).
self._checkpoint_results.append(checkpoint_result)

if self._checkpoint_config.num_to_keep is not None:
# Delete the bottom (N - K) checkpoints
worst_results = set(
self._checkpoint_results[: -self._checkpoint_config.num_to_keep]
)
# Except for the latest checkpoint.
results_to_delete = worst_results - {self._latest_checkpoint_result}

# Update internal state before actually deleting them.
self._checkpoint_results = [
checkpoint_result
for checkpoint_result in self._checkpoint_results
if checkpoint_result not in results_to_delete
]

for checkpoint_result in results_to_delete:
checkpoint = checkpoint_result.checkpoint
logger.debug("Deleting checkpoint: ", checkpoint)
_delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)

def _get_checkpoint_score(
self, checkpoint: _TrainingResult
) -> Tuple[bool, numbers.Number]:
"""Get the score for a checkpoint, according to checkpoint config.

If `mode="min"`, the metric is negated so that the lowest score is
treated as the best.

Returns:
Tuple: A tuple of (not_is_nan: bool, score: numbers.Number).
This score orders: nan values < float("-inf") < valid numeric metrics
"""
checkpoint_score_attribute = self._checkpoint_config.checkpoint_score_attribute
if checkpoint_score_attribute:
flat_metrics = flatten_dict(checkpoint.metrics)
try:
checkpoint_result = flat_metrics[checkpoint_score_attribute]
except KeyError:
valid_keys = list(flat_metrics.keys())
logger.error(
f"Result dict has no key: {checkpoint_score_attribute}. "
f"checkpoint_score_attr must be set to a key in the "
f"result dict. Valid keys are: {valid_keys}"
)
checkpoint_result = float("-inf")
else:
checkpoint_result = float("-inf")

checkpoint_score_order = self._checkpoint_config.checkpoint_score_order
order_factor = 1.0 if checkpoint_score_order == MAX else -1.0

checkpoint_score = order_factor * checkpoint_result

if not isinstance(checkpoint_score, numbers.Number):
raise ValueError(
f"Unable to persist checkpoint for "
f"checkpoint_score_attribute: "
f"{checkpoint_score_attribute} with value "
f"{checkpoint_score}. "
f"This attribute must be numerical."
)

return (
(not is_nan(checkpoint_score), checkpoint_score)
if not is_nan(checkpoint_score)
else (False, float("-inf"))
)

@property
def best_checkpoint_result(self) -> Optional[_TrainingResult]:
return self._checkpoint_results[-1] if self._checkpoint_results else None

@property
def latest_checkpoint_result(self) -> Optional[_TrainingResult]:
return self._latest_checkpoint_result

@property
def best_checkpoint_results(self) -> List[_TrainingResult]:
if self._checkpoint_config.num_to_keep is None:
return self._checkpoint_results
return self._checkpoint_results[-self._checkpoint_config.num_to_keep :]
191 changes: 191 additions & 0 deletions python/ray/train/tests/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from pathlib import Path
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adapted from ray/air/test_checkpoint_manager.py

import random
from typing import List

import pytest

from ray.train import CheckpointConfig
from ray.train._internal.checkpoint_manager import (
_CheckpointManager,
_TrainingResult,
)
from ray.train.checkpoint import Checkpoint


@pytest.fixture
def checkpoint_paths(tmp_path):
checkpoint_paths = []
for i in range(10):
checkpoint_path = tmp_path / f"ckpt_{i}"
checkpoint_path.mkdir()
(checkpoint_path / "dummy.txt").write_text(f"{i}")
checkpoint_paths.append(checkpoint_path)

yield [str(path) for path in checkpoint_paths]


def test_unlimited_checkpoints(checkpoint_paths: List[str]):
manager = _CheckpointManager(checkpoint_config=CheckpointConfig(num_to_keep=None))

for i in range(10):
manager.register_checkpoint(
_TrainingResult(
checkpoint=Checkpoint.from_directory(checkpoint_paths[i]),
metrics={"iter": i},
)
)

assert len(manager.best_checkpoint_results) == 10


def test_limited_checkpoints(checkpoint_paths: List[str]):
manager = _CheckpointManager(checkpoint_config=CheckpointConfig(num_to_keep=2))

for i in range(10):
manager.register_checkpoint(
_TrainingResult(
checkpoint=Checkpoint.from_directory(checkpoint_paths[i]),
metrics={"iter": i},
)
)

assert len(manager.best_checkpoint_results) == 2

# Keep the latest checkpoints if no metric is given.
assert {
tracked_checkpoint.metrics["iter"]
for tracked_checkpoint in manager.best_checkpoint_results
} == {8, 9}

# The first 8 checkpoints should be deleted.
for i in range(8):
assert not Path(checkpoint_paths[i]).exists()

assert Path(checkpoint_paths[8]).exists()
assert Path(checkpoint_paths[9]).exists()


@pytest.mark.parametrize("order", ["min", "max"])
def test_keep_checkpoints_by_score(order, checkpoint_paths):
num_to_keep = 2
score_attribute = "score"

manager = _CheckpointManager(
checkpoint_config=CheckpointConfig(
num_to_keep=num_to_keep,
checkpoint_score_attribute=score_attribute,
checkpoint_score_order=order,
)
)

scores = []
for i in range(10):
score = random.random()
manager.register_checkpoint(
_TrainingResult(
checkpoint=Checkpoint.from_directory(checkpoint_paths[i]),
metrics={"iter": i, score_attribute: score},
)
)
scores.append(score)

sorted_scores = sorted(scores, reverse=order == "max")
assert set(sorted_scores[:num_to_keep]) == {
tracked_checkpoint.metrics[score_attribute]
for tracked_checkpoint in manager.best_checkpoint_results
}

# Make sure the bottom checkpoints are deleted.
best_checkpoint_iters = {
tracked_checkpoint.metrics["iter"]
for tracked_checkpoint in manager.best_checkpoint_results
}
for i, checkpoint_path in enumerate(checkpoint_paths):
if i in best_checkpoint_iters or i == 9:
# The checkpoint should only exist if it's one of the top K or the latest.
assert Path(checkpoint_path).exists()
else:
assert not Path(checkpoint_path).exists()


def test_keep_latest_checkpoint(checkpoint_paths):
manager = _CheckpointManager(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="score",
checkpoint_score_order="max",
)
)

manager.register_checkpoint(
_TrainingResult(
checkpoint=Checkpoint.from_directory(checkpoint_paths[0]),
metrics={"score": 3.0},
)
)
manager.register_checkpoint(
_TrainingResult(
checkpoint=Checkpoint.from_directory(checkpoint_paths[1]),
metrics={"score": 2.0},
)
)
manager.register_checkpoint(
_TrainingResult(
checkpoint=Checkpoint.from_directory(checkpoint_paths[2]),
metrics={"score": 1.0},
)
)

assert len(manager.best_checkpoint_results) == 2

# The latest checkpoint with the lowest score should not be deleted yet.
assert manager.latest_checkpoint_result.metrics["score"] == 1.0

# The latest checkpoint with the lowest score should not be deleted yet.
assert Path(checkpoint_paths[2]).exists()

manager.register_checkpoint(
_TrainingResult(
checkpoint=Checkpoint.from_directory(checkpoint_paths[3]),
metrics={"score": 0.0},
)
)
# A newer checkpoint came in. Even though the new one has a lower score, there are
# already num_to_keep better checkpoints, so the previous one should be deleted.
assert not Path(checkpoint_paths[2]).exists()

# Quick sanity check to make sure that the new checkpoint is kept.
assert manager.latest_checkpoint_result.metrics["score"] == 0.0
assert Path(checkpoint_paths[3]).exists()

# The original 2 checkpoints should still exist
assert Path(checkpoint_paths[0]).exists()
assert Path(checkpoint_paths[1]).exists()


@pytest.mark.parametrize(
"metrics",
[
{"nested": {"sub": {"attr": 5}}},
{"nested": {"sub/attr": 5}},
{"nested/sub": {"attr": 5}},
{"nested/sub/attr": 5},
],
)
def test_nested_get_checkpoint_score(metrics):
manager = _CheckpointManager(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="nested/sub/attr",
checkpoint_score_order="max",
)
)

tracked_checkpoint = _TrainingResult(checkpoint=None, metrics=metrics)
assert manager._get_checkpoint_score(tracked_checkpoint) == (True, 5.0)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
Loading