Skip to content

Commit

Permalink
Fix error propagation for as_directory if to_directory fails (#40025)
Browse files Browse the repository at this point in the history
`Checkpoint.as_directory` currently downloads the checkpoint to the local filesystem if the checkpoint is remote. It uses `to_directory` to do so. If this fails, `as_directory` will raise a strange error during its `finally` clause that is unrelated to the actual error that occurred during downloading. This PR fixes the issue by only wrapping the user's code in the `try` block. This way, any exceptions during downloading get raised immediately.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
  • Loading branch information
justinvyu committed Oct 3, 2023
1 parent e9d619b commit d166024
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/ray/train/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def as_directory(self) -> Iterator[str]:
del_lock_path = _get_del_lock_path(self._get_temporary_checkpoint_dir())
open(del_lock_path, "a").close()

temp_dir = self.to_directory()
try:
temp_dir = self.to_directory()
yield temp_dir
finally:
# Always cleanup the del lock after we're done with the directory.
Expand All @@ -280,6 +280,8 @@ def as_directory(self) -> Iterator[str]:
f"Traceback:\n{traceback.format_exc()}"
)

# If there are no more lock files, that means there are no more
# readers of this directory, and we can safely delete it.
# In the edge case (process crash before del lock file is removed),
# we do not remove the directory at all.
# Since it's in /tmp, this is not that big of a deal.
Expand Down
21 changes: 21 additions & 0 deletions python/ray/train/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,27 @@ def test_as_directory_lock_cleanup(checkpoint: Checkpoint):
assert not Path(checkpoint_dir).exists()


def test_as_directory_download_error(checkpoint: Checkpoint, monkeypatch):
"""Errors during a checkpoint download should be raised directly when accessing
it with the `as_directory` context manager."""
if isinstance(checkpoint.filesystem, pyarrow.fs.LocalFileSystem):
pytest.skip(
"Local filesystem checkpoints don't download to a temp dir, so "
"there's no error handling to test."
)

error_text = "original error"

def to_directory_error(*args, **kwargs):
raise RuntimeError(error_text)

monkeypatch.setattr(checkpoint, "to_directory", to_directory_error)

with pytest.raises(RuntimeError, match=error_text):
with checkpoint.as_directory() as _:
pass


def test_metadata(checkpoint: Checkpoint):
assert checkpoint.get_metadata() == {}

Expand Down

0 comments on commit d166024

Please sign in to comment.