-
Notifications
You must be signed in to change notification settings - Fork 262
/
base_checkpointer.py
560 lines (475 loc) · 24.3 KB
/
base_checkpointer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import abc
import bisect
import logging
import os
from datetime import timedelta
from typing import Any, cast, Iterable, List, Literal, Optional, Union
import torch.distributed as dist
from pyre_extensions import none_throws
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks._checkpoint_utils import (
_delete_checkpoint,
_metadata_exists,
_sort_by_metric_value,
_sort_by_recency,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
get_latest_checkpoint_path,
)
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
RestoreOptions,
)
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast
from torchtnt.utils.fsspec import get_filesystem
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
logger: logging.Logger = logging.getLogger(__name__)
class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
"""
Abstract base class for file-based state_dict checkpointing. This class can be used as the base of a checkpointing callback, and handles
checkpointing frequency logic, checkpoint naming, checkpoint purging / upkeep, and process group synchronization. There are only two methods
that need to be implemented by subclasses:
1) ``_checkpoint_impl`` which implements the checkpoint saving logic, given the relevant checkpoint items and path.
2) ``restore`` which implements restoring the checkpoint given the relevant checkpoint path.
The subclass may override the ``metadata_fname`` attribute to specify the filename of the metadata file that will be written within the checkpoint directory.
This will be used by this base class to ensure the integrity of the checkpoint.
Args:
dirpath: Parent directory to save checkpoints to.
save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated.
save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated.
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
process_group: The process group on which the ranks will communicate on. If the process group is not gloo-based, a new gloo-based process group will be created.
Note:
If torch.distributed is available and default process group is initialized, the constructor will call a collective operation for rank 0 to broadcast the dirpath to all other ranks
Note:
This class assumes checkpoint items are saved in the directory provided in ``_checkpoint_impl`` and will be in the form of ``<dirpath>/<epoch>-<step>/``. Checkpoint contents
should be stored within this directory, as deleting and retrieving latest checkpoint relies on reading the <epoch>-<step> directory name within <dirpath>
Note:
If best_checkpoint_config is enabled, the attribute must be on the unit upon checkpoint time, and must be castable to "float". This value must be maintained by the unit, and updated
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends. If the metric value is None, the
checkpoint will be saved, without the metric value in the checkpoint name
"""
metadata_fname: Optional[str] = None
def __init__(
self,
dirpath: str,
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> None:
if save_every_n_train_steps is not None and save_every_n_train_steps <= 0:
raise ValueError(
f"Invalid value passed for save_every_n_train_steps. Expected to receive either None or positive number, but received {save_every_n_train_steps}"
)
if save_every_n_epochs is not None and save_every_n_epochs <= 0:
raise ValueError(
f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}"
)
if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0:
raise ValueError(
f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}"
)
self._best_checkpoint_config = best_checkpoint_config
if best_checkpoint_config and best_checkpoint_config.mode not in {"min", "max"}:
raise ValueError(
f"Invalid value passed for best_checkpoint_config.mode. Expected to receive 'min' or 'max', but received {best_checkpoint_config.mode}"
)
self._save_every_n_train_steps = save_every_n_train_steps
self._save_every_n_epochs = save_every_n_epochs
self._save_every_n_eval_epochs = save_every_n_eval_epochs
self._keep_last_n_checkpoints = keep_last_n_checkpoints
self._ckpt_dirpaths: List[str] = []
if self._keep_last_n_checkpoints:
metric_name = (
None
if not best_checkpoint_config
else best_checkpoint_config.monitored_metric
)
ckpt_dirpaths = get_checkpoint_dirpaths(
dirpath=dirpath,
metadata_fname=self.metadata_fname,
metric_name=metric_name,
process_group=process_group,
)
# sort by metric value if doing best checkpoint, else by recency
if best_checkpoint_config:
self._ckpt_dirpaths = _sort_by_metric_value(
ckpt_dirpaths, mode=best_checkpoint_config.mode
)
else:
self._ckpt_dirpaths = _sort_by_recency(ckpt_dirpaths)
self._process_group: Optional[dist.ProcessGroup] = None
self._setup_gloo_pg(process_group)
self._pg_wrapper = PGWrapper(process_group)
# sync dirpaths from rank 0
self._sync_dirpath_to_all_ranks(dirpath)
def _setup_gloo_pg(self, process_group: Optional[dist.ProcessGroup]) -> None:
"""
Setups gloo process group to be used for any collectives called during
checkpointing. If global process group is already gloo, no action is required.
Gloo is used over nccl for better compatibility.
"""
if not dist.is_initialized():
# there can be no process group
return
if process_group is None:
# use global process group
process_group = dist.group.WORLD
# we create a new gloo process group if different backend is being used
if dist.get_backend(process_group) != dist.Backend.GLOO:
rank_zero_info("Creating new gloo process group for checkpointing.")
self._process_group = dist.new_group(
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
)
else:
self._process_group = process_group
def _sync_dirpath_to_all_ranks(self, dirpath: str) -> None:
if not (dist.is_available() and dist.is_initialized()):
self._dirpath: str = dirpath
return
dirpath_container = [dirpath] if self._pg_wrapper.get_rank() == 0 else [""]
# broadcast directory from global rank 0
dist.broadcast_object_list(dirpath_container, src=0, group=self._process_group)
updated_dirpath = dirpath_container[0]
if updated_dirpath != dirpath:
logger.warning(f"Updating dirpath to match rank 0: {updated_dirpath}")
self._dirpath: str = updated_dirpath
@property
def dirpath(self) -> str:
"""Returns parent directory to save to."""
return self._dirpath
def _generate_checkpoint_and_upkeep(
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
) -> bool:
"""
Implementation for saving checkpoint while taking care of checkpoint
name generation and cleanup of oldest checkpoints.
Args:
state: Current state of the trainer.
unit: Current training unit.
hook: Hook at which checkpoint is being saved.
Returns:
True if checkpoint was successfully saved. False otherwise.
"""
# 1) generate checkpoint name
unit = cast(TTrainUnit, unit)
num_steps_completed = unit.train_progress.num_steps_completed
if state.entry_point == EntryPoint.FIT:
eval_unit = cast(TEvalUnit, unit)
num_steps_completed += eval_unit.eval_progress.num_steps_completed
epoch = unit.train_progress.num_epochs_completed
checkpoint_path = _get_save_path(self._dirpath, epoch, num_steps_completed)
# 1.1) Make sure that last checkpoint does not already exist
if hook == "on_train_end" and self._does_checkpoint_exist(
checkpoint_path, process_group=self._process_group
):
rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger)
return False
# 1.2) If there is a tracked metric, add to the checkpoint path
metric_value = self._get_tracked_metric_value(unit)
if metric_value is not None:
metric_name = none_throws(self._best_checkpoint_config).monitored_metric
checkpoint_path += f"_{metric_name}={metric_value}"
# 2) Determine if checkpoint should be saved
if not self._should_save_checkpoint(metric_value):
return False
# 3) try to save checkpoint
if not self._checkpoint_impl(
state, unit, checkpoint_path=checkpoint_path, hook=hook
):
return False
# 4) remove the oldest/worst checkpoint if applicable
if self._should_remove_checkpoint():
self._remove_checkpoint(state)
# 5) update the tracked list of checkpoint paths
if self._best_checkpoint_config and (metric_value is not None):
metric_mode = none_throws(self._best_checkpoint_config).mode
# insert the checkpoint path at the correct index to preserve ordering
keys = [
float(os.path.basename(x).split("=")[-1]) for x in self._ckpt_dirpaths
]
if metric_mode == "min":
keys.reverse()
# Use bisect.bisect() to find the insertion point
idx = bisect.bisect(keys, metric_value)
if metric_mode == "min":
idx = len(self._ckpt_dirpaths) - idx
self._ckpt_dirpaths.insert(idx, checkpoint_path)
elif not self._best_checkpoint_config: # no metric to track
self._ckpt_dirpaths.append(checkpoint_path)
return True
def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
"""
If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float.
Args:
unit: The training unit to look for the tracked metric in.
Returns:
The value of the tracked metric, or None if there is no best_checkpoint config defined.
Raises:
RuntimeError: If the unit does not have the attribute specified in the best_checkpoint config,
or if the value cannot be cast to a float.
"""
if not self._best_checkpoint_config:
return None
monitored_metric_name = self._best_checkpoint_config.monitored_metric
if not hasattr(unit, monitored_metric_name):
raise RuntimeError(
f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint."
)
metric_value_f = None
if (metric_value := getattr(unit, monitored_metric_name)) is not None:
try:
metric_value_f = float(metric_value)
except ValueError as e:
raise RuntimeError(
f"Unable to convert monitored metric {monitored_metric_name} to a float. Please ensure the value "
"can be converted to float and is not a multi-element tensor value."
) from e
return metric_value_f
def on_train_start(self, state: State, unit: TTrainUnit) -> None:
# clean up the difference if surplus of checkpoints exist
keep_last_n_checkpoints = self._keep_last_n_checkpoints
if (
keep_last_n_checkpoints
and len(self._ckpt_dirpaths) > keep_last_n_checkpoints
):
logger.warning(
" ".join(
[
f"{len(self._ckpt_dirpaths)} checkpoints found in {self._dirpath}.",
f"Deleting {len(self._ckpt_dirpaths) - keep_last_n_checkpoints} oldest",
"checkpoints to enforce ``keep_last_n_checkpoints`` argument.",
]
)
)
for _ in range(len(self._ckpt_dirpaths) - keep_last_n_checkpoints):
self._remove_checkpoint(state)
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
num_steps_completed = unit.train_progress.num_steps_completed
save_every_n_train_steps = self._save_every_n_train_steps
if (
save_every_n_train_steps is None
or num_steps_completed % save_every_n_train_steps != 0
):
return
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_step_end")
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
epoch = unit.train_progress.num_epochs_completed
save_every_n_epochs = self._save_every_n_epochs
if save_every_n_epochs is None or epoch % save_every_n_epochs != 0:
return
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_epoch_end")
def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None:
epoch = unit.eval_progress.num_epochs_completed
save_every_n_eval_epochs = self._save_every_n_eval_epochs
if save_every_n_eval_epochs is None or epoch % save_every_n_eval_epochs != 0:
return
self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_epoch_end")
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end")
@abc.abstractmethod
def _checkpoint_impl(
self,
state: State,
unit: AppStateMixin,
*,
checkpoint_path: str,
hook: str,
) -> bool:
"""
Implementation of saving checkpoint.
Args:
state: current application state
unit: current unit
checkpoint_path: path to save checkpoint
hook: name of callback hook that triggered this function call
Returns:
Whether a new checkpoint was created.
"""
...
def _should_save_checkpoint(self, metric_value: Optional[float] = None) -> bool:
"""
Whether a new checkpoint should be saved.
"""
keep_last_n_checkpoints = self._keep_last_n_checkpoints
if not keep_last_n_checkpoints:
# always save candidate checkpoint if no limit of checkpoints is specified
return True
if len(self._ckpt_dirpaths) < keep_last_n_checkpoints:
# limit of checkpoints has not been reached
return True
best_checkpoint_config = self._best_checkpoint_config
if not best_checkpoint_config:
# we always save the latest checkpoint
return True
# otherwise, we need to determine if we should overwrite the worst checkpoint
assert metric_value
ckpt_value = float(self._ckpt_dirpaths[0].split("=")[-1])
# do not checkpoint if candidate is worse than the existing one
if best_checkpoint_config.mode == "min" and metric_value > ckpt_value:
return False
elif best_checkpoint_config.mode == "max" and metric_value < ckpt_value:
return False
# the candidate checkpoint is better than the existing one, so we must overwrite it
return True
def _should_remove_checkpoint(self) -> bool:
"""
Whether the oldest / worst checkpoint should be removed.
This is called after the candidate checkpoint is saved, so it is already evaluated that the
candidate checkpoint was worth saving.
"""
keep_last_n_checkpoints = self._keep_last_n_checkpoints
return (
keep_last_n_checkpoints is not None
and len(self._ckpt_dirpaths) >= keep_last_n_checkpoints
)
def _remove_checkpoint(self, state: State) -> None:
# remove oldest checkpoint directory
oldest_ckpt_path = self._ckpt_dirpaths.pop(0)
with get_timing_context(state, f"{self.__class__.__name__}.delete_checkpoint"):
if self._pg_wrapper.get_rank() == 0:
# only delete on rank 0
_delete_checkpoint(oldest_ckpt_path)
self._pg_wrapper.barrier()
@staticmethod
@abc.abstractmethod
def restore(
path: str,
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
) -> None:
"""Method to restore checkpoint state from a path.
There are additional flags offered should the user want to skip loading the train and eval progress.
By default, the train and eval progress are restored, if applicable.
Args:
path: Path of the checkpoint to restore.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
"""
...
@classmethod
def restore_from_latest(
cls,
dirpath: str,
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
**kwargs: Any,
) -> bool:
"""
Given a parent directory where checkpoints are saved, restore the checkpoint state from the latest checkpoint in the directory.
There are additional flags offered should the user want to skip loading the train and eval progress.
By default, the train and eval progress are restored, if applicable.
Args:
dirpath: Parent directory from which to get the latest checkpoint.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
Returns:
True if the latest checkpoint directory was found and successfully restored, otherwise False.
"""
path = get_latest_checkpoint_path(
dirpath, metadata_fname=cls.metadata_fname, process_group=process_group
)
if path is None:
return False
logger.info(f"Restoring from path: {path}")
cls.restore(
path,
unit,
train_dataloader=train_dataloader,
process_group=process_group,
restore_options=restore_options,
**kwargs,
)
return True
@classmethod
def restore_from_best(
cls,
dirpath: str,
unit: AppStateMixin,
metric_name: str,
mode: Literal["min", "max"],
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
**kwargs: Any,
) -> bool:
"""
Given a parent directory where checkpoints are saved, restore the checkpoint state from the best checkpoint in the directory.
There are additional flags offered should the user want to skip loading the train and eval progress.
By default, the train and eval progress are restored, if applicable.
Args:
dirpath: Parent directory from which to get the latest checkpoint.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
metric_name: Name of the metric to use to find the best checkpoint.
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
Returns:
True if the best checkpoint directory was found and successfully restored, otherwise False.
"""
best_checkpoint_path = get_best_checkpoint_path(
dirpath,
metric_name=metric_name,
mode=mode,
metadata_fname=cls.metadata_fname,
process_group=process_group,
)
if best_checkpoint_path is None:
rank_zero_warn(
f"No checkpoints with metric name {metric_name} were found in {dirpath}. Not loading any checkpoint.",
logger=logger,
)
return False
rank_zero_info(f"Loading checkpoint from {best_checkpoint_path}")
cls.restore(
best_checkpoint_path,
unit,
train_dataloader=train_dataloader,
process_group=process_group,
restore_options=restore_options,
**kwargs,
)
return True
@rank_zero_read_and_broadcast
def _does_checkpoint_exist(
self, checkpoint_path: str, process_group: Optional[dist.ProcessGroup] = None
) -> bool:
"""
Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
If the checkpointer doesn't have a metadata file, this function will always return False.
"""
metadata_fname = self.metadata_fname
if not metadata_fname:
return False
fs = get_filesystem(checkpoint_path)
return _metadata_exists(fs, checkpoint_path, metadata_fname)
def _get_save_path(dirpath: str, epoch: int, step: int) -> str:
# TODO: discuss whether this path should be customized
return os.path.join(dirpath, f"epoch_{epoch}_step_{step}")