Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for MNASNet #1224

Merged
merged 44 commits into from
Sep 23, 2019
Merged

Bugfix for MNASNet #1224

merged 44 commits into from
Sep 23, 2019

Conversation

1e100
Copy link
Contributor

@1e100 1e100 commented Aug 10, 2019

The original implementation I submitted contained a bug which affects all MNASNet variants other than 1.0. The bug is that the first few layers need to also be scaled in terms of width multiplier, along with all the rest. This fixes the issue, and brings the implementation fully in sync with Google's TPU reference code. I have compared the ONNX dump of this model against TFLite's hosted model and ensured that all layer configurations line up exactly.

Because only MNASNet 0.5 checkpoint was affected, I have also trained a slightly better checkpoint for it. I was unable to train this to the same accuracy with Torchvision's reference training code (and it wasn't for the lack of trying), and had to use label smoothing and EMA to get this result. The final checkpoint is derived from EMA.

Even so, the accuracy is a bit lower than Google's result (67.83 for this model vs 68.03 TFLite). The posted number for TF TPU implementation is 68.9, but that model uses SE on some of its layers, which my implementation does not.

@1e100
Copy link
Contributor Author

1e100 commented Aug 12, 2019

Sure, we can make it backward compatible. Of the two proposed solutions I like solution 1 better (fewer moving parts), but I'll implement whatever we decide here.

Maybe I should take this opportunity and add squeeze and excitation as well. Authors use it in some of the larger variants of the model. Accuracy will go up a bit if I do that, and then we won't have to version it again in the future.

Is there a release schedule BTW? I wanted this fix to be ready before the release (models unfortunately take forever to train), and noticed the release had been cut a few days ago, so the PR didn't make it.

@fmassa
Copy link
Member

fmassa commented Aug 28, 2019

Hi @1e100 ,

I just got back from holidays, sorry for the delay in replying.

Sure, we can make it backward compatible. Of the two proposed solutions I like solution 1 better (fewer moving parts), but I'll implement whatever we decide here.

I'm not yet clear on the best solution. I feel that this is something that needs to be carefully considered, because model versioning is going to be a big topic.

Maybe I should take this opportunity and add squeeze and excitation as well.

I feel that this should be sent in a separate PR.

Is there a release schedule BTW?

We will be cutting a new release of torchvision in the next 2-3 weeks, with minor fixes and improvements.

I'm also tagging @ailzhang for handling BC-breaking changes within hub, and @cpuhrsch @vincentqb and @zhangguanheng66 for torchaudio and torchtext model versioning in the future.

@soumith
Copy link
Member

soumith commented Aug 28, 2019

using _version counter like BatchNorm makes a lot of sense to me.

@zhangguanheng66
Copy link
Contributor

zhangguanheng66 commented Aug 28, 2019

I had to handle the BC breaking for nn.MultiheadAttention. To extend the capability of the module, I add four extra attributes in the module. We used the hasattr function to check the existing of the attribute, and it seems fine. Users are care-free. _version counter is another way and more common, I believe.

For the second option, we have to instruct users to use _load_from_state_dict, which is not hard but needs to be clear in the docs.

@1e100
Copy link
Contributor Author

1e100 commented Sep 7, 2019

OK, had some time today to look into this, aiming to get it done over the weekend.

Basically, it seems that the following simple logic would satisfy the backward compat requirements:

  1. Add _version == 2 field
  2. When model is loaded using _load_from_state_dict(), if dict contains version 1 (default), modify the network definition such that the stem is configured as it was configured in the initial version and perhaps show a warning as in MultiHeadAttention, then set MNASNet._version = 1 turning the model into a de-facto previous version model. If version == 2 is found, do nothing and simply load the new checkpoint.

Seems like a pretty straightforward fix to me.

@1e100
Copy link
Contributor Author

1e100 commented Sep 7, 2019

OK, @fmassa, here's the first cut of the requested changes. CI seems to be failing on something CUDA-related. Let me know if this is what you had in mind!

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking very good, I like it!

I have a few more comments, let me know what you think.

Also, I am thinking that we might need to still add an option somewhere (maybe in the mnasnet_0_5 function), something that initially raises a warning if the user doesn't pass an argument `, saying that the default behavior will change in a new version, so that we don't break BC right away for the users?

self.layers[idx] = layer

# The model is now identical to v1, and must be saved as such.
MNASNet._version = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modifies all instances of MNASNet, and not the one being called.
This could have some unexpected effects, maybe you meant to do instead something like self._version?

Copy link
Contributor Author

@1e100 1e100 Sep 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

D'oh! You're right. Changed, and verified it works with this code:

#!/usr/bin/env python3

import torch
import torchvision

# NOTE: v1 checkpoint
ckpt = torch.load("mnasnet0.5_top1_67.592-7c6cb539b9.pth")
m = torchvision.models.MNASNet(0.5)
m.load_state_dict(ckpt)
print("Loaded old")
torch.save(m.state_dict(), "resaved.pth")
print("Re-saved")
ckpt = torch.load("resaved.pth")
m = torchvision.models.MNASNet(0.5)
m.load_state_dict(ckpt)
print("Re-loaded")

@@ -139,16 +149,58 @@ def _initialize_weights(self):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.01)
nn.init.kaiming_uniform_(m.weight, mode="fan_out",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the initialization scheme, does this yield better performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may have very slightly improved the top1 on MNASNet b1 0.5 that I trained for this PR, but I'm not sure. The purpose of the change is that initialization is now identical to the reference TensorFlow code (which also uses a variance scaling initializer aka Kaiming uniform). Certainly not worse than before.

torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
@1e100
Copy link
Contributor Author

1e100 commented Sep 14, 2019

@fmassa I've addressed your feedback, PTAL

@fmassa
Copy link
Member

fmassa commented Sep 19, 2019

Sorry for the delay in replying, I've made a few more comments

Remove unused member var as per review.
Copy link
Contributor Author

@1e100 1e100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've addressed the feedback, PTAL

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

I'll upload the weights and update the PR

@fmassa
Copy link
Member

fmassa commented Sep 20, 2019

@1e100 I couldn't push the updated path without a force push on your master branch.

Can you update the link in the PR to

https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth

And let me know?

@1e100
Copy link
Contributor Author

1e100 commented Sep 20, 2019

@fmassa done!

@fmassa fmassa merged commit 367e851 into pytorch:master Sep 23, 2019
@fmassa
Copy link
Member

fmassa commented Sep 23, 2019

Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants