-
Notifications
You must be signed in to change notification settings - Fork 100
Redefine should_save_checkpoint
in state #8
Description
🚀 Feature
Currently state
API has a should_save_checkpoint
which has a couple of issues:
CheckpointUtil
assumes that all workers will return the same value fromshould_save_checkpoint
CheckpointUtil
chooses worker with rank == 0 to be the "representative" to load the checkpoint, then leans onsync()
to broadcast the state to other workers.- Fix CircleCI badge in README.md #2 may not be the correct choice that generalizes to different use-cases. The "correct" logic should be to chose the worker with the "most-tenured" state (e.g. the most up to date state) to broadcast the state.
Motivation
The checkpoint feature in torchelastic has many caveats (see above). Cleaning this logic up would make it clear for users on how to implement their state objects and also make it easier for users to reason about loading and saving of checkpoints and how that interacts with how they should be implementing sync()
and load
and save
methods in the state class.
Pitch
Here's one way we could achieve this:
-
Define a
get_most_tenured
API that the user has to implement to return the rank of the worker with the most "up to date" state that should be shared with other workers on a rendezvous event. -
Add helpers to broadcast
state
objects to the workers, this helper can be called in thesync()
method. For instance:
def get_most_tenured_rank():
# get the rank that has the most up to date state
# or just return a consistent rank
pass
def sync():
most_tenured_rank = get_most_tenured_rank()
dist_util.broadcast_state(state, most_tenured_rank)
Alternatives
- Can bake
most_tenured_rank
concept into the checkpoint util