Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train/tune] Separate storage checkpoint index bookkeeping #39927

Merged
merged 1 commit into from
Sep 29, 2023

Conversation

krfricke
Copy link
Contributor

@krfricke krfricke commented Sep 28, 2023

Why are these changes needed?

Checkpoint IDs are incremented in three different places: The Trainable (for class trainables), the session (for function trainables), and the Trial (on the driver). These are currently implicitly kept in sync. In the future, we may want to synchronize driver and trainable state via other means, or customize the checkpoint directory name to be populated by other metrics. In preparation for this, we separate out the checkpoint ID mutation into a subfunction that can be overwritten or (in a follow-up) provided or otherwise modified.

from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train._internal.storage import StorageContext
from ray.tune import Tuner
from ray.tune.experiment import Experiment
from ray.tune.trainable import Trainable


class CustomStorageContext(StorageContext):
    def _increase_checkpoint_index(self, metrics):
        self.current_checkpoint_index = metrics.get(
            "training_iteration", self.current_checkpoint_index + 1
        )


class MyTrainable(Trainable):
    def step(self):
        return {"metric": self.iteration + 100}

    def save_checkpoint(self, checkpoint_dir):
        return {"test": "data"}

    def load_checkpoint(self, checkpoint_dir):
        pass


# monkey patch
Experiment._storage_context_cls = CustomStorageContext


tuner = Tuner(
    MyTrainable,
    run_config=RunConfig(
        checkpoint_config=CheckpointConfig(checkpoint_frequency=5),
        stop={"training_iteration": 100},
    ),
)
tuner.fit()

Result:

...
(MyTrainable pid=56464) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/kai/ray_results/MyTrainable_2023-09-27_14-37-02/MyTrainable_270ba_00000_0_2023-09-27_14-37-07/checkpoint_000080)
(MyTrainable pid=56464) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/kai/ray_results/MyTrainable_2023-09-27_14-37-02/MyTrainable_270ba_00000_0_2023-09-27_14-37-07/checkpoint_000085)
(MyTrainable pid=56464) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/kai/ray_results/MyTrainable_2023-09-27_14-37-02/MyTrainable_270ba_00000_0_2023-09-27_14-37-07/checkpoint_000090)
...

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Kai Fricke <kai@anyscale.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can revert to the 2.6 indexing for class trainables only (see the suggestion below).

Then, we can come up with a more general checkpoint dir naming functionality in a follow-up -- something like a format string with arbitrary metrics: CheckpointConfig(checkpoint_dir_name="checkpoint_{training_iteration}-{loss:.2f}")

@@ -1104,7 +1104,7 @@ def on_checkpoint(self, checkpoint: Union[_TrackedCheckpoint, _TrainingResult]):
# Increment the checkpoint index to keep the checkpoint index in sync.
# This index will get restored when the trial is restored and will
# be passed to the Trainable as the starting checkpoint index.
self.storage.current_checkpoint_index += 1
self.storage._increase_checkpoint_index(checkpoint_result.metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.storage._increase_checkpoint_index(checkpoint_result.metrics)
self.storage.current_checkpoint_index = checkpoint.result.metrics[TRAINING_ITERATION]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this mean that the checkpoint manager in Trial will get out of sync with class trainables?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only location where we access the trial checkpoint ID?

            result[CHECKPOINT_DIR_NAME] = (
                trial.storage.checkpoint_dir_name if trial_will_checkpoint else None
            )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, that's right.

That is the only place it gets explicitly accessed, but the trial checkpoint id also needs to stay in sync for restore -- it'll get propagated to the new trainable that gets scheduled:

kwargs["storage"] = trial.storage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - this means we should keep those in sync and shouldn't diverge between class and function trainable, right?

In that case, should we go with the workaround proposed in this PR?

Alternatively, we could communicate the (updated) checkpoint ID in the training result.

@krfricke krfricke marked this pull request as ready for review September 29, 2023 05:05
@krfricke krfricke merged commit 3135323 into ray-project:master Sep 29, 2023
58 of 61 checks passed
@krfricke krfricke deleted the tune/checkpoint-index branch September 29, 2023 05:08
justinvyu added a commit that referenced this pull request Oct 5, 2023
…0003)

Following up to #39927, this PR updates the logic of updating the checkpoint ID (and thus the checkpoint directory name) just before persisting the checkpoint. This means that the (renamed) `_update_checkpoint_index` gets the metrics associated with the current checkpoint, rather than the previous one.

---------

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <krfricke@users.noreply.github.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Zandew pushed a commit to Zandew/ray that referenced this pull request Oct 10, 2023
…y-project#40003)

Following up to ray-project#39927, this PR updates the logic of updating the checkpoint ID (and thus the checkpoint directory name) just before persisting the checkpoint. This means that the (renamed) `_update_checkpoint_index` gets the metrics associated with the current checkpoint, rather than the previous one.

---------

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <krfricke@users.noreply.github.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
vymao pushed a commit to vymao/ray that referenced this pull request Oct 11, 2023
…ct#39927)

Checkpoint IDs are incremented in three different places: The `Trainable` (for class trainables), the `session` (for function trainables), and the `Trial` (on the driver). These are currently implicitly kept in sync. In the future, we may want to synchronize driver and trainable state via other means, or customize the checkpoint directory name to be populated by other metrics. In preparation for this, we separate out the checkpoint ID mutation into a subfunction that can be overwritten or (in a follow-up) provided or otherwise modified.

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Victor <vctr.y.m@example.com>
vymao pushed a commit to vymao/ray that referenced this pull request Oct 11, 2023
…y-project#40003)

Following up to ray-project#39927, this PR updates the logic of updating the checkpoint ID (and thus the checkpoint directory name) just before persisting the checkpoint. This means that the (renamed) `_update_checkpoint_index` gets the metrics associated with the current checkpoint, rather than the previous one.

---------

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <krfricke@users.noreply.github.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Victor <vctr.y.m@example.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants