Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check sha256 of weights #7219

Merged
merged 9 commits into from
May 3, 2023
Merged

Check sha256 of weights #7219

merged 9 commits into from
May 3, 2023

Conversation

adamjstewart
Copy link
Contributor

@adamjstewart adamjstewart commented Feb 10, 2023

@NicolasHug @nps1ngh @pmeier would this be sufficient to close #7210? This should be fully backwards compatible in case someone wants to use weights without a sha256 hash suffix. We'll rename all of our weights in TorchGeo to add the suffix.

Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it up to @NicolasHug to review this with respect to torch.hub, since he is the specialist here. From my naive understanding, this should be an innocent addition that could improve security.

@adamjstewart
Copy link
Contributor Author

Ping @NicolasHug

@NicolasHug
Copy link
Member

After the release please @adamjstewart :) #7210 (comment)

@adamjstewart
Copy link
Contributor Author

Would love to get this into the release but if it's past the feature freeze it can wait.

@pmeier
Copy link
Collaborator

pmeier commented Feb 21, 2023

@adamjstewart Branch cut was last Friday, so it won't make it into the next release.

@adamjstewart
Copy link
Contributor Author

@NicolasHug congrats on the new release! I know there's always fallout from new releases, but once that settles down, a gentle reminder to review this PR. I rebased so we can test it against the latest version.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7219

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 28 New Failures, 2 Unrelated Failures

As of commit c9a1b8a:

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base 8811c91:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@adamjstewart
Copy link
Contributor Author

Ping @NicolasHug

@adamjstewart
Copy link
Contributor Author

Ping ping @NicolasHug

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping and for your patience @adamjstewart .

PR looks good, I agree it should be enough to enable sha checks on torchgeo's end.

Before mergning, do you mind updating the rest of the weights.get_state_dict(...) calls to use check_hash everywhere?

The remaining ones should be these:

(pt) ➜  vision git:(weights/hash) ✗ git grep "get_state_dict(pro.*ss)"
torchvision/models/detection/faster_rcnn.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/faster_rcnn.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/faster_rcnn.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/fcos.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/keypoint_rcnn.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/mask_rcnn.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/mask_rcnn.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/retinanet.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/retinanet.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/ssd.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/detection/ssdlite.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/optical_flow/raft.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/quantization/googlenet.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/quantization/inception.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/quantization/mobilenetv2.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/quantization/mobilenetv3.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/quantization/resnet.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/quantization/shufflenetv2.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/segmentation/deeplabv3.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/segmentation/deeplabv3.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/segmentation/deeplabv3.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/segmentation/fcn.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/segmentation/fcn.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/segmentation/lraspp.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/video/mvit.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/video/resnet.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/video/s3d.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/models/video/swin_transformer.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/prototype/models/depth/stereo/crestereo.py:        model.load_state_dict(weights.get_state_dict(progress=progress))
torchvision/prototype/models/depth/stereo/raft_stereo.py:        model.load_state_dict(weights.get_state_dict(progress=progress))

If that makes it easier for you, feel free to just revert the changes you made to the calls to get_state_dict() in this PR and we can address them all in a later PR.

@adamjstewart
Copy link
Contributor Author

@NicolasHug I believe I got all the remaining ones, let me know if I missed anything.

@pmeier pmeier requested a review from NicolasHug April 28, 2023 19:37
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @adamjstewart , the PR LGTM. Our CI is a bit toasted right now so I'll try to merge it some time later. Nothing more to do on your side though :)

Thanks!

@adamjstewart adamjstewart deleted the weights/hash branch May 3, 2023 12:47
facebook-github-bot pushed a commit that referenced this pull request May 15, 2023
Summary: Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>

Reviewed By: vmoens

Differential Revision: D45523936

fbshipit-source-id: 3febf3a0f410cc1af38cfac91d18c7e83213bd4f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

WeightsEnum: use checksums
4 participants