Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding EfficientNetV2 architecture #5450

Merged
merged 18 commits into from
Mar 2, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Feb 21, 2022

Related to #2707

Adds EfficientNetV2 implementation on the existing EfficientNet class.

This PR is influenced by earlier work done by @xiaohu2015 at #4910

@facebook-github-bot
Copy link

facebook-github-bot commented Feb 21, 2022

💊 CI failures summary and remediations

As of commit 364da8f (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me so far. Just left a small nit

torchvision/models/efficientnet.py Show resolved Hide resolved
@xiaohu2015
Copy link
Contributor

Related to #2707

Adds EfficientNetV2 implementation on the existing EfficientNet class.

This PR is influenced by earlier work done by @xiaohu2015 at #4910

Do you have checked the accuray of the converted TF weights. In TF, they used a different BN parameter:
norm_layer = partial(nn.BatchNorm2d, eps=1e-03)
this difference affect the accuracy, another gap maybe lay in the padding mode.

@datumbox
Copy link
Contributor Author

datumbox commented Feb 22, 2022

@xiaohu2015 Yes. You get a hit of about 0.3 on the Small model (Acc@1 83.602 Acc@5 96.556).

The reason why I haven't overwritten the BN configuration is because we plan to train it from scratch rather than using the TF weights. There are a few techniques used on the paper that are not supported by our reference scripts (like progressive learning) so depending on how close the delta from the training will be we might try to close this gap by making the proposed patch.

Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just have a small nit

torchvision/models/efficientnet.py Show resolved Hide resolved
@xiaohu2015
Copy link
Contributor

@xiaohu2015 Yes. You get a hit of about 0.3 on the Small model (Acc@1 83.602 Acc@5 96.556).

The reason why I haven't overwritten the BN configuration is because we plan to train it from scratch rather than using the TF weights. There are a few techniques used on the paper that are not supported by our reference scripts (like progressive learning) so depending on how close the delta from the training will be we might try to close this gap by making the proposed patch.

I have a question: you got 83.1 top-1 acc using TF weights, is this result tested on nn.BatchNorm2d, eps=1e-03?

@datumbox
Copy link
Contributor Author

@xiaohu2015 Yes correct. The accuracy you see on the source code (83.1) is with the TF weights and without the BN patch. After applying the BN patch, we can reach 83.6. The reason why the BN patch is not on the source code is because I'm currently training the model from scratch and I want to see if that's necessary. Note that the ported TF weights are just added for my convenience to be able to run some of the tests and for checking that nothing is fundamentally broken with the implementation.

@xiaohu2015
Copy link
Contributor

@xiaohu2015 Yes correct. The accuracy you see on the source code (83.1) is with the TF weights and without the BN patch. After applying the BN patch, we can reach 83.6. The reason why the BN patch is not on the source code is because I'm currently training the model from scratch and I want to see if that's necessary. Note that the ported TF weights are just added for my convenience to be able to run some of the tests and for checking that nothing is fundamentally broken with the implementation.

thanks. because I used the timm weights (converted from TF), but I cannot get such accuacy even with the BN patch (about 83.0%), maybe I missed some things.

@datumbox
Copy link
Contributor Author

datumbox commented Feb 27, 2022

@xiaohu2015 I've just replaced the weights for the Small variant with some trained from scratch using TorchVision's recipe. We can do better than the paper by ~0.3 points:

torchrun --nproc_per_node=1 train.py --test-only --prototype --weights EfficientNet_V2_S_Weights.IMAGENET1K_V1 --model efficientnet_v2_s -b 1
Acc@1 84.228 Acc@5 96.878

The above means that we don't have to implement TF specific tricks to reproduce the paper, which massively simplifies our code.

Here are the results from training medium from scratch:

gpurun torchrun --nproc_per_node=1 train.py --test-only --prototype --weights EfficientNet_V2_M_Weights.IMAGENET1K_V1 --model efficientnet_v2_m -b 1
Acc@1 85.112 Acc@5 97.156

And here is Large ported from the paper:

torchrun --nproc_per_node=1 train.py --test-only --prototype --weights EfficientNet_V2_S_Weights.IMAGENET1K_V1 --model efficientnet_v2_l
Acc@1 85.810 Acc@5 97.792

Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thank you!

@datumbox datumbox changed the title [WIP] Adding EfficientNetV2 architecture Adding EfficientNetV2 architecture Mar 2, 2022
@datumbox datumbox merged commit e6d82f7 into pytorch:main Mar 2, 2022
@datumbox datumbox deleted the models/efficientnet_v2 branch March 2, 2022 12:37
facebook-github-bot pushed a commit that referenced this pull request Mar 15, 2022
Summary:
* Extend the EfficientNet class to support v1 and v2.

* Refactor config/builder methods and add prototype builders

* Refactoring weight info.

* Update dropouts based on TF config ref

* Update BN eps on TF base_config

* Use Conv2dNormActivation.

* Adding pre-trained weights for EfficientNetV2-s

* Add Medium and Large weights

* Update stats with single batch run.

* Add accuracies in the docs.

Reviewed By: vmoens

Differential Revision: D34878984

fbshipit-source-id: 1f771dc1173dcdcf21391fb01dfa79d7c3608c5f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants