@@ -360,7 +360,12 @@ def _split_keys_evenly(keys: list, n: int) -> list:
360
360
361
361
362
362
def _load_part_state_dict (
363
- keys , checkpoint_file : Union [str , os .PathLike ], tensor_parallel_split_mapping , fliter_dict_keys , device
363
+ keys ,
364
+ checkpoint_file : Union [str , os .PathLike ],
365
+ tensor_parallel_split_mapping ,
366
+ fliter_dict_keys ,
367
+ device ,
368
+ return_numpy = False ,
364
369
):
365
370
"""load part state dict from checkpoint file.
366
371
@@ -395,7 +400,7 @@ def _load_part_state_dict(
395
400
weight = tensor_parallel_split_mapping [key ](py_safe_slice_ )
396
401
else :
397
402
weight = py_safe_slice_ [:]
398
- if device == "expected" :
403
+ if not return_numpy and device == "expected" :
399
404
with device_guard ():
400
405
weight = paddle .Tensor .__call__ (weight , zero_copy = True )
401
406
weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
@@ -407,9 +412,10 @@ def _load_part_state_dict(
407
412
or key .endswith (ASYMMETRY_QUANT_SCALE_MAX )
408
413
):
409
414
scale = f .get_tensor (key )
410
- with device_guard ():
411
- scale = paddle .Tensor .__call__ (scale , zero_copy = True )
412
- scale = scale ._copy_to (paddle .framework ._current_expected_place (), False )
415
+ if not return_numpy and device == "expected" :
416
+ with device_guard ():
417
+ scale = paddle .Tensor .__call__ (scale , zero_copy = True )
418
+ scale = scale ._copy_to (paddle .framework ._current_expected_place (), False )
413
419
scale_dict [key ] = scale
414
420
return part_state_dict , scale_dict
415
421
@@ -420,6 +426,7 @@ def load_state_dict(
420
426
fliter_dict_keys = None ,
421
427
device = "cpu" ,
422
428
ckpt_quant_stage = "O0" ,
429
+ return_numpy = False ,
423
430
):
424
431
"""
425
432
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
@@ -455,6 +462,7 @@ def load_state_dict(
455
462
tensor_parallel_split_mapping ,
456
463
fliter_dict_keys ,
457
464
device ,
465
+ return_numpy ,
458
466
)
459
467
else :
460
468
# Load state dict in multi-thread to speed up loading
@@ -469,6 +477,7 @@ def load_state_dict(
469
477
tensor_parallel_split_mapping ,
470
478
fliter_dict_keys ,
471
479
device ,
480
+ return_numpy ,
472
481
): keys
473
482
for keys in keys_groups
474
483
}
@@ -477,7 +486,7 @@ def load_state_dict(
477
486
state_dict .update (res_state_dict )
478
487
scale_dict .update (res_scale_dict )
479
488
480
- if device == "cpu" :
489
+ if not return_numpy and device == "cpu" :
481
490
for k in list (state_dict .keys ()):
482
491
with device_guard ():
483
492
state_dict [k ] = paddle .Tensor .__call__ (state_dict .pop (k ), zero_copy = True )
@@ -3174,7 +3183,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False):
3174
3183
with safe_open (safe_model_path , framework = "np" , device = "cpu" ) as f :
3175
3184
loaded_keys = f .keys ()
3176
3185
tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_keys )
3177
- state_dict = load_state_dict (safe_model_path , tp_actions )
3186
+ state_dict = load_state_dict (safe_model_path , tp_actions , return_numpy = return_numpy )
3178
3187
else : # shard files safetensors
3179
3188
resolved_archive_file , resolved_sharded_files , sharded_metadata , is_sharded = cls ._resolve_model_file_path (
3180
3189
pretrained_model_name_or_path = folder ,
@@ -3190,10 +3199,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False):
3190
3199
shard_file ,
3191
3200
tp_actions ,
3192
3201
loaded_state_dict_keys ,
3202
+ return_numpy = return_numpy ,
3193
3203
)
3194
3204
state_dict .update (shard_state_dict )
3195
- if return_numpy :
3196
- for k in list (state_dict .keys ()):
3197
- if not isinstance (state_dict [k ], np .ndarray ):
3198
- state_dict [k ] = state_dict .pop (k ).cpu ().numpy ()
3199
3205
return state_dict
0 commit comments