-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[air] pyarrow.fs
persistence: Remove dependence on rank 0 worker checkpoint reporting
#38523
[air] pyarrow.fs
persistence: Remove dependence on rank 0 worker checkpoint reporting
#38523
Conversation
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
# All workers reported a checkpoint to the same fs path, so there's | ||
# no need to report multiple checkpoints to Tune. | ||
at_least_one_reported_checkpoint = any( | ||
result.checkpoint is not None for result in results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is results
synchronously gathered from all workers (i.e., a barrier?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. Results are returned by the training iterator via a ray.get([worker.get_next_result.remote() for worker in workers])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change makes sense to me, but could you elaborate on the motivation/use cases for this? Are there new edge cases this would expose, or is it a pretty straightforward win-win?
path=tune_session.storage.checkpoint_fs_path, | ||
) | ||
if at_least_one_reported_checkpoint | ||
else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we validate/assert all the checkpoints are pointing to the same path?
I think it's purely win-win here. The motivating use case is with deepspeed checkpoints:
I think this is better for UX as well -- we don't enforce which worker can report a checkpoint. @woshiyyya can elaborate more here. |
…persistence/rank0_ckpt Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
) | ||
|
||
checkpoint = ( | ||
NewCheckpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One consequence of this: if users end up using our new framework checkpoints, it'll lose the class here and always end up as a generic checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @justinvyu! This change is mostly for distributed checkpointing. In deepspeed, there should be one worker on each node to report checkpoints, not necessarily be global rank 0.
This change gives us more general and flexible interface, as long as one of the workers report checkpoint, we can always correctly track it.
…eckpoint reporting (ray-project#38523) This PR removes the need for the rank 0 worker to report a checkpoint in order for a checkpoint to be tracked by Train/Tune. Before (ray <= 2.6): ```python def train_fn_per_worker(config): ... tmpdir = tempfile.mkdtemp() if session.get_world_rank() == 0: # write to tmpdir # global rank 0 MUST report. otherwise it's as if you didn't checkpoint checkpoint = Checkpoint.from_directory(...) else: # create an "empty" checkpoint... # otherwise, if you just reported None, we throw an error # even worse, if you report a dict checkpoint here... unknown territory checkpoint = Checkpoint.from_directory(...) session.report(..., checkpoint) ``` After: ```python def train_fn_per_worker(config): ... # ANY combination of workers can report a checkpoint if train.get_context().get_world_rank() in [2, 4, 6]: with tempfile.TemporaryDirectory() as tempdir: # write to tmpdir train.report(metrics, Checkpoint.from_directory(tempdir)) else: train.report(metrics) ``` Note: the reported *metrics* are still pulled from the global rank 0 worker (same behavior as before). This PR does not remove that restriction. Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
…eckpoint reporting (ray-project#38523) This PR removes the need for the rank 0 worker to report a checkpoint in order for a checkpoint to be tracked by Train/Tune. Before (ray <= 2.6): ```python def train_fn_per_worker(config): ... tmpdir = tempfile.mkdtemp() if session.get_world_rank() == 0: # write to tmpdir # global rank 0 MUST report. otherwise it's as if you didn't checkpoint checkpoint = Checkpoint.from_directory(...) else: # create an "empty" checkpoint... # otherwise, if you just reported None, we throw an error # even worse, if you report a dict checkpoint here... unknown territory checkpoint = Checkpoint.from_directory(...) session.report(..., checkpoint) ``` After: ```python def train_fn_per_worker(config): ... # ANY combination of workers can report a checkpoint if train.get_context().get_world_rank() in [2, 4, 6]: with tempfile.TemporaryDirectory() as tempdir: # write to tmpdir train.report(metrics, Checkpoint.from_directory(tempdir)) else: train.report(metrics) ``` Note: the reported *metrics* are still pulled from the global rank 0 worker (same behavior as before). This PR does not remove that restriction. Signed-off-by: Victor <vctr.y.m@example.com>
Why are these changes needed?
This PR removes the need for the rank 0 worker to report a checkpoint in order for a checkpoint to be tracked by Train/Tune.
Before (ray <= 2.6):
After:
Note: the reported metrics are still pulled from the global rank 0 worker (same behavior as before). This PR does not remove that restriction.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.