-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train/tune] Separate storage checkpoint index bookkeeping #39927
Conversation
Signed-off-by: Kai Fricke <kai@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we can revert to the 2.6 indexing for class trainables only (see the suggestion below).
Then, we can come up with a more general checkpoint dir naming functionality in a follow-up -- something like a format string with arbitrary metrics: CheckpointConfig(checkpoint_dir_name="checkpoint_{training_iteration}-{loss:.2f}")
@@ -1104,7 +1104,7 @@ def on_checkpoint(self, checkpoint: Union[_TrackedCheckpoint, _TrainingResult]): | |||
# Increment the checkpoint index to keep the checkpoint index in sync. | |||
# This index will get restored when the trial is restored and will | |||
# be passed to the Trainable as the starting checkpoint index. | |||
self.storage.current_checkpoint_index += 1 | |||
self.storage._increase_checkpoint_index(checkpoint_result.metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.storage._increase_checkpoint_index(checkpoint_result.metrics) | |
self.storage.current_checkpoint_index = checkpoint.result.metrics[TRAINING_ITERATION] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this mean that the checkpoint manager in Trial
will get out of sync with class trainables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the only location where we access the trial checkpoint ID?
result[CHECKPOINT_DIR_NAME] = (
trial.storage.checkpoint_dir_name if trial_will_checkpoint else None
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah, that's right.
That is the only place it gets explicitly accessed, but the trial checkpoint id also needs to stay in sync for restore -- it'll get propagated to the new trainable that gets scheduled:
ray/python/ray/tune/experiment/trial.py
Line 277 in 5369feb
kwargs["storage"] = trial.storage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK - this means we should keep those in sync and shouldn't diverge between class and function trainable, right?
In that case, should we go with the workaround proposed in this PR?
Alternatively, we could communicate the (updated) checkpoint ID in the training result.
…0003) Following up to #39927, this PR updates the logic of updating the checkpoint ID (and thus the checkpoint directory name) just before persisting the checkpoint. This means that the (renamed) `_update_checkpoint_index` gets the metrics associated with the current checkpoint, rather than the previous one. --------- Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Kai Fricke <krfricke@users.noreply.github.com> Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com>
…y-project#40003) Following up to ray-project#39927, this PR updates the logic of updating the checkpoint ID (and thus the checkpoint directory name) just before persisting the checkpoint. This means that the (renamed) `_update_checkpoint_index` gets the metrics associated with the current checkpoint, rather than the previous one. --------- Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Kai Fricke <krfricke@users.noreply.github.com> Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com>
…ct#39927) Checkpoint IDs are incremented in three different places: The `Trainable` (for class trainables), the `session` (for function trainables), and the `Trial` (on the driver). These are currently implicitly kept in sync. In the future, we may want to synchronize driver and trainable state via other means, or customize the checkpoint directory name to be populated by other metrics. In preparation for this, we separate out the checkpoint ID mutation into a subfunction that can be overwritten or (in a follow-up) provided or otherwise modified. Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Victor <vctr.y.m@example.com>
…y-project#40003) Following up to ray-project#39927, this PR updates the logic of updating the checkpoint ID (and thus the checkpoint directory name) just before persisting the checkpoint. This means that the (renamed) `_update_checkpoint_index` gets the metrics associated with the current checkpoint, rather than the previous one. --------- Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Kai Fricke <krfricke@users.noreply.github.com> Signed-off-by: Justin Yu <justinvyu@anyscale.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Victor <vctr.y.m@example.com>
Why are these changes needed?
Checkpoint IDs are incremented in three different places: The
Trainable
(for class trainables), thesession
(for function trainables), and theTrial
(on the driver). These are currently implicitly kept in sync. In the future, we may want to synchronize driver and trainable state via other means, or customize the checkpoint directory name to be populated by other metrics. In preparation for this, we separate out the checkpoint ID mutation into a subfunction that can be overwritten or (in a follow-up) provided or otherwise modified.Result:
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.