-
Notifications
You must be signed in to change notification settings - Fork 255
/
_checkpointer.py
689 lines (599 loc) · 29.7 KB
/
_checkpointer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import gc
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Protocol
import torch
from torchtune import utils
from torchtune.models import convert_weights
from torchtune.models.mistral import (
mistral_reward_hf_to_tune,
mistral_reward_tune_to_hf,
)
from torchtune.models.phi3 import phi3_hf_to_tune, phi3_tune_to_hf
from torchtune.utils._checkpointing._checkpointer_utils import (
get_path,
ModelType,
safe_torch_load,
save_config,
)
from torchtune.utils.logging import get_logger
logger = get_logger("DEBUG")
class _CheckpointerInterface(Protocol):
"""
Interface implemented by Checkpointers in torchtune.
torchtune checkpointers are designed to be composable components which can be plugged
into any training recipe. Each checkpointer supports a specific set of models and training
scenarios making these easy to understand, debug and extend. For example, the
``FullModelCheckpointer``s are used for loading and saving all of the model weights.
This checkpointer can be used for Full-Finetuning scenarios or PEFT where the output is a
merged checkpoint. In case the current suite of checkpointers are inadequate,
users are encouraged to implement their own and contribute back to torchtune.
torchtune is also designed to be "state-dict invariant". This means the checkpointer
ensures that the output checkpoint has the same format as the original checkpoint i.e.
the output checkpoint has the same keys split across the same number of files as the original
checkpoint. Being "state-dict invariant" allows users to seamlessly use torchtune checkpoints
with their favorite post-training tools from the open-source ecosystem without writing
torchtune-specific convertors. To be "state-dict invariant", the ``load_checkpoint`` and
``save_checkpoint`` methods make use of the weight convertors available in
``torchtune/models/<model_folder>``.
torchtune Checkpointers support two checkpointing scenarios:
* End-of-training Checkpointing. The model weights at the end of a completed training
run are written out to file. The checkpointer ensures that the output checkpoint
files have the same keys as the input checkpoint file used to begin training. The
checkpointer also ensures that the keys are paritioned across the same number of
files as the original checkpoint. This ensures that the original metadata files can
be used as is, and the output checkpoint can be used with any tool that understands
the original checkpoint format. This includes popular inference engines such as
``llama.cpp`` and ``gpt-fast``. The output state dict has the following format:
{
"key_1": weight
...
}
Mid-training Chekpointing. In addition to the model checkpoint files, we output an
additional "recipe_state.pt" file for intermediate checkpoints. These are currently
output at the end of each epoch, and contain information such as optimizer state,
number of epochs completed etc which is needed to correctly resume a previously
interrupted training run. The recipe is responsible for constructing the state dict
with the information it needs. The checkpointer extracts the model state dict
(key = "model") and writes everything else out to "recipe_state.pt". To prevent us
from flooding ``output_dir`` with checkpoint files, the recipe state is overwritten
at the end of each epoch. The output state dicts have the following formats:
Model:
{
"key_1": weight
...
}
Recipe State:
{
"optimizer": ...,
"epoch": ...,
...
}
"""
def load_checkpoint(self, **kwargs) -> Dict[str, Any]:
...
def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None:
...
class FullModelTorchTuneCheckpointer(_CheckpointerInterface):
"""
Checkpointer which reads and writes checkpoints in a format compatible with
torchtune. No conversion of weights is required.
Currently this supports reading a single checkpoint file only. This will likely change as
we add support for larger models.
"""
def __init__(
self,
checkpoint_dir: str,
checkpoint_files: List[str],
model_type: ModelType,
output_dir: str,
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_dir = Path(checkpoint_dir)
self._checkpoint_path = get_path(self._checkpoint_dir, checkpoint_files[0])
if not self._checkpoint_path.suffix == ".pt":
raise ValueError(
f"Checkpoint file {self._checkpoint_path} is not a valid checkpoint file. "
"Checkpointer expects a valid .pt file."
)
self._adapter_checkpoint = (
get_path(self._checkpoint_dir, adapter_checkpoint)
if adapter_checkpoint
else None
)
self._resume_from_checkpoint = resume_from_checkpoint
self._model_type = model_type
self._output_dir = Path(output_dir)
# recipe_checkpoint contains the recipe state. This should be available if
# resume_from_checkpoint is True
self._recipe_checkpoint = None
if self._resume_from_checkpoint:
if recipe_checkpoint is None:
raise ValueError(
"If resume_from_checkpoint is True, recipe_checkpoint file must be provided."
)
self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint)
def load_checkpoint(self, weights_only: bool = True) -> Dict[str, Any]:
"""
Load torchtune checkpoint from file. Currently only loading from a single file is supported.
The output state_dict has the following format, with keys other than "model" only present if
``resume_from_checkpoint`` is True:
{
"model": {
"key_1": weight
...
},
"optimizer": ...,
...
}
Args:
weights_only (bool): flag passed down to torch.load. We expose this, because quantized models
cannot be loaded with weights_only=True
Returns:
Dict[str, Any]: state_dict from the input checkpoint
"""
state_dict: Dict[str:Any] = {}
state_dict[utils.MODEL_KEY] = safe_torch_load(
self._checkpoint_path, weights_only=weights_only
)
if self._adapter_checkpoint:
adapter_state_dict = safe_torch_load(self._adapter_checkpoint)
state_dict[utils.ADAPTER_KEY] = adapter_state_dict
if self._resume_from_checkpoint:
recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False)
state_dict.update(recipe_state)
return state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
intermediate_checkpoint: bool = False,
) -> None:
"""
Save torchtune checkpoint to file. If ``intermediate_checkpoint`` is True, an additional
checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe
state. The output state dicts have the following formats:
Model:
{
"key_1": weight
...
}
Recipe State:
{
"optimizer": ...,
"epoch": ...,
...
}
Args:
state_dict (Dict[str, Any]): State dict with model and (optionally) recipe state
epoch (int): Current epoch number. This is added to the checkpoint file name to ensure
we're not overwriting intermediate checkpoint files
intermediate_checkpoint (bool): If True, save an additional checkpoint file with the
recipe state
"""
self._output_dir.mkdir(exist_ok=True)
# Output file is always a .pt file with the epoch number in the name
checkpoint_file = Path.joinpath(
self._output_dir, f"torchtune_model_{epoch}"
).with_suffix(".pt")
torch.save(state_dict[utils.MODEL_KEY], checkpoint_file)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(checkpoint_file) / 1000**3:.2f} GB "
f"saved to {checkpoint_file}"
)
if utils.ADAPTER_KEY in state_dict:
output_path = Path.joinpath(
self._output_dir, f"adapter_{epoch}"
).with_suffix(".pt")
torch.save(state_dict[utils.ADAPTER_KEY], output_path)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
# If the recipe state needs to be output, first remove the model state dict
if intermediate_checkpoint:
_ = state_dict.pop(utils.MODEL_KEY)
_ = state_dict.pop(utils.ADAPTER_KEY, None)
_ = state_dict.pop(utils.ADAPTER_CONFIG, None)
output_path = Path.joinpath(self._output_dir, "recipe_state.pt")
torch.save(state_dict, output_path)
logger.info(
"Recipe checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
class FullModelHFCheckpointer(_CheckpointerInterface):
"""
Checkpointer which reads and writes checkpoints in HF's format. For LoRA models this includes
saving checkpoints in a format that can be loaded into PEFT via e.g. ``from_pretrained``. Example includes
the Llama-2-7b-hf model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b-hf)
A few notes about the checkpoint reading logic:
- HF checkpoint names usually ordered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure
we read the files in the right order, we sort the checkpoint file names before reading
- Checkpoint conversion to and from HF's format requires access to model params which are
read directly from the "config.json" file. This helps ensure we either load the weights
correctly or error out in case of discrepancy between the HF checkpoint file and torchtune's
model implementations.
Args:
checkpoint_dir (str): Directory containing the checkpoint files
checkpoint_files (List[str]): List of checkpoint files to load. Since the checkpointer takes care
of sorting by file ID, the order in this list does not matter
model_type (ModelType): Model type of the model for which the checkpointer is being loaded
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
Raises:
ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None
"""
def __init__(
self,
checkpoint_dir: str,
checkpoint_files: List[str],
model_type: ModelType,
output_dir: str,
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
) -> None:
self._checkpoint_dir = Path(checkpoint_dir)
self._checkpoint_paths = self._validate_hf_checkpoint_files(checkpoint_files)
self._adapter_checkpoint = (
get_path(self._checkpoint_dir, adapter_checkpoint)
if adapter_checkpoint
else None
)
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._resume_from_checkpoint = resume_from_checkpoint
# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
# parition the state dict into output checkpoint files. This is updated during checkpoint
# load
self._weight_map: Dict[str, str] = None
# the config.json file contains model params needed for state dict conversion
self._config = json.loads(
Path.joinpath(self._checkpoint_dir, "config.json").read_text()
)
# save config.json to output_dir
save_config(self._output_dir, self._config)
# recipe_checkpoint contains the recipe state. This should be available if
# resume_from_checkpoint is True
self._recipe_checkpoint = None
if self._resume_from_checkpoint:
if recipe_checkpoint is None:
raise ValueError(
"If resume_from_checkpoint is True, recipe_checkpoint file must be provided."
)
self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint)
def _validate_hf_checkpoint_files(self, checkpoint_files: List[str]) -> List[Path]:
"""
Validates that the checkpoint files exist and sorts based on ID.
"""
checkpoint_paths: List[Path] = []
for f in checkpoint_files:
checkpoint_path = get_path(self._checkpoint_dir, f)
checkpoint_paths.append(checkpoint_path)
return sorted(checkpoint_paths)
def load_checkpoint(self) -> Dict[str, Any]:
"""
Load torchtune checkpoint from file.
The keys and weights from across all checkpoint files are merged into a single state_dict.
We preserve the "state_dict key" <-> "checkpoint file mapping" in weight_map so we can
write the state dict correctly in ``save_checkpoint``.
Before returning, the model state dict is converted to a torchtune compatible format using.
Returns:
state_dict (Dict[str, Any]): torchtune checkpoint state dict
Raises:
ValueError: If the values in the input state_dict are not Tensors
"""
self._weight_map = {}
# merged state_dict contains keys and weights from all the checkpoint files
merged_state_dict: Dict[str, torch.Tensor] = {}
# converted_state_dict is the final state_dict passed to the recipe after the
# keys are converted into the torchtune format. This optionally also contains
# the recipe state and adapter weights
converted_state_dict: Dict[str, Dict[str, torch.Tensor]] = {}
# _checkpoint_paths are already sorted so simply enumerate to generate the right id
for cpt_idx, cpt_path in enumerate(self._checkpoint_paths):
state_dict = safe_torch_load(cpt_path)
for key, value in state_dict.items():
# Ensure that the state dict is a flat dict of keys and tensors. Breaking this assumption
# will break recipe code
if not isinstance(value, torch.Tensor):
raise ValueError(
f"Expected all values in the state dict to be torch.Tensor. "
f"Found {type(value)} instead."
)
# idx is written in the 4 digit format (eg: 0001, 0002, etc.)
self._weight_map[key] = f"{cpt_idx+1:04}"
merged_state_dict.update(state_dict)
# delete the state_dict to free up memory; TODO check if this del is needed
del state_dict
gc.collect()
if self._model_type == ModelType.PHI3_MINI:
logger.warning(
"Converting Phi-3 Mini weights from HF format."
"Note that conversion of adapter weights into PEFT format is not supported."
)
converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict)
elif self._model_type == ModelType.MISTRAL_REWARD:
converted_state_dict[utils.MODEL_KEY] = mistral_reward_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
else:
converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)
if self._adapter_checkpoint:
adapter_state_dict = safe_torch_load(self._adapter_checkpoint)
converted_state_dict[utils.ADAPTER_KEY] = adapter_state_dict
if self._resume_from_checkpoint:
recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False)
converted_state_dict.update(recipe_state)
return converted_state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
intermediate_checkpoint: bool = False,
) -> None:
"""
Save torchtune checkpoint to file. If ``intermediate_checkpoint`` is True, an additional
checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe
state.
The state_dict is first converted back to the HF format and then paritioned based on the
``_weight_map`` into separate checkpoint files.
Args:
state_dict (Dict[str, Any]): Checkpoint state dict to be written out to file
epoch (int): Epoch number. Used to create the checkpoint file name
intermediate_checkpoint (bool): If True, an additional checkpoint files for recipe state
and (if applicable) adapter weights are created. Default is False
"""
self._output_dir.mkdir(exist_ok=True)
# convert the state_dict back to hf format; do this inplace
if self._model_type == ModelType.PHI3_MINI:
state_dict[utils.MODEL_KEY] = phi3_tune_to_hf(state_dict[utils.MODEL_KEY])
elif self._model_type == ModelType.MISTRAL_REWARD:
state_dict[utils.MODEL_KEY] = mistral_reward_tune_to_hf(
state_dict[utils.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
else:
state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf(
state_dict[utils.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)
# split the state_dict into separate dicts, one for each output checkpoint file
split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {}
for key, weight in state_dict[utils.MODEL_KEY].items():
cpt_idx = self._weight_map[key]
if cpt_idx not in split_state_dicts:
split_state_dicts[cpt_idx] = {}
split_state_dicts[cpt_idx].update({key: weight})
# write the partitioned state dicts to the right checkpoint file
for cpt_idx, model_state_dict in split_state_dicts.items():
output_path = Path.joinpath(
self._output_dir, f"hf_model_{cpt_idx}_{epoch}"
).with_suffix(".pt")
torch.save(model_state_dict, output_path)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
if utils.ADAPTER_KEY in state_dict:
# Save torchtune format adapter weights even if we save PEFT format
# This way we can resume no matter what (and memory footprint of adapter weights is small)
output_path = Path.joinpath(
self._output_dir, f"adapter_{epoch}"
).with_suffix(".pt")
torch.save(state_dict[utils.ADAPTER_KEY], output_path)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
if self._model_type == ModelType.PHI3_MINI:
logger.warning(
"Saving Phi-3 Mini adapter weights to PEFT format is not supported, saving to torchtune format instead"
)
else:
state_dict[
utils.ADAPTER_KEY
] = convert_weights.tune_to_peft_adapter_weights(
state_dict[utils.ADAPTER_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
peft_output_path = Path.joinpath(
self._output_dir, "adapter_model"
).with_suffix(".bin")
torch.save(state_dict[utils.ADAPTER_KEY], peft_output_path)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {peft_output_path}"
)
if utils.ADAPTER_CONFIG in state_dict:
if self._model_type == ModelType.PHI3_MINI:
logger.warning(
"PEFT integration for Phi-3 Mini is not supported, skipping adapter config save"
)
else:
state_dict[
utils.ADAPTER_CONFIG
] = convert_weights.tune_to_peft_adapter_config(
state_dict[utils.ADAPTER_CONFIG]
)
output_path = Path.joinpath(self._output_dir, "adapter_config.json")
with open(output_path, "w") as f:
json.dump(state_dict[utils.ADAPTER_CONFIG], f)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
# If the recipe state needs to be output, first remove the model state dict
# and if it exists, remove the adapter state dict as well
if intermediate_checkpoint:
_ = state_dict.pop(utils.MODEL_KEY)
_ = state_dict.pop(utils.ADAPTER_KEY, None)
_ = state_dict.pop(utils.ADAPTER_CONFIG, None)
output_path = Path.joinpath(self._output_dir, "recipe_state.pt")
torch.save(state_dict, output_path)
logger.info(
"Recipe checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
class FullModelMetaCheckpointer(_CheckpointerInterface):
"""
Checkpointer which reads and writes checkpoints in Meta's format. Example includes
the Llama-2-7b model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b)
Currently we support reading from a single checkpoint file only. Support for reading from
sharded checkpoints is WIP.
Args:
checkpoint_dir (str): Directory containing the checkpoint files
checkpoint_files (List[str]): List of checkpoint files to load. Currently this checkpointer only
supports loading a single checkpoint file.
model_type (ModelType): Model type of the model for which the checkpointer is being loaded
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
Raises:
ValueError: If ``checkpoint_files`` is not a list of length 1
ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None
"""
def __init__(
self,
checkpoint_dir: str,
checkpoint_files: List[str],
model_type: ModelType,
output_dir: str,
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_dir = Path(checkpoint_dir)
self._checkpoint_path = get_path(self._checkpoint_dir, checkpoint_files[0])
self._adapter_checkpoint = (
get_path(self._checkpoint_dir, adapter_checkpoint)
if adapter_checkpoint
else None
)
self._resume_from_checkpoint = resume_from_checkpoint
self._model_type = model_type
self._output_dir = Path(output_dir)
# recipe_checkpoint contains the recipe state. This should be available if
# resume_from_checkpoint is True
self._recipe_checkpoint = None
if self._resume_from_checkpoint:
if recipe_checkpoint is None:
raise ValueError(
"If resume_from_checkpoint is True, recipe_checkpoint file must be provided."
)
self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint)
def load_checkpoint(self) -> Dict[str, Any]:
"""
Load torchtune checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
state_dict[utils.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict)
if self._adapter_checkpoint:
adapter_state_dict = safe_torch_load(self._adapter_checkpoint)
state_dict[utils.ADAPTER_KEY] = adapter_state_dict
if self._resume_from_checkpoint:
recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False)
state_dict.update(recipe_state)
return state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
intermediate_checkpoint: bool = False,
) -> None:
"""
Save torchtune checkpoint to file. If ``intermediate_checkpoint`` is True, an additional
checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe
state.
Args:
state_dict (Dict[str, Any]): Checkpoint state dict to be written out to file
epoch (int): Epoch number. Used to create the checkpoint file name
intermediate_checkpoint (bool): If True, an additional checkpoint files for recipe state
and (if applicable) adapter weights are created. Default is False
"""
self._output_dir.mkdir(exist_ok=True)
model_state_dict = state_dict[utils.MODEL_KEY]
state_dict[utils.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict)
# Output file is always a .pt file with the epoch number in the name
checkpoint_file = Path.joinpath(
self._output_dir, f"meta_model_{epoch}"
).with_suffix(".pt")
torch.save(state_dict[utils.MODEL_KEY], checkpoint_file)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(checkpoint_file) / 1000**3:.2f} GB "
f"saved to {checkpoint_file}"
)
if utils.ADAPTER_KEY in state_dict:
output_path = Path.joinpath(
self._output_dir, f"adapter_{epoch}"
).with_suffix(".pt")
torch.save(state_dict[utils.ADAPTER_KEY], output_path)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
# If the recipe state needs to be output, first remove the model state dict
# and if it exists, remove the adapter state dict as well
if intermediate_checkpoint:
_ = state_dict.pop(utils.MODEL_KEY)
_ = state_dict.pop(utils.ADAPTER_KEY, None)
_ = state_dict.pop(utils.ADAPTER_CONFIG, None)
output_path = Path.joinpath(self._output_dir, "recipe_state.pt")
torch.save(state_dict, output_path)
logger.info(
"Recipe checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)