Skip to content

Commit

Permalink
[Init] Make sure shape mismatches are caught early (huggingface#2847)
Browse files Browse the repository at this point in the history
Improve init
  • Loading branch information
patrickvonplaten committed Mar 28, 2023
1 parent 7cabf90 commit b4c5db9
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" those weights or else make sure your checkpoint file is correct."
)

empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
)

if empty_state_dict[param_name].shape != param.shape:
raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)

if accepts_dtype:
set_module_tensor_to_device(
model, param_name, param_device, value=param, dtype=torch_dtype
Expand Down

0 comments on commit b4c5db9

Please sign in to comment.