-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix keys name for Transformer #2529
Conversation
Seems like the PR current doesn't resolve issues with HF models loaded through |
To solve this, I made Other calls that do load model
Other calls that use
|
OK, IMO the PR is ready for merging but maybe someone should sanity check the few commits I've done. This key patching is signaled to the user with an In the case of It looks like this:
|
sounds good to me. |
What does this PR do?
This PR solves an issue due to #2489. Indeed, this PR modify the name of a key in decoder transformer
self.mutihead_attn
toself.multihead_attn
. Doing so breaks the loading of state dict since it does not recognize the previous key. In order to solve this problem, I introduce a new function calledmap_old_state_dict_weights
which is directly applied withintorch_recovery
,average_checkpoints
and_load_from_state_dict
. The first is when you are loading a checkpoint, the second when you are trying to avg multiple checkpoints (the issue here is that you need to make sure that before doing the avg, every ckpts has the same keys), and the latter when you are loading the state_dict directly from the object which can be the case in our codebase (i.e. bypassing checkpointer).Before submitting
PR review
Reviewer checklist