Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow incompatible shapes in load_state_dict(strict=False) #24139

Open
ppwwyyxx opened this issue Aug 10, 2019 · 7 comments
Open

Allow incompatible shapes in load_state_dict(strict=False) #24139

ppwwyyxx opened this issue Aug 10, 2019 · 7 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Aug 10, 2019

馃殌 Feature

Right now, module.load_state_dict(strict=False) allows the following:

  • loading a dict with missing parameters
  • loading a dict with more parameters than needed

And it returns an object containing the information about what are missing or what parameters are unexpected.

But it will throw an error if there are parameters with the same name but different shape. It would be good to also allow this behavior, and return information about the unmatched parameters.

Motivation

This will help with work in transfer learning, for example.

Pitch

User can write

ret = model.load_state_dict("pretrained.pkl")
for x in ret.incompatible_keys:
    logger.warn("x is not loaded because it has shape xx in checkpoint but shape yy in the model") 

Alternatives

User now have to manually modify the state_dict for such use cases.
UPDATE: it's error-prone for users to do it manually, because some modules (e.g. quantization observers) by design expect to support loading checkpoints with incompatible shapes. It's hard for users to distinguish them from the unsupported incompatible shapes.

Some related discussions at #8282 (comment)

@jerryzh168 jerryzh168 added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 12, 2019
@elbaro
Copy link

elbaro commented Jan 11, 2020

I am in the same situation. My model has a parameter of dynamic tensor shape and rank, which cannot be determined at the initialization phase.

But load_state_dict(strict=False) shouldn't do this silently.
For example, if a model has a rank A and the state_dict has a rank B, A!=B, then load_state_dict does not know if it should overwrite A with B or convert B to A.

For example, load_state_dict has this conversion:

# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1:
    input_param = input_param[0]

A suggestion is load_state_dict(strict=False, overwrite_tensor_shape=True).

@zetyquickly
Copy link
Contributor

Hello everyone,

Here I want to add that there is also a case when the model has a None value in particular key but state dict has certain tensor for this key. So I think it is needed to check not only shape difference but the presense of tensor in the model.

The case might occur when you have state dict of a model that has collected statistics for quantization but hasn't quantised yet (e.g. some_layer.observer)

@leo-mao
Copy link

leo-mao commented Jul 10, 2020

long for this feature for my transfer learning task

@ppwwyyxx
Copy link
Contributor Author

@elbaro note that the original feature proposal was not to ignore incompatible shapes silently. It should return information about mismatches. This would allow users to handle them manually: overwrite, ignore, reshape, transpose, etc, and would be better than just an "overwrite_..." option.

@tcapelle
Copy link

this please!

@ZFTurbo
Copy link

ZFTurbo commented Jun 21, 2023

I did this hack:

def load_not_compatible_weights(model, weights):
    new_model = model.state_dict()
    old_model = torch.load(weights)

    for el in new_model:
        if el in old_model:
            print('Match found for {}!'.format(el))
            if new_model[el].shape == old_model[el].shape:
                print('Action: Just copy weights!')
                new_model[el] = old_model[el]
            else:
                if len(new_model[el].shape) != len(old_model[el].shape):
                    print('Action: Different dimension! Too lazy to write the code... Skip it')
                else:
                    print('Shape is different: {} != {}'.format(tuple(new_model[el].shape), tuple(old_model[el].shape)))
                    ln = len(new_model[el].shape)
                    max_shape = []
                    slices_old = []
                    slices_new = []
                    for i in range(ln):
                        max_shape.append(max(new_model[el].shape[i], old_model[el].shape[i]))
                        slices_old.append(slice(0, old_model[el].shape[i]))
                        slices_new.append(slice(0, new_model[el].shape[i]))
                    print(max_shape)
                    print(slices_old, slices_new)
                    slices_old = tuple(slices_old)
                    slices_new = tuple(slices_new)
                    max_matrix = np.zeros(max_shape, dtype=np.float32)
                    for i in range(ln):
                        max_matrix[slices_old] = old_model[el].cpu().numpy()
                    max_matrix = torch.from_numpy(max_matrix)
                    new_model[el] = max_matrix[slices_new]
        else:
            print('Match not found for {}!'.format(el))
    model.load_state_dict(
        new_model
    )

@ppwwyyxx
Copy link
Contributor Author

ppwwyyxx commented Sep 17, 2023

The workaround for this issue that we implemented a few years ago is at https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py#L280-L344

Note that it's quite nontrivial (unlike the above comment) as there are a few corner case classes that should be handled. That's why it's better to build this into each module's own load_state_dict and why this feature request is important.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants