-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add set_weight_decay to support custom weight decay setting #5671
Conversation
💊 CI failures summary and remediationsAs of commit 755ccd5 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakagesunittest_windows_cpu_py3.9 (1/2)Step: "Run tests" (full log | diagnosis details | 🔁 rerun)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xiaohu2015 Thanks for the contribution.
I've added a few comments, let me know your thoughts.
@xiaohu2015 LGTM, thanks for the awesome contributions as always! As discussed offline, this approach is a great step towards a generic solution. There are still some cons such as the fact that the custom_key logic is applied always prior the norm_layers but we can investigate alternative more flexible solutions on separate PRs. Since more investigation will be done, we can move this back to references after all the tests on this PR pass. |
LGTM and the test pass. We can move it to references and merge. :) |
references/classification/train.py
Outdated
help="weight decay for bias parameters of all layers (default: None, same value as --wd)", | ||
) | ||
parser.add_argument( | ||
"--transformer-weight-decay", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jdsgomes Any thoughts on this name?
This option controls the weight decay for parameters related to position encodings and embeddings that are usually found in Transformer models. We will eventually revise this utility, so this is temporary to unblock the work on Swin but I was hoping if you could offer a better alternative. Line 236 on this PR will also offer clarifications over what this does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only alternative I can think of is to introduce a generic weight decay argument with pairs of ([keys], weight decay) and then deprecate bias_weight_decay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will look a bit ugly in terms of implementation (at least the way I can think of using argparse) but might make more sense for the users and be more generic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeap, true. Another idea that @xiaohu2015 suggested was to pass directly the custom_keys_weight_decay from argparse. The reason that for now we avoided this was because it exposes too much info on the internals of the utility. An alternative approach for the util is to accept lambdas that allow users to select any module they want on arbitrary conditions. We discussed offline to explore similar approaches and see if we can create something generic to be added in ops
eventually. But in the meantime, I didn't want to block your work on Swin and that's why we went with this workaround.
I wonder if there is at least a better naming convention for what this flag does. Do you think it would be accurate to name it --positional-param-decay
or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the class_token I would not call it positional-param-decay
so I would keep the original one for lack of better alternative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After an online chat with @jdsgomes, we decided to rename to --transformer-embedding-decay
. This will do for now to unblock Swin and we can review soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xiaohu2015 Thanks! LGTM!
FYI I fixed a conflict with master. I'll wait for the input from Joao but I'm happy to make for you inplace any name changes to avoid back and forth.
…5671) Summary: * add set_weight_decay * Update _utils.py * refactor code * fix import * add set_weight_decay * fix lint * fix lint * replace split_normalization_params with set_weight_decay * simplfy the code * refactor code * refactor code * fix lint * remove unused * Update test_ops.py * Update train.py * Update _utils.py * Update train.py * add set_weight_decay * add set_weight_decay * Update _utils.py * Update test_ops.py * Change `--transformer-weight-decay` to `--transformer-embedding-decay` Reviewed By: NicolasHug Differential Revision: D35393158 fbshipit-source-id: 625eec0edf01864c38c1fd826d2e8bf256e2e879 Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
In this PR, we want to add a more generic method to replace
split_normalization_params
in torchvision.ops. Thesplit_normalization_params
method only supports no weight decay for the parameters of the normalization layers. But for some recent models (eg. vit, swint), we need custom weight decay for other parameters (eg. class token, positision_embedding). This PR is related to #5491.For current implementation, the priority is: custom_keys > norm layers > others.
two test case:
references