Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] pyarrow.fs persistence: Don't automatically delete the local checkpoint #38507

Merged
merged 6 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,9 @@ def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
"Current" is defined by the `current_checkpoint_index` attribute of the
storage context.
This method copies the checkpoint files to the storage location,
drops a marker at the storage path to indicate that the checkpoint
is completely uploaded, then deletes the original checkpoint directory.
This method copies the checkpoint files to the storage location.
It's up to the user to delete the original checkpoint files if desired.
For example, the original directory is typically a local temp directory.
Args:
Expand Down Expand Up @@ -561,17 +561,12 @@ def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
destination_filesystem=self.storage_filesystem,
)

# Delete local checkpoint files.
# TODO(justinvyu): What if checkpoint.path == self.checkpoint_fs_path?
# TODO(justinvyu): What if users don't want to delete the local checkpoint?
checkpoint.filesystem.delete_dir(checkpoint.path)

uploaded_checkpoint = Checkpoint(
persisted_checkpoint = Checkpoint(
filesystem=self.storage_filesystem,
path=self.checkpoint_fs_path,
)
logger.debug(f"Checkpoint successfully created at: {uploaded_checkpoint}")
return uploaded_checkpoint
logger.debug(f"Checkpoint successfully created at: {persisted_checkpoint}")
return persisted_checkpoint

@property
def experiment_fs_path(self) -> str:
Expand Down
35 changes: 19 additions & 16 deletions python/ray/train/tests/test_new_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,29 @@ def train_fn(config):
for i in range(start, config.get("num_iterations", 5)):
time.sleep(0.25)

temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, "checkpoint.pkl"), "wb") as f:
pickle.dump({"iter": i}, f)
with tempfile.TemporaryDirectory() as temp_dir:
with open(os.path.join(temp_dir, "checkpoint.pkl"), "wb") as f:
pickle.dump({"iter": i}, f)

artifact_file_name = f"artifact-iter={i}.txt"
if in_trainer:
rank = train.get_context().get_world_rank()
artifact_file_name = f"artifact-rank={rank}-iter={i}.txt"
artifact_file_name = f"artifact-iter={i}.txt"
if in_trainer:
rank = train.get_context().get_world_rank()
artifact_file_name = f"artifact-rank={rank}-iter={i}.txt"

checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl"
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
pickle.dump({"iter": i}, f)
checkpoint_file_name = f"checkpoint_shard-rank={rank}.pkl"
with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
pickle.dump({"iter": i}, f)

with open(artifact_file_name, "w") as f:
f.write(f"{i}")
with open(artifact_file_name, "w") as f:
f.write(f"{i}")

train.report(
{"iter": i, _SCORE_KEY: i},
checkpoint=NewCheckpoint.from_directory(temp_dir),
)
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
# `train.report` should not have deleted this!
assert os.path.exists(temp_dir)

train.report(
{"iter": i, _SCORE_KEY: i},
checkpoint=NewCheckpoint.from_directory(temp_dir),
)
if i in config.get("fail_iters", []):
raise RuntimeError(f"Failing on iter={i}!!")

Expand Down
6 changes: 6 additions & 0 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,12 @@ def get_state(self):
def _create_checkpoint_dir(
self, checkpoint_dir: Optional[str] = None
) -> Optional[str]:
if _use_storage_context():
# NOTE: There's no need to supply the checkpoint directory inside
# the local trial dir, since it'll get persisted to the right location.
checkpoint_dir = tempfile.mkdtemp()
return checkpoint_dir

# Create checkpoint_xxxxx directory and drop checkpoint marker
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
checkpoint_dir or self.logdir, index=self.iteration, override=True
Expand Down
Loading