Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Remove temporary checkpoint directories after restore #37173

Merged
merged 4 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions python/ray/tune/tests/test_function_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from ray.tune.logger import NoopLogger
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable.util import TrainableUtil
from ray.tune.trainable import with_parameters, wrap_function, FuncCheckpointUtil
from ray.tune.trainable import (
with_parameters,
wrap_function,
FuncCheckpointUtil,
FunctionTrainable,
)
from ray.tune.result import DEFAULT_METRIC
from ray.tune.schedulers import ResourceChangingScheduler

Expand Down Expand Up @@ -287,10 +292,11 @@ def train(config, checkpoint_dir=None):

new_trainable2 = wrapped(logger_creator=self.logger_creator)
new_trainable2.restore_from_object(checkpoint_obj)
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
checkpoint_obj = new_trainable2.save_to_object()
new_trainable2.train()
result = new_trainable2.train()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
new_trainable2.stop()
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
assert result[TRAINING_ITERATION] == 4
Expand Down Expand Up @@ -596,6 +602,42 @@ def train(config):
self.assertEqual(trial_2.last_result["m"], 8 + 9)


def test_restore_from_object_delete(tmp_path):
"""Test that temporary checkpoint directories are deleted after restoring.

`FunctionTrainable.restore_from_object` creates a temporary checkpoint directory.
This directory is kept around as we don't control how the user interacts with
the checkpoint - they might load it several times, or no time at all.

Once a new checkpoint is tracked in the status reporter, there is no need to keep
the temporary object around anymore. This test asserts that the temporary
checkpoint directories are then deleted.
"""
# Create 2 checkpoints
cp_1 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=1, override=True)
cp_2 = TrainableUtil.make_checkpoint_dir(str(tmp_path), index=2, override=True)

# Instantiate function trainable
trainable = FunctionTrainable()
trainable._logdir = str(tmp_path)
trainable._status_reporter.set_checkpoint(cp_1)

# Save to object and restore. This will create a temporary checkpoint directory.
cp_obj = trainable.save_to_object()
trainable.restore_from_object(cp_obj)

# Assert there is at least one `checkpoint_tmpxxxxx` directory in the logdir
assert any(path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir())

# Track a new checkpoint. This should delete the temporary checkpoint directory.
trainable._status_reporter.set_checkpoint(cp_2)

# Directory should have been deleted
assert not any(
path.name.startswith("checkpoint_tmp") for path in tmp_path.iterdir()
)


if __name__ == "__main__":
import pytest

Expand Down
12 changes: 12 additions & 0 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def __init__(
# to throw an error if `tune.report()` is called as well
self._air_session_has_reported = False

# Temporary checkpoint directory used for restoring in `restore_from_object`.
# When the next checkpoint is saved, we will remove this directory.
self._tmp_restore_dir = None

def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=None):
self._trial_name = trial_name
self._trial_id = trial_id
Expand Down Expand Up @@ -229,6 +233,11 @@ def set_checkpoint(self, checkpoint, is_new=True):
if is_new:
self._fresh_checkpoint = True

# Delete temporary checkpoint folder from `restore_from_object`, if set.
if self._tmp_restore_dir:
shutil.rmtree(self._tmp_restore_dir, ignore_errors=True)
self._tmp_restore_dir = None

def has_new_checkpoint(self):
return self._fresh_checkpoint

Expand Down Expand Up @@ -519,6 +528,9 @@ def restore_from_object(self, obj):
checkpoint.to_directory(self.temp_checkpoint_dir)

self.restore(self.temp_checkpoint_dir)
# Set tmp restore dir - this directory will be deleted once a new checkpoint
# is written or set.
self._status_reporter._tmp_restore_dir = self.temp_checkpoint_dir
krfricke marked this conversation as resolved.
Show resolved Hide resolved

def cleanup(self):
# Trigger thread termination
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import ray
from ray import tune
from ray.air.config import RunConfig, ScalingConfig, FailureConfig
from ray.air.config import CheckpointConfig, FailureConfig, RunConfig, ScalingConfig
from ray.train.examples.pytorch.tune_cifar_torch_pbt_example import train_func
from ray.train.torch import TorchConfig, TorchTrainer
from ray.tune.schedulers import PopulationBasedTraining
Expand Down Expand Up @@ -70,6 +70,7 @@
run_config=RunConfig(
stop={"training_iteration": 1} if args.smoke_test else None,
failure_config=FailureConfig(max_failures=-1),
checkpoint_config=CheckpointConfig(num_to_keep=10),
callbacks=[FailureInjectorCallback(time_between_checks=90), ProgressCallback()],
),
)
Expand Down