-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow incompatible shapes in load_state_dict(strict=False) #24139
Comments
I am in the same situation. My model has a parameter of dynamic tensor shape and rank, which cannot be determined at the initialization phase. But For example,
A suggestion is |
Hello everyone, Here I want to add that there is also a case when the model has a The case might occur when you have state dict of a model that has collected statistics for quantization but hasn't quantised yet (e.g. |
long for this feature for my transfer learning task |
@elbaro note that the original feature proposal was not to ignore incompatible shapes silently. It should return information about mismatches. This would allow users to handle them manually: overwrite, ignore, reshape, transpose, etc, and would be better than just an "overwrite_..." option. |
this please! |
I did this hack: def load_not_compatible_weights(model, weights):
new_model = model.state_dict()
old_model = torch.load(weights)
for el in new_model:
if el in old_model:
print('Match found for {}!'.format(el))
if new_model[el].shape == old_model[el].shape:
print('Action: Just copy weights!')
new_model[el] = old_model[el]
else:
if len(new_model[el].shape) != len(old_model[el].shape):
print('Action: Different dimension! Too lazy to write the code... Skip it')
else:
print('Shape is different: {} != {}'.format(tuple(new_model[el].shape), tuple(old_model[el].shape)))
ln = len(new_model[el].shape)
max_shape = []
slices_old = []
slices_new = []
for i in range(ln):
max_shape.append(max(new_model[el].shape[i], old_model[el].shape[i]))
slices_old.append(slice(0, old_model[el].shape[i]))
slices_new.append(slice(0, new_model[el].shape[i]))
print(max_shape)
print(slices_old, slices_new)
slices_old = tuple(slices_old)
slices_new = tuple(slices_new)
max_matrix = np.zeros(max_shape, dtype=np.float32)
for i in range(ln):
max_matrix[slices_old] = old_model[el].cpu().numpy()
max_matrix = torch.from_numpy(max_matrix)
new_model[el] = max_matrix[slices_new]
else:
print('Match not found for {}!'.format(el))
model.load_state_dict(
new_model
) |
The workaround for this issue that we implemented a few years ago is at https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py#L280-L344 Note that it's quite nontrivial (unlike the above comment) as there are a few corner case classes that should be handled. That's why it's better to build this into each module's own |
馃殌 Feature
Right now,
module.load_state_dict(strict=False)
allows the following:And it returns an object containing the information about what are missing or what parameters are unexpected.
But it will throw an error if there are parameters with the same name but different shape. It would be good to also allow this behavior, and return information about the unmatched parameters.
Motivation
This will help with work in transfer learning, for example.
Pitch
User can write
Alternatives
User now have to manually modify the state_dict for such use cases.
UPDATE: it's error-prone for users to do it manually, because some modules (e.g. quantization observers) by design expect to support loading checkpoints with incompatible shapes. It's hard for users to distinguish them from the unsupported incompatible shapes.
Some related discussions at #8282 (comment)
The text was updated successfully, but these errors were encountered: