Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow autocast for 1.6 #2384

Merged
merged 13 commits into from Jul 9, 2020
Merged

[WIP] Allow autocast for 1.6 #2384

merged 13 commits into from Jul 9, 2020

Conversation

mcarilli
Copy link
Contributor

@mcarilli mcarilli commented Jul 2, 2020

This PR introduces a minimal set of diffs that enable torchvision models to work with torch.cuda.amp.autocast. Custom C++ ops (roi_align and nms) require the most attention.

In later modifications to Pytorch, I'll allow external libs to use Pytorch's internal autocast utilities, after which this code can be made cleaner. For 1.6, however, copy pasting some utilities is the best we can do.

Should close pytorch/pytorch#37735.

@mcarilli mcarilli changed the title Allow autocast for 1.6 [WIP] Allow autocast for 1.6 Jul 3, 2020
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Can you add some full model-forward test in a similar location as

vision/test/test_models.py

Lines 273 to 293 in 5247f7b

@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.cuda()
model.eval()
input_shape = (3, 300, 300)
x = torch.rand(input_shape, device='cuda')
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
# now switch to cpu and make sure it works
model.cpu()
x = x.cpu()
out_cpu = model([x])
self.assertTrue("boxes" in out_cpu[0])
self.assertTrue("scores" in out_cpu[0])
self.assertTrue("labels" in out_cpu[0])

and also add a forward-backward test for roi_align after

vision/test/test_ops.py

Lines 290 to 291 in 5247f7b

def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)
(and we will clean it up later on when adding support for the other ops to autocast)

torchvision/ops/poolers.py Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Jul 3, 2020

Codecov Report

Merging #2384 into master will increase coverage by 0.70%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2384      +/-   ##
==========================================
+ Coverage   70.65%   71.36%   +0.70%     
==========================================
  Files          94       94              
  Lines        7897     8328     +431     
  Branches     1241     1385     +144     
==========================================
+ Hits         5580     5943     +363     
- Misses       1934     1972      +38     
- Partials      383      413      +30     
Impacted Files Coverage Δ
torchvision/ops/poolers.py 97.02% <100.00%> (ø)
torchvision/io/image.py 71.73% <0.00%> (+0.77%) ⬆️
torchvision/transforms/functional_tensor.py 65.41% <0.00%> (+0.83%) ⬆️
torchvision/transforms/functional.py 81.91% <0.00%> (+2.09%) ⬆️
torchvision/transforms/transforms.py 78.25% <0.00%> (+2.23%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 86b6c3e...f3451c9. Read the comment docs.

torchvision/csrc/ROIAlign.h Outdated Show resolved Hide resolved
@fmassa
Copy link
Member

fmassa commented Jul 7, 2020

Test failures seem related

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Michael!

@fmassa fmassa merged commit 0a8586c into pytorch:master Jul 9, 2020
fmassa pushed a commit to fmassa/vision-1 that referenced this pull request Jul 9, 2020
* Fixes Xiao's repro

* Ports nms to use full dispatcher

* Move HIPGuard to nms_cuda

* clang-format

* run models in test_models.py on GPU if available

* Francisco's comment, also disable cuda model tests to see if CPU alone still passes

* cuda tests now pass locally, although still not comparing to saved numerics

* add note for thing to ask francisco

* Allow cuda and cpu tests to share a data file

* ignore suffix if unneeded

* Skip autocast numerics checks for a few models

* Add roi_align test

Co-authored-by: Michael Carilli <mcarilli@nvidia.com>
fmassa added a commit that referenced this pull request Jul 9, 2020
* Fixes Xiao's repro

* Ports nms to use full dispatcher

* Move HIPGuard to nms_cuda

* clang-format

* run models in test_models.py on GPU if available

* Francisco's comment, also disable cuda model tests to see if CPU alone still passes

* cuda tests now pass locally, although still not comparing to saved numerics

* add note for thing to ask francisco

* Allow cuda and cpu tests to share a data file

* ignore suffix if unneeded

* Skip autocast numerics checks for a few models

* Add roi_align test

Co-authored-by: Michael Carilli <mcarilli@nvidia.com>

Co-authored-by: mcarilli <mcarilli@gmail.com>
Co-authored-by: Michael Carilli <mcarilli@nvidia.com>
@fmassa fmassa mentioned this pull request Sep 1, 2020
@fmassa fmassa mentioned this pull request Oct 13, 2020
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.cuda.amp.autocast not working with torchvision.models.detection.maskrcnn
3 participants