Skip to content

Commit

Permalink
broadcast does-last-checkpoint exist from rank0 (#653)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #653

Broadcast the checkpoint check from rank0 instead of performing it on all ranks.

Reviewed By: JKSenthil

Differential Revision: D52083564

fbshipit-source-id: bfc8e855134dbac7798201bc38dc24c3f93c4e1c
  • Loading branch information
galrotem authored and facebook-github-bot committed Dec 14, 2023
1 parent 37b8a7a commit e2e611d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
23 changes: 23 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,26 @@ def test_keep_last_n_checkpoints_e2e(self) -> None:
f"epoch_{max_epochs}_step_{dataset_len // batch_size * max_epochs}",
os.listdir(temp_dir),
)

def test_does_checkpoint_exist(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
with open(os.path.join(temp_dir, ".metadata"), "w"):
pass
bc = BaseCheckpointSaver(
temp_dir,
save_every_n_train_steps=2,
keep_last_n_checkpoints=1,
)
# checkpointer doesn't have a metadata_fname
does_checkpoint_exist = bc._does_checkpoint_exist(temp_dir)
self.assertFalse(does_checkpoint_exist)

# checkpointer has metadata_fname and the file exists
bc.metadata_fname = ".metadata"
does_checkpoint_exist = bc._does_checkpoint_exist(temp_dir)
self.assertTrue(does_checkpoint_exist)

# checkpointer has metadata_fname but the file doesn't exist
os.remove(os.path.join(temp_dir, ".metadata"))
does_checkpoint_exist = bc._does_checkpoint_exist(temp_dir)
self.assertFalse(does_checkpoint_exist)
32 changes: 23 additions & 9 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_delete_checkpoint,
get_checkpoint_dirpaths,
get_latest_checkpoint_path,
rank_zero_read_and_broadcast,
)
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.state import EntryPoint, State
Expand Down Expand Up @@ -134,15 +135,12 @@ def _generate_checkpoint_and_upkeep(
epoch = unit.train_progress.num_epochs_completed
checkpoint_path = _get_save_path(self._dirpath, epoch, num_steps_completed)

# 1.5) If metadata_fname is set, ensure metadata file doesn't exist on final checkpoint
if hook == "on_train_end" and self.metadata_fname:
metadata_filepath = os.path.join(checkpoint_path, self.metadata_fname)
fs = get_filesystem(metadata_filepath)
if fs.exists(metadata_filepath):
rank_zero_warn(
"Final checkpoint already exists, skipping.", logger=logger
)
return False
# 1.5) Ensure the need to checkpoint again at the end of training
if hook == "on_train_end" and self._does_checkpoint_exist(
checkpoint_path, self._process_group
):
rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger)
return False

# 2) save checkpoint
success = self._checkpoint_impl(
Expand Down Expand Up @@ -307,6 +305,22 @@ def restore_from_latest(
)
return True

@rank_zero_read_and_broadcast
def _does_checkpoint_exist(
self, checkpoint_path: str, process_group: Optional[dist.ProcessGroup] = None
) -> bool:
"""
Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
If the checkpointer doesn't have a metadata file, this function will always return False.
"""
metadata_fname = self.metadata_fname
if not metadata_fname:
return False

metadata_filepath = os.path.join(checkpoint_path, metadata_fname)
fs = get_filesystem(metadata_filepath)
return fs.exists(metadata_filepath)


def _get_save_path(dirpath: str, epoch: int, step: int) -> str:
# TODO: discuss whether this path should be customized
Expand Down

0 comments on commit e2e611d

Please sign in to comment.