Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add set_weight_decay to support custom weight decay setting #5671

Merged
merged 30 commits into from
Apr 1, 2022
Merged

add set_weight_decay to support custom weight decay setting #5671

merged 30 commits into from
Apr 1, 2022

Conversation

xiaohu2015
Copy link
Contributor

@xiaohu2015 xiaohu2015 commented Mar 24, 2022

In this PR, we want to add a more generic method to replace split_normalization_params in torchvision.ops. The split_normalization_params method only supports no weight decay for the parameters of the normalization layers. But for some recent models (eg. vit, swint), we need custom weight decay for other parameters (eg. class token, positision_embedding). This PR is related to #5491.

For current implementation, the priority is: custom_keys > norm layers > others.

two test case:

    model = models.mobilenet_v3_large()
    params_1 = ops._utils.split_normalization_params(model)
    params_2 = set_weight_decay(model, 0.2, norm_weight_decay=0.1)
    params_2.sort(key=lambda x: x["weight_decay"])
    params_3 = set_weight_decay(model, 0.2)
    
    assert len(params_1[0]) == len(params_2[0]["params"]) == 92
    assert len(params_1[1]) == len(params_2[1]["params"]) == 82
    assert len(params_3) == 1, len(params_3[0]["params"]) == 174
    assert params_2[0]["weight_decay"] == 0.1, params_2[1]["weight_decay"] == 0.2
    assert params_3[0]["weight_decay"] == 0.2
    @pytest.mark.parametrize("norm_weight_decay", [None, 0.2])
    @pytest.mark.parametrize("norm_layer", [None, nn.LayerNorm])
    @pytest.mark.parametrize("custom_keys_weight_decay", [None, [("class_token", 0.3), ("pos_embedding", 0.4)]])
    def test_set_weight_decay(self, norm_weight_decay, norm_layer, custom_keys_weight_decay):
        model = models.VisionTransformer(
            image_size=224,
            patch_size=16,
            num_layers=1,
            num_heads=2,
            hidden_dim=8,
            mlp_dim=4,
        )
        param_groups = ops._utils.set_weight_decay(
            model,
            0.1,
            norm_weight_decay=norm_weight_decay,
            norm_classes=None if norm_layer is None else [norm_layer],
            custom_keys_weight_decay=custom_keys_weight_decay,
        )

        if norm_weight_decay is None and custom_keys_weight_decay is None:
            assert len(param_groups) == 1
            assert len(param_groups[0]["params"]) == 20

        if norm_weight_decay is not None and custom_keys_weight_decay is None:
            assert len(param_groups) == 2
            param_groups.sort(key=lambda x: x["weight_decay"])
            assert len(param_groups[0]["params"]) == 14
            assert len(param_groups[1]["params"]) == 6

        if norm_weight_decay is not None and custom_keys_weight_decay is not None:
            assert len(param_groups) == 4
            param_groups.sort(key=lambda x: x["weight_decay"])
            assert len(param_groups[0]["params"]) == 12
            assert len(param_groups[1]["params"]) == 6
            assert len(param_groups[2]["params"]) == 1
            assert len(param_groups[3]["params"]) == 1

references

@facebook-github-bot
Copy link

facebook-github-bot commented Mar 24, 2022

💊 CI failures summary and remediations

As of commit 755ccd5 (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See CircleCI build unittest_windows_cpu_py3.9 (1/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

test/test_utils.py::test_flow_to_image_errors[i...Flow should be of dtype torch.float] PASSED [ 98%]
test/test_utils.py::test_draw_keypoints_colored[red] PASSED              [ 98%]
test/test_utils.py::test_draw_keypoints_colored[#FF00FF] PASSED          [ 98%]
test/test_utils.py::test_draw_keypoints_colored[colors2] PASSED          [ 98%]
test/test_utils.py::test_draw_keypoints_errors PASSED                    [ 98%]
test/test_utils.py::test_flow_to_image[True] PASSED                      [ 98%]
test/test_utils.py::test_flow_to_image[False] PASSED                     [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow0-Input flow should have shape] PASSED [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow1-Input flow should have shape] PASSED [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow2-Input flow should have shape] PASSED [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow3-Input flow should have shape] PASSED [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow4-Flow should be of dtype torch.float] PASSED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[RATRACE_wave_f_nm_np1_fr_goo_37.avi] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[TrumanShow_wave_f_nm_np1_fr_med_26.avi] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[v_SoccerJuggling_g23_c01.avi] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[v_SoccerJuggling_g24_c01.avi] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[R6llTwEh07w.mp4] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[SOX5yA1l24A.mp4] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[WUzgd7C1pWA.mp4] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_seek_reading[C:\\Users\\circleci\\project\\test\\assets\\videos\\v_SoccerJuggling_g23_c01.avi-8.0-True] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_seek_reading[C:\\Users\\circleci\\project\\test\\assets\\videos\\v_SoccerJuggling_g23_c01.avi-8.0-False] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_seek_reading[C:\\Users\\circleci\\project\\test\\assets\\videos\\v_SoccerJuggling_g24_c01.avi-8.0-True] SKIPPED [ 98%]

See CircleCI build unittest_windows_cpu_py3.8 (2/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

test/test_utils.py::test_flow_to_image_errors[i...Flow should be of dtype torch.float] PASSED [ 98%]
test/test_utils.py::test_draw_keypoints_colored[red] PASSED              [ 98%]
test/test_utils.py::test_draw_keypoints_colored[#FF00FF] PASSED          [ 98%]
test/test_utils.py::test_draw_keypoints_colored[colors2] PASSED          [ 98%]
test/test_utils.py::test_draw_keypoints_errors PASSED                    [ 98%]
test/test_utils.py::test_flow_to_image[True] PASSED                      [ 98%]
test/test_utils.py::test_flow_to_image[False] PASSED                     [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow0-Input flow should have shape] PASSED [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow1-Input flow should have shape] PASSED [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow2-Input flow should have shape] PASSED [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow3-Input flow should have shape] PASSED [ 98%]
test/test_utils.py::test_flow_to_image_errors[input_flow4-Flow should be of dtype torch.float] PASSED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[RATRACE_wave_f_nm_np1_fr_goo_37.avi] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[TrumanShow_wave_f_nm_np1_fr_med_26.avi] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[v_SoccerJuggling_g23_c01.avi] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[v_SoccerJuggling_g24_c01.avi] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[R6llTwEh07w.mp4] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[SOX5yA1l24A.mp4] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_frame_reading[WUzgd7C1pWA.mp4] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_seek_reading[C:\\Users\\circleci\\project\\test\\assets\\videos\\v_SoccerJuggling_g23_c01.avi-8.0-True] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_seek_reading[C:\\Users\\circleci\\project\\test\\assets\\videos\\v_SoccerJuggling_g23_c01.avi-8.0-False] SKIPPED [ 98%]
test/test_video_gpu_decoder.py::TestVideoGPUDecoder::test_seek_reading[C:\\Users\\circleci\\project\\test\\assets\\videos\\v_SoccerJuggling_g24_c01.avi-8.0-True] SKIPPED [ 98%]

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@xiaohu2015 xiaohu2015 marked this pull request as draft March 24, 2022 11:51
@xiaohu2015 xiaohu2015 marked this pull request as ready for review March 27, 2022 07:08
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohu2015 Thanks for the contribution.

I've added a few comments, let me know your thoughts.

references/classification/train.py Outdated Show resolved Hide resolved
torchvision/ops/_utils.py Outdated Show resolved Hide resolved
torchvision/ops/_utils.py Outdated Show resolved Hide resolved
torchvision/ops/_utils.py Outdated Show resolved Hide resolved
torchvision/ops/_utils.py Outdated Show resolved Hide resolved
torchvision/ops/_utils.py Outdated Show resolved Hide resolved
torchvision/ops/_utils.py Outdated Show resolved Hide resolved
references/classification/train.py Outdated Show resolved Hide resolved
references/classification/train.py Outdated Show resolved Hide resolved
torchvision/ops/_utils.py Outdated Show resolved Hide resolved
@datumbox
Copy link
Contributor

@xiaohu2015 LGTM, thanks for the awesome contributions as always!

As discussed offline, this approach is a great step towards a generic solution. There are still some cons such as the fact that the custom_key logic is applied always prior the norm_layers but we can investigate alternative more flexible solutions on separate PRs. Since more investigation will be done, we can move this back to references after all the tests on this PR pass.

@datumbox
Copy link
Contributor

LGTM and the test pass. We can move it to references and merge. :)

help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
)
parser.add_argument(
"--transformer-weight-decay",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jdsgomes Any thoughts on this name?

This option controls the weight decay for parameters related to position encodings and embeddings that are usually found in Transformer models. We will eventually revise this utility, so this is temporary to unblock the work on Swin but I was hoping if you could offer a better alternative. Line 236 on this PR will also offer clarifications over what this does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only alternative I can think of is to introduce a generic weight decay argument with pairs of ([keys], weight decay) and then deprecate bias_weight_decay.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will look a bit ugly in terms of implementation (at least the way I can think of using argparse) but might make more sense for the users and be more generic

Copy link
Contributor

@datumbox datumbox Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, true. Another idea that @xiaohu2015 suggested was to pass directly the custom_keys_weight_decay from argparse. The reason that for now we avoided this was because it exposes too much info on the internals of the utility. An alternative approach for the util is to accept lambdas that allow users to select any module they want on arbitrary conditions. We discussed offline to explore similar approaches and see if we can create something generic to be added in ops eventually. But in the meantime, I didn't want to block your work on Swin and that's why we went with this workaround.

I wonder if there is at least a better naming convention for what this flag does. Do you think it would be accurate to name it --positional-param-decay or something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the class_token I would not call it positional-param-decay so I would keep the original one for lack of better alternative.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After an online chat with @jdsgomes, we decided to rename to --transformer-embedding-decay. This will do for now to unblock Swin and we can review soon.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohu2015 Thanks! LGTM!

FYI I fixed a conflict with master. I'll wait for the input from Joao but I'm happy to make for you inplace any name changes to avoid back and forth.

@datumbox datumbox merged commit 3925946 into pytorch:main Apr 1, 2022
@xiaohu2015 xiaohu2015 deleted the patch-1 branch April 2, 2022 02:04
facebook-github-bot pushed a commit that referenced this pull request Apr 6, 2022
…5671)

Summary:
* add set_weight_decay

* Update _utils.py

* refactor code

* fix import

* add set_weight_decay

* fix lint

* fix lint

* replace split_normalization_params with set_weight_decay

* simplfy the code

* refactor code

* refactor code

* fix lint

* remove unused

* Update test_ops.py

* Update train.py

* Update _utils.py

* Update train.py

* add set_weight_decay

* add set_weight_decay

* Update _utils.py

* Update test_ops.py

* Change `--transformer-weight-decay` to `--transformer-embedding-decay`

Reviewed By: NicolasHug

Differential Revision: D35393158

fbshipit-source-id: 625eec0edf01864c38c1fd826d2e8bf256e2e879

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants