Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] pyarrow.fs persistence: Remove dependence on rank 0 worker checkpoint reporting #38523

Merged
merged 5 commits into from
Aug 17, 2023

Conversation

justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Aug 16, 2023

Why are these changes needed?

This PR removes the need for the rank 0 worker to report a checkpoint in order for a checkpoint to be tracked by Train/Tune.

Before (ray <= 2.6):

def train_fn_per_worker(config):
    ...
    tmpdir = tempfile.mkdtemp()
    if session.get_world_rank() == 0:
        # write to tmpdir
        # global rank 0 MUST report. otherwise it's as if you didn't checkpoint
        checkpoint = Checkpoint.from_directory(...)
    else:
        # create an "empty" checkpoint...
        # otherwise, if you just reported None, we throw an error
        # even worse, if you report a dict checkpoint here... unknown territory
        checkpoint = Checkpoint.from_directory(...)
    session.report(..., checkpoint)

After:

def train_fn_per_worker(config):
    ...
    # ANY combination of workers can report a checkpoint
    if train.get_context().get_world_rank() in [2, 4, 6]:
        with tempfile.TemporaryDirectory() as tempdir:
            # write to tmpdir
            train.report(metrics, Checkpoint.from_directory(tempdir))
    else:
        train.report(metrics)

Note: the reported metrics are still pulled from the global rank 0 worker (same behavior as before). This PR does not remove that restriction.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
# All workers reported a checkpoint to the same fs path, so there's
# no need to report multiple checkpoints to Tune.
at_least_one_reported_checkpoint = any(
result.checkpoint is not None for result in results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is results synchronously gathered from all workers (i.e., a barrier?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Results are returned by the training iterator via a ray.get([worker.get_next_result.remote() for worker in workers])

Copy link
Contributor

@ericl ericl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes sense to me, but could you elaborate on the motivation/use cases for this? Are there new edge cases this would expose, or is it a pretty straightforward win-win?

@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Aug 16, 2023
path=tune_session.storage.checkpoint_fs_path,
)
if at_least_one_reported_checkpoint
else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we validate/assert all the checkpoints are pointing to the same path?

@justinvyu
Copy link
Contributor Author

I think it's purely win-win here. The motivating use case is with deepspeed checkpoints:

  • Lightning saves multiple ranks' deepspeed checkpoint shards to a single directory and it's hard to de-aggregate, so our integration should just have each local rank 0 worker upload the directory instead of every worker doing it. The global rank 0 worker may not be the local rank 0 worker, and without this PR, that would mean the checkpoint doesn't get tracked by Tune internally if global rank 0 doesn't report.

I think this is better for UX as well -- we don't enforce which worker can report a checkpoint.

@woshiyyya can elaborate more here.

@justinvyu justinvyu added tests-ok The tagger certifies test failures are unrelated and assumes personal liability. and removed @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. labels Aug 17, 2023
…persistence/rank0_ckpt

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu justinvyu requested a review from ericl August 17, 2023 16:00
)

checkpoint = (
NewCheckpoint(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One consequence of this: if users end up using our new framework checkpoints, it'll lose the class here and always end up as a generic checkpoint.

Copy link
Member

@woshiyyya woshiyyya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @justinvyu! This change is mostly for distributed checkpointing. In deepspeed, there should be one worker on each node to report checkpoints, not necessarily be global rank 0.

This change gives us more general and flexible interface, as long as one of the workers report checkpoint, we can always correctly track it.

@ericl ericl merged commit e1208a7 into ray-project:master Aug 17, 2023
42 of 44 checks passed
@justinvyu justinvyu deleted the air/persistence/rank0_ckpt branch August 17, 2023 19:04
arvind-chandra pushed a commit to lmco/ray that referenced this pull request Aug 31, 2023
…eckpoint reporting (ray-project#38523)

This PR removes the need for the rank 0 worker to report a checkpoint in order for a checkpoint to be tracked by Train/Tune.

Before (ray <= 2.6):
```python
def train_fn_per_worker(config):
    ...
    tmpdir = tempfile.mkdtemp()
    if session.get_world_rank() == 0:
        # write to tmpdir
        # global rank 0 MUST report. otherwise it's as if you didn't checkpoint
        checkpoint = Checkpoint.from_directory(...)
    else:
        # create an "empty" checkpoint...
        # otherwise, if you just reported None, we throw an error
        # even worse, if you report a dict checkpoint here... unknown territory
        checkpoint = Checkpoint.from_directory(...)
    session.report(..., checkpoint)
```

After:
```python
def train_fn_per_worker(config):
    ...
    # ANY combination of workers can report a checkpoint
    if train.get_context().get_world_rank() in [2, 4, 6]:
        with tempfile.TemporaryDirectory() as tempdir:
            # write to tmpdir
            train.report(metrics, Checkpoint.from_directory(tempdir))
    else:
        train.report(metrics)
```

Note: the reported *metrics* are still pulled from the global rank 0 worker (same behavior as before). This PR does not remove that restriction.
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
vymao pushed a commit to vymao/ray that referenced this pull request Oct 11, 2023
…eckpoint reporting (ray-project#38523)

This PR removes the need for the rank 0 worker to report a checkpoint in order for a checkpoint to be tracked by Train/Tune.

Before (ray <= 2.6):
```python
def train_fn_per_worker(config):
    ...
    tmpdir = tempfile.mkdtemp()
    if session.get_world_rank() == 0:
        # write to tmpdir
        # global rank 0 MUST report. otherwise it's as if you didn't checkpoint
        checkpoint = Checkpoint.from_directory(...)
    else:
        # create an "empty" checkpoint...
        # otherwise, if you just reported None, we throw an error
        # even worse, if you report a dict checkpoint here... unknown territory
        checkpoint = Checkpoint.from_directory(...)
    session.report(..., checkpoint)
```

After:
```python
def train_fn_per_worker(config):
    ...
    # ANY combination of workers can report a checkpoint
    if train.get_context().get_world_rank() in [2, 4, 6]:
        with tempfile.TemporaryDirectory() as tempdir:
            # write to tmpdir
            train.report(metrics, Checkpoint.from_directory(tempdir))
    else:
        train.report(metrics)
```

Note: the reported *metrics* are still pulled from the global rank 0 worker (same behavior as before). This PR does not remove that restriction.
Signed-off-by: Victor <vctr.y.m@example.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants