Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] make Trainable storage optional #38853

Merged
merged 2 commits into from
Aug 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def __init__(

self._storage = storage

if _use_storage_context():
assert storage
if _use_storage_context() and storage:
assert storage.trial_fs_path
logger.debug(f"StorageContext on the TRAINABLE:\n{storage}")

Expand Down Expand Up @@ -525,17 +524,29 @@ def save(
)

local_checkpoint = NewCheckpoint.from_directory(checkpoint_dir)
persisted_checkpoint = self._storage.persist_current_checkpoint(
local_checkpoint
)
# The checkpoint index needs to be incremented.
# NOTE: This is no longer using "iteration" as the folder indexing
# to be consistent with fn trainables.
self._storage.current_checkpoint_index += 1

checkpoint_result = _TrainingResult(
checkpoint=persisted_checkpoint, metrics=self._last_result.copy()
)
if self._storage:
persisted_checkpoint = self._storage.persist_current_checkpoint(
local_checkpoint
)
# The checkpoint index needs to be incremented.
# NOTE: This is no longer using "iteration" as the folder indexing
# to be consistent with fn trainables.
self._storage.current_checkpoint_index += 1

checkpoint_result = _TrainingResult(
checkpoint=persisted_checkpoint,
metrics=self._last_result.copy(),
)
else:
# `storage=None` only happens when initializing the
# Trainable manually, outside of Tune/Train.
# In this case, no storage is set, so the default behavior
# is to just not upload anything and report a local checkpoint.
# This is fine for the main use case of local debugging.
checkpoint_result = _TrainingResult(
checkpoint=local_checkpoint, metrics=self._last_result.copy()
)

else:
checkpoint_result: _TrainingResult = checkpoint_dict_or_path
Expand Down
Loading