Skip to content

Commit 1f5adfe

Browse files
committed
Refactor Checkpointer
Several bugs fixes, refactors, and feature improvement for the next PR (integration with TorchFT) 1. Code refactor for better readbility 2. Remove the time based checkpoint condiation, this is not used and can cause deadlocks when integrating with TorchFT. This will also make code simplier. 3. Fixes a async_with_pinned_memory bug. 4. The original keep_last_k implementation can cause exceptions in certain case and is also slow. Fixes the bugs and use a thread to purge checkpoints. ghstack-source-id: cdad3b2 Pull Request resolved: #871
1 parent e742b29 commit 1f5adfe

File tree

4 files changed

+594
-209
lines changed

4 files changed

+594
-209
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import copy
2+
import os
3+
import shutil
4+
import tempfile
5+
import time
6+
import unittest
7+
from concurrent.futures import ThreadPoolExecutor
8+
from dataclasses import dataclass, field
9+
from unittest import mock
10+
11+
import torch
12+
13+
from torchtitan.checkpoint import CheckpointManager
14+
15+
16+
def fake_dcp_save(state, checkpoint_id):
17+
state = {k: v.state_dict() for k, v in state.items()}
18+
os.makedirs(checkpoint_id, exist_ok=True)
19+
torch.save(state, os.path.join(checkpoint_id, "state.pt"))
20+
21+
22+
def fake_dcp_load(state, checkpoint_id):
23+
state["trainer"].dcp_load_is_called = 7312
24+
25+
26+
def fake_async_save(state, checkpoint_id, process_group):
27+
def run_save():
28+
fake_dcp_save(state, checkpoint_id)
29+
30+
with ThreadPoolExecutor(max_workers=1) as executor:
31+
f = executor.submit(run_save)
32+
33+
mock_future = mock.Mock()
34+
mock_future.result = mock.Mock(side_effect=f.result)
35+
return mock_future
36+
37+
38+
def fake_get_model_state_dict(model, *args, **kwargs):
39+
return model.state_dict()
40+
41+
42+
@dataclass
43+
class DummyCheckpointConfig:
44+
enable_checkpoint: bool = True
45+
folder: str = "dummy_folder"
46+
interval: int = 10
47+
async_mode: str = "disabled"
48+
keep_latest_k: int = 0
49+
model_weights_only: bool = False
50+
export_dtype: str = "float32"
51+
exclude_from_loading = []
52+
53+
54+
@dataclass
55+
class DummyJob:
56+
dump_folder: str = "dummy_folder"
57+
58+
59+
@dataclass
60+
class DummyJobConfig:
61+
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
62+
job: DummyJob = field(default_factory=DummyJob)
63+
64+
65+
# Dummy instances to supply as constructor arguments.
66+
dummy_dataloader = mock.Mock()
67+
dummy_dataloader.state_dict = mock.Mock(side_effect=lambda: {"dataloader": 1})
68+
dummy_model_parts = [mock.Mock()]
69+
dummy_model_parts[0].state_dict = mock.Mock(side_effect=lambda: {"model": 2})
70+
dummy_optimizers = mock.Mock()
71+
dummy_optimizers.state_dict = mock.Mock(side_effect=lambda: {"optimizer": 3})
72+
dummy_lr_schedulers = mock.Mock()
73+
dummy_lr_schedulers.state_dict = mock.Mock(side_effect=lambda: {"lr_scheduler": 4})
74+
75+
76+
class TestCheckpointManager(unittest.TestCase):
77+
def setUp(self):
78+
self.temp_dir = tempfile.mkdtemp()
79+
80+
self.dummy_job = DummyJob(dump_folder=self.temp_dir)
81+
self.job_config = DummyJobConfig(job=self.dummy_job)
82+
self.checkpoint_folder = os.path.join(
83+
self.dummy_job.dump_folder, self.job_config.checkpoint.folder
84+
)
85+
os.makedirs(self.checkpoint_folder, exist_ok=True)
86+
self.trainer_state = mock.Mock()
87+
self.trainer_state.state_dict = mock.Mock(side_effect=lambda: {"my_state": 765})
88+
89+
def tearDown(self):
90+
# Remove the temporary directory after each test.
91+
shutil.rmtree(self.temp_dir)
92+
93+
@mock.patch(
94+
"torchtitan.checkpoint.get_model_state_dict",
95+
side_effect=fake_get_model_state_dict,
96+
)
97+
@mock.patch("torchtitan.checkpoint.dcp.save", side_effect=fake_dcp_save)
98+
def test_save(self, *_):
99+
"""Test that calling save() writes a checkpoint file to disk."""
100+
job_config = DummyJobConfig(job=self.dummy_job)
101+
manager = CheckpointManager(
102+
dummy_dataloader,
103+
dummy_model_parts,
104+
dummy_optimizers,
105+
dummy_lr_schedulers,
106+
{"trainer": self.trainer_state},
107+
job_config,
108+
)
109+
step = 20
110+
manager.save(curr_step=step, force=True)
111+
state_file = self._checkpoint_id(step)
112+
self.assertTrue(
113+
os.path.exists(state_file), "The checkpoint file was not created on disk."
114+
)
115+
loaded_state = torch.load(state_file, weights_only=False)
116+
self.assertEqual(
117+
loaded_state["trainer"]["my_state"],
118+
765,
119+
"Saved state does not match expected value.",
120+
)
121+
122+
@mock.patch(
123+
"torchtitan.checkpoint.get_model_state_dict",
124+
side_effect=fake_get_model_state_dict,
125+
)
126+
@mock.patch("torchtitan.checkpoint.dcp.load", side_effect=fake_dcp_load)
127+
@mock.patch("torchtitan.checkpoint.dcp.save", side_effect=fake_dcp_save)
128+
def test_load(self, *_):
129+
"""Test that load() properly reads the checkpoint file from disk and restores state."""
130+
job_config = DummyJobConfig(job=self.dummy_job)
131+
manager = CheckpointManager(
132+
dummy_dataloader,
133+
dummy_model_parts,
134+
dummy_optimizers,
135+
dummy_lr_schedulers,
136+
{"trainer": self.trainer_state},
137+
job_config,
138+
)
139+
step = 30
140+
manager.save(curr_step=step, force=True)
141+
# Simulate a state change.
142+
manager.states["test"] = 999
143+
success = manager.load(step=step)
144+
self.assertTrue(
145+
success,
146+
"The load() method should have returned True for an existing checkpoint.",
147+
)
148+
self.assertTrue(hasattr(manager.states["trainer"], "dcp_load_is_called"))
149+
150+
self.assertEqual(
151+
manager.states["trainer"].dcp_load_is_called,
152+
7312,
153+
"The state was not correctly restored after loading.",
154+
)
155+
156+
@mock.patch("torchtitan.checkpoint.dist.get_rank", return_value=0)
157+
@mock.patch(
158+
"torchtitan.checkpoint.get_model_state_dict",
159+
side_effect=fake_get_model_state_dict,
160+
)
161+
@mock.patch("torchtitan.checkpoint.dcp.save", side_effect=fake_dcp_save)
162+
def test_purge_stale_checkpoints_rank_zero(self, *_):
163+
"""
164+
Test that when keep_latest_k is 3 and dist.get_rank() returns 0, stale checkpoints
165+
are purged by placing the correct paths into the purge queue.
166+
"""
167+
job_config = DummyJobConfig(job=self.dummy_job)
168+
job_config.checkpoint.keep_latest_k = 3
169+
manager = CheckpointManager(
170+
dummy_dataloader,
171+
dummy_model_parts,
172+
dummy_optimizers,
173+
dummy_lr_schedulers,
174+
{"trainer": self.trainer_state},
175+
job_config,
176+
)
177+
steps = [10, 20, 30, 40, 50]
178+
for s in steps:
179+
manager.save(curr_step=s, force=False)
180+
while not manager.purge_queue.empty():
181+
time.sleep(1)
182+
time.sleep(1)
183+
os.sync()
184+
expected_paths = [
185+
os.path.join(self.checkpoint_folder, "step-30"),
186+
os.path.join(self.checkpoint_folder, "step-40"),
187+
os.path.join(self.checkpoint_folder, "step-50"),
188+
]
189+
for step in [10, 20]:
190+
self.assertFalse(
191+
os.path.exists(self._checkpoint_id(step)),
192+
"The checkpoint is not purged.",
193+
)
194+
195+
for step in [30, 40, 50]:
196+
self.assertTrue(
197+
os.path.exists(self._checkpoint_id(step)), "The checkpointis purged."
198+
)
199+
200+
@mock.patch("torchtitan.checkpoint.dist.get_rank", return_value=1)
201+
@mock.patch(
202+
"torchtitan.checkpoint.get_model_state_dict",
203+
side_effect=fake_get_model_state_dict,
204+
)
205+
@mock.patch("torchtitan.checkpoint.dcp.save", side_effect=fake_dcp_save)
206+
def test_purge_stale_checkpoints_rank_nonzero(self, *_):
207+
"""
208+
Test that when dist.get_rank() returns a non-zero value, the purge logic does not
209+
place any paths in the purge queue.
210+
"""
211+
job_config = DummyJobConfig(job=self.dummy_job)
212+
job_config.checkpoint.keep_latest_k = 3
213+
manager = CheckpointManager(
214+
dummy_dataloader,
215+
dummy_model_parts,
216+
dummy_optimizers,
217+
dummy_lr_schedulers,
218+
{"trainer": self.trainer_state},
219+
job_config,
220+
)
221+
steps = [10, 20, 30, 40, 50]
222+
for s in steps:
223+
manager.save(curr_step=s, force=False)
224+
while not manager.purge_queue.empty():
225+
time.sleep(1)
226+
time.sleep(1)
227+
os.sync()
228+
229+
for step in [10, 20, 30, 40, 50]:
230+
self.assertTrue(
231+
os.path.exists(self._checkpoint_id(step)), "The checkpointis purged."
232+
)
233+
234+
@mock.patch("torchtitan.checkpoint.dist.new_group")
235+
@mock.patch(
236+
"torchtitan.checkpoint.get_model_state_dict",
237+
side_effect=fake_get_model_state_dict,
238+
)
239+
@mock.patch("torchtitan.checkpoint.dcp.async_save", side_effect=fake_async_save)
240+
def test_async_save_calls_async_wait(self, *_):
241+
"""
242+
Test that in async mode (AsyncMode.ASYNC), calling save() twice correctly waits
243+
on the previous async future via _async_wait().
244+
"""
245+
# Set async_mode to "async" in the job configuration.
246+
job_config = DummyJobConfig(job=self.dummy_job)
247+
job_config.checkpoint.async_mode = "async"
248+
manager = CheckpointManager(
249+
dummy_dataloader,
250+
dummy_model_parts,
251+
dummy_optimizers,
252+
dummy_lr_schedulers,
253+
{"trainer": self.trainer_state},
254+
job_config,
255+
)
256+
# First save: should schedule an async save.
257+
manager.save(curr_step=10, force=False)
258+
f = manager.async_future
259+
f.result.assert_not_called()
260+
manager.save(curr_step=20, force=False)
261+
f.result.assert_called_once()
262+
f = manager.async_future
263+
f.result.assert_not_called()
264+
265+
def _checkpoint_id(self, step):
266+
checkpoint_id = os.path.join(self.checkpoint_folder, f"step-{step}")
267+
state_file = os.path.join(checkpoint_id, "state.pt")
268+
return state_file
269+
270+
271+
if __name__ == "__main__":
272+
unittest.main()

0 commit comments

Comments
 (0)