31
31
from torch .distributed .checkpoint .stateful import Stateful
32
32
from torch .utils .data import DataLoader
33
33
34
+ from torchtitan .components .ft import FTManager
34
35
from torchtitan .components .optimizer import LRSchedulersContainer , OptimizersContainer
35
36
from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
36
37
from torchtitan .tools .logging import init_logger , logger
@@ -214,6 +215,19 @@ class CheckpointManager:
214
215
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
215
216
with the assumption that all lr_schedulers have the same state_dict.
216
217
218
+ Note: TorchFT checkpointing flow
219
+
220
+ There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
221
+ checkpoint, 2) the per-replica checkpoint.
222
+
223
+ The full perisistent checkpoint is saved by the replica with
224
+ ``ft_manager.participating_rank() == 0``. It contains everything including the model,
225
+ optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
226
+ checkpoint is loaded by all replicas. However, we can optimize it to only load if
227
+ there are no other alive replicas.
228
+
229
+ The per-replica checkpoint contains only the dataloader and is saved/loaded by all
230
+ replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
217
231
218
232
Args:
219
233
dataloader (DataLoader): The dataloader used to load the data.
@@ -223,6 +237,7 @@ class CheckpointManager:
223
237
states (Dict[str, Any]): The states that need to be saved, other than the
224
238
previous 4 components.
225
239
job_config (JobConfig): The job config used to configure the checkpointing.
240
+ ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
226
241
"""
227
242
228
243
def __init__ (
@@ -233,16 +248,41 @@ def __init__(
233
248
lr_schedulers : LRSchedulersContainer ,
234
249
states : Dict [str , Any ],
235
250
job_config : JobConfig ,
251
+ ft_manager : FTManager ,
236
252
) -> None :
237
253
ckpt_config = job_config .checkpoint
238
254
self .enable_checkpoint = ckpt_config .enable_checkpoint
255
+ self .ft_manager = ft_manager .manager if ft_manager .enabled else None
256
+
257
+ if self .ft_manager :
258
+ optimizers .init_cache_state_dict ()
259
+
260
+ def state_dict ():
261
+ ret = {}
262
+ for k , v in self .states .items ():
263
+ if k in {
264
+ MODEL ,
265
+ OPTIMIZER ,
266
+ LR_SCHEDULER ,
267
+ TRAIN_STATE ,
268
+ }:
269
+ ret [k ] = v .state_dict ()
270
+ return ret
271
+
272
+ def load_state_dict (state_dict ):
273
+ assert state_dict is not None
274
+ for k , v in state_dict .items ():
275
+ self .states [k ].load_state_dict (v )
276
+
277
+ self .ft_manager .set_state_dict_fns (load_state_dict , state_dict )
278
+ self .ft_replica_id = job_config .experimental .ft_replica_id
239
279
240
280
async_mode = ckpt_config .async_mode .lower ()
241
281
self .enable_staging = (
242
282
self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
243
- )
283
+ ) or self . ft_manager
244
284
245
- if not self .enable_checkpoint :
285
+ if not self .enable_checkpoint and self . ft_manager is None :
246
286
return
247
287
248
288
self .states = states
@@ -254,6 +294,13 @@ def __init__(
254
294
LR_SCHEDULER : lr_schedulers ,
255
295
}
256
296
)
297
+ self .ft_states = {DATALOADER : dataloader }
298
+
299
+ self .staging = False
300
+ self .sending_to_checkpoint_mp = False
301
+ self .staging_id = None
302
+ self .cpu_offload_state_dict = None
303
+ self .staging_stream = torch .cuda .Stream () if self .enable_staging else None
257
304
258
305
self .staging = False
259
306
self .sending_to_checkpoint_mp = False
@@ -264,7 +311,7 @@ def __init__(
264
311
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
265
312
self .interval = ckpt_config .interval
266
313
async_mode = ckpt_config .async_mode .lower ()
267
- if async_mode == AsyncMode .ASYNC :
314
+ if async_mode == AsyncMode .ASYNC or self . ft_manager :
268
315
self .pg = dist .new_group (backend = "gloo" )
269
316
270
317
self .keep_latest_k = ckpt_config .keep_latest_k
@@ -339,35 +386,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
339
386
None
340
387
"""
341
388
389
+ if self .ft_manager :
390
+ self ._ft_save (curr_step )
391
+
342
392
if not self ._should_save (curr_step , force ):
343
393
return
344
394
345
395
begin = time .monotonic ()
346
- logger .info ("Saving the checkpoint (or staging if async is enabled)." )
347
- checkpoint_id = self ._create_checkpoint_id (curr_step )
348
- self ._async_wait ()
349
- # This GC is called for async checkpoint as it is useless to do
350
- # GC right after async_save -- the CPU memory is not able to be
351
- # freed until _async_wait()
352
- if force :
353
- self ._save_last_step (curr_step )
354
- elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
355
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
356
- self ._async_with_pinned_memory (checkpoint_id )
357
- elif self .async_mode == AsyncMode .ASYNC :
358
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
359
- self .async_future = dcp .async_save (
360
- self .states , checkpoint_id = checkpoint_id , process_group = self .pg
361
- )
362
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
363
- else :
364
- save_with_gc (self .states , checkpoint_id = checkpoint_id )
365
- self ._purge_stale_checkpoints ()
396
+ if not self .ft_manager or self .ft_manager .participating_rank () == 0 :
397
+ logger .info ("Saving the checkpoint (or staging if async is enabled)." )
398
+ checkpoint_id = self ._create_checkpoint_id (curr_step )
399
+ self ._async_wait ()
400
+ # This GC is called for async checkpoint as it is useless to do
401
+ # GC right after async_save -- the CPU memory is not able to be
402
+ # freed until _async_wait()
403
+ if force :
404
+ self ._save_last_step (curr_step )
405
+ elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
406
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
407
+ self ._async_with_pinned_memory (checkpoint_id )
408
+ elif self .async_mode == AsyncMode .ASYNC :
409
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
410
+ self .async_future = dcp .async_save (
411
+ self .states , checkpoint_id = checkpoint_id , process_group = self .pg
412
+ )
413
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
414
+ else :
415
+ save_with_gc (self .states , checkpoint_id = checkpoint_id )
416
+ self ._purge_stale_checkpoints ()
366
417
367
- logger .info (
368
- "Finished saving the checkpoint (or staging if async is enabled)"
369
- f"in { time .monotonic () - begin :.2f} seconds."
370
- )
418
+ logger .info (
419
+ "Finished saving the checkpoint (or staging if async is enabled)"
420
+ f"in { time .monotonic () - begin :.2f} seconds."
421
+ )
422
+ elif self .ft_manager :
423
+ logger .info (
424
+ "Replica %d doesn't save checkpoint." ,
425
+ self .ft_manager .participating_rank (),
426
+ )
371
427
372
428
@torch .no_grad ()
373
429
def load (self , step : int = - 1 ) -> bool :
@@ -384,6 +440,9 @@ def load(self, step: int = -1) -> bool:
384
440
bool: Whether the checkpoint was loaded successfully.
385
441
"""
386
442
443
+ if self .ft_manager :
444
+ self ._ft_load ()
445
+
387
446
if not self .enable_checkpoint or not os .path .isdir (self .folder ):
388
447
return False
389
448
@@ -467,10 +526,36 @@ def _find_load_step(self, folder: str = "") -> int:
467
526
return - 1
468
527
return max (step_counts )
469
528
529
+ def _ft_folder (self ) -> str :
530
+ return os .path .join (self .folder , f"ft-replicat-{ self .ft_replica_id } " )
531
+
470
532
def _create_checkpoint_id (self , step : int , folder : str = "" ) -> str :
471
533
folder = folder if folder else self .folder
472
534
return os .path .join (folder , f"step-{ step } " )
473
535
536
+ def _ft_save (self , step : int ) -> None :
537
+ begin = time .monotonic ()
538
+ self ._async_wait ()
539
+ checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
540
+ self .async_future = dcp .async_save (
541
+ self .ft_states , checkpoint_id = checkpoint_id , process_group = self .pg
542
+ )
543
+ logger .info (f"Staging ft checkpoint took { time .monotonic () - begin } secs." )
544
+
545
+ def _ft_load (self ) -> None :
546
+ step = self ._find_load_step (folder = self ._ft_folder ())
547
+ if step == - 1 :
548
+ return
549
+
550
+ begin = time .monotonic ()
551
+ logger .info (f"Loading the FT checkpoint at step { step } ." )
552
+ checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
553
+ dcp .load (self .ft_states , checkpoint_id = checkpoint_id )
554
+ GarbageCollection .collect ("GC collection for checkpoint loading." )
555
+ logger .info (
556
+ f"Finished loading the ft checkpoint in { time .monotonic () - begin :.2f} seconds."
557
+ )
558
+
474
559
def _states_to_load (self , step : int ) -> Dict [str , Any ]:
475
560
"""Determines which states to load for the given step.
476
561
@@ -491,6 +576,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
491
576
for exclude_key in self .exclude_from_loading :
492
577
if exclude_key not in states :
493
578
raise ValueError (f"{ exclude_key } not found in state_dict." )
579
+ if self .ft_manager :
580
+ states_to_load .pop (DATALOADER )
494
581
return states_to_load
495
582
496
583
def _save_last_step (self , curr_step : int ) -> None :
@@ -577,6 +664,7 @@ def _purge_stale_checkpoints(self):
577
664
self .keep_latest_k > 0
578
665
and dist .get_rank () == 0
579
666
and os .path .isdir (self .folder )
667
+ and (not self .ft_manager or self .ft_manager .participating_rank () == 0 )
580
668
):
581
669
discovered_checkpoints = []
582
670
for filename in os .listdir (self .folder ):
0 commit comments