Skip to content

Commit c1322be

Browse files
authored
Update loading numpy format state dict (#10679)
* [UC] fix _load_state_dict_into_model * update load numpy state_dict
1 parent ea2976f commit c1322be

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,12 @@ def _split_keys_evenly(keys: list, n: int) -> list:
360360

361361

362362
def _load_part_state_dict(
363-
keys, checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping, fliter_dict_keys, device
363+
keys,
364+
checkpoint_file: Union[str, os.PathLike],
365+
tensor_parallel_split_mapping,
366+
fliter_dict_keys,
367+
device,
368+
return_numpy=False,
364369
):
365370
"""load part state dict from checkpoint file.
366371
@@ -395,7 +400,7 @@ def _load_part_state_dict(
395400
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
396401
else:
397402
weight = py_safe_slice_[:]
398-
if device == "expected":
403+
if not return_numpy and device == "expected":
399404
with device_guard():
400405
weight = paddle.Tensor.__call__(weight, zero_copy=True)
401406
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
@@ -407,9 +412,10 @@ def _load_part_state_dict(
407412
or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
408413
):
409414
scale = f.get_tensor(key)
410-
with device_guard():
411-
scale = paddle.Tensor.__call__(scale, zero_copy=True)
412-
scale = scale._copy_to(paddle.framework._current_expected_place(), False)
415+
if not return_numpy and device == "expected":
416+
with device_guard():
417+
scale = paddle.Tensor.__call__(scale, zero_copy=True)
418+
scale = scale._copy_to(paddle.framework._current_expected_place(), False)
413419
scale_dict[key] = scale
414420
return part_state_dict, scale_dict
415421

@@ -420,6 +426,7 @@ def load_state_dict(
420426
fliter_dict_keys=None,
421427
device="cpu",
422428
ckpt_quant_stage="O0",
429+
return_numpy=False,
423430
):
424431
"""
425432
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
@@ -455,6 +462,7 @@ def load_state_dict(
455462
tensor_parallel_split_mapping,
456463
fliter_dict_keys,
457464
device,
465+
return_numpy,
458466
)
459467
else:
460468
# Load state dict in multi-thread to speed up loading
@@ -469,6 +477,7 @@ def load_state_dict(
469477
tensor_parallel_split_mapping,
470478
fliter_dict_keys,
471479
device,
480+
return_numpy,
472481
): keys
473482
for keys in keys_groups
474483
}
@@ -477,7 +486,7 @@ def load_state_dict(
477486
state_dict.update(res_state_dict)
478487
scale_dict.update(res_scale_dict)
479488

480-
if device == "cpu":
489+
if not return_numpy and device == "cpu":
481490
for k in list(state_dict.keys()):
482491
with device_guard():
483492
state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True)
@@ -3174,7 +3183,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False):
31743183
with safe_open(safe_model_path, framework="np", device="cpu") as f:
31753184
loaded_keys = f.keys()
31763185
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
3177-
state_dict = load_state_dict(safe_model_path, tp_actions)
3186+
state_dict = load_state_dict(safe_model_path, tp_actions, return_numpy=return_numpy)
31783187
else: # shard files safetensors
31793188
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
31803189
pretrained_model_name_or_path=folder,
@@ -3190,10 +3199,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False):
31903199
shard_file,
31913200
tp_actions,
31923201
loaded_state_dict_keys,
3202+
return_numpy=return_numpy,
31933203
)
31943204
state_dict.update(shard_state_dict)
3195-
if return_numpy:
3196-
for k in list(state_dict.keys()):
3197-
if not isinstance(state_dict[k], np.ndarray):
3198-
state_dict[k] = state_dict.pop(k).cpu().numpy()
31993205
return state_dict

0 commit comments

Comments
 (0)