Skip to content

Commit

Permalink
fix _retrieve_checkpoint_dirpaths
Browse files Browse the repository at this point in the history
Summary:
# Context
For directories containing `_` in other parts of the name besides `epoch_0_step_0` (ex `tmp/fjad_213/epoch_0_step_0`), `_retrieve_checkpoint_dirpaths` can raise errors as it splits on `_` assuming underscore only appears in the final part of the path separating the epoch and step counts

```
>> ckpt_dirpaths.sort(key=lambda x: (int(x.split("_")[1]), int(x.split("_")[3])))
ValueError: invalid literal for int() with base 10: 'tmp/tmpcinmegj2/epoch'
```

# This diff
When sorting the paths, calls `os.path.basename(path)` first, to only consider the `epoch_0_step_0` part of the path.

Reviewed By: galrotem

Differential Revision: D51916358

fbshipit-source-id: 1d08847ea31a6612393663c839c5fec38b076988
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Dec 7, 2023
1 parent 2c6a7c8 commit c5d913c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
15 changes: 15 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Expand Up @@ -337,6 +337,21 @@ def _test_process_group_plumbing() -> None:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

@patch(
"torchtnt.framework.callbacks.base_checkpointer._retrieve_checkpoint_dirpaths",
return_value=["epoch_1_step_10", "epoch_2_step_20"],
)
def test_ckpt_dirpaths(self, _: MagicMock) -> None:
"""
Tests that ckpt_dirpaths is populated correctly
based on if ``keep_last_n_checkpoints`` is set.
"""
bc = BaseCheckpointSaver("foo")
self.assertEqual(bc._ckpt_dirpaths, [])

bc = BaseCheckpointSaver("foo", keep_last_n_checkpoints=10)
self.assertEqual(bc._ckpt_dirpaths, ["epoch_1_step_10", "epoch_2_step_20"])

def test_should_remove_checkpoint(self) -> None:
"""
Tests the helper function that checks if checkpoint should be removed or not
Expand Down
2 changes: 2 additions & 0 deletions tests/framework/callbacks/test_checkpoint_utils.py
Expand Up @@ -142,6 +142,7 @@ def test_retrieve_checkpoint_dirpaths(self, mock_get_filesystem: MagicMock) -> N
{"name": "tmp/epoch_1_step_10", "type": "directory"},
{"name": "tmp/epoch_2_step_10", "type": "directory"},
{"name": "tmp/epoch_0_step_5", "type": "directory"},
{"name": "tmp_12ed/epoch_0_step_6", "type": "directory"},
{"name": "tmp/epoch_0_step_3", "type": "file"},
]

Expand All @@ -151,6 +152,7 @@ def test_retrieve_checkpoint_dirpaths(self, mock_get_filesystem: MagicMock) -> N
returned_paths,
[
"tmp/epoch_0_step_5",
"tmp_12ed/epoch_0_step_6",
"tmp/epoch_0_step_10",
"tmp/epoch_1_step_10",
"tmp/epoch_2_step_10",
Expand Down
9 changes: 7 additions & 2 deletions torchtnt/framework/callbacks/_checkpoint_utils.py
Expand Up @@ -8,7 +8,7 @@
import os
import re

from typing import Any, Dict, List, Optional, Pattern
from typing import Any, Dict, List, Optional, Pattern, Tuple

from pyre_extensions import none_throws
from torch import distributed as dist
Expand Down Expand Up @@ -140,6 +140,11 @@ def _retrieve_checkpoint_dirpaths(dirpath: str) -> List[str]:
Args:
dirpath: parent directory where checkpoints are saved.
"""

def sort_fn(path: str) -> Tuple[int, int]:
x = os.path.basename(path)
return (int(x.split("_")[1]), int(x.split("_")[3]))

fs = get_filesystem(dirpath)

contents = fs.ls(dirpath, detail=True)
Expand All @@ -151,7 +156,7 @@ def _retrieve_checkpoint_dirpaths(dirpath: str) -> List[str]:
ckpt_dirpaths.append(path)

# sorts by epoch, then step
ckpt_dirpaths.sort(key=lambda x: (int(x.split("_")[1]), int(x.split("_")[3])))
ckpt_dirpaths.sort(key=sort_fn)
return ckpt_dirpaths


Expand Down

0 comments on commit c5d913c

Please sign in to comment.