Skip to content

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Aug 6, 2020

In clip_coordinates replace minimum(maximum(in)) composition with clamp_max(clamp_min(in))
Swap order of clamp_min operands to clamp NaNs in grid to 0

Fixes #42616

@malfet malfet requested review from ezyang and ssnl August 6, 2020 21:55
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@malfet malfet added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 6, 2020
@dr-ci
Copy link

dr-ci bot commented Aug 6, 2020

💊 CI failures summary and remediations

As of commit a58b4b2 (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build caffe2_onnx_main_py3_6_clang7_ubuntu16_04_build (1/2)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Aug 10 18:12:30 fatal: reference is not a tree: f015d698006c4a11be15b1ebb75b3b9bb317b914
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:376 
 
real	0m25.152s 
user	0m0.075s 
sys	0m0.053s 
Aug 10 18:11:29 ++ export BUILD_ENVIRONMENT=caffe2-onnx-main-py3.6-clang7-ubuntu16.04-build 
Aug 10 18:11:29 ++ BUILD_ENVIRONMENT=caffe2-onnx-main-py3.6-clang7-ubuntu16.04-build 
Aug 10 18:11:29 ++ git submodule sync 
Aug 10 18:11:29 ++ git submodule update -q --init --recursive 
Aug 10 18:12:30 fatal: reference is not a tree: f015d698006c4a11be15b1ebb75b3b9bb317b914 
Aug 10 18:12:32 Unable to checkout 'f015d698006c4a11be15b1ebb75b3b9bb317b914' in submodule path 'third_party/tensorpipe' 

See CircleCI build binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_build (2/2)

Step: "Checkout pytorch/builder repo" (full log | diagnosis details | 🔁 rerun)

fatal: reference is not a tree: f015d698006c4a11be15b1ebb75b3b9bb317b914
+ sleep 2 
+ git submodule update --init --recursive 
fatal: reference is not a tree: f015d698006c4a11be15b1ebb75b3b9bb317b914 
Unable to checkout 'f015d698006c4a11be15b1ebb75b3b9bb317b914' in submodule path 'third_party/tensorpipe' 
+ sleep 4 
+ git submodule update --init --recursive 
fatal: reference is not a tree: f015d698006c4a11be15b1ebb75b3b9bb317b914 
Unable to checkout 'f015d698006c4a11be15b1ebb75b3b9bb317b914' in submodule path 'third_party/tensorpipe' 
+ sleep 8 
+ git submodule update --init --recursive 
fatal: reference is not a tree: f015d698006c4a11be15b1ebb75b3b9bb317b914 
Unable to checkout 'f015d698006c4a11be15b1ebb75b3b9bb317b914' in submodule path 'third_party/tensorpipe' 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 16 times.

@ezyang
Copy link
Contributor

ezyang commented Aug 6, 2020

cc @emcastillo @ngimel

@malfet malfet added the module: NaNs and Infs Problems related to NaN and Inf handling in floating point label Aug 7, 2020
@malfet malfet requested a review from ngimel August 7, 2020 00:45
@malfet malfet force-pushed the malfet/fix-grid-sample-with-nans branch 2 times, most recently from 6a5ea42 to c9185c6 Compare August 7, 2020 16:37
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ssnl
Copy link
Collaborator

ssnl commented Aug 7, 2020

Wait. Isn't NaN propagation what we are supposed to do?

@ngimel
Copy link
Collaborator

ngimel commented Aug 7, 2020

clamp has to propagate nans, so we can't change it. I'm not sure what is expected behavior of grid_sample with nans in the grid - is it RuntimeError or silently ignoring them?

@ssnl
Copy link
Collaborator

ssnl commented Aug 7, 2020

I thought we decided to propagate NaN as much as we could. I suppose it is more important in grid_sample, which is often used as a layer/op in a neural network. Yet this PR stops propagation IIRC.

@ngimel
Copy link
Collaborator

ngimel commented Aug 7, 2020

But if there are nans in the grid you cannot reasonably propagate them, you can only runtimeError?

@ssnl
Copy link
Collaborator

ssnl commented Aug 7, 2020

Why can't we? If there is nan for a location in the grid, shouldn't the output for that location just be nan?

@malfet malfet force-pushed the malfet/fix-grid-sample-with-nans branch from c9185c6 to 3f3a793 Compare August 7, 2020 20:38
@malfet
Copy link
Contributor Author

malfet commented Aug 7, 2020

Propagating NaNs from grid into output would be quite expensive and before this patch would almost certainly result in segfault.
It's easy to add this logic, but it will unnecessarily slow down the execution of the operator in normal cases.

@malfet malfet force-pushed the malfet/fix-grid-sample-with-nans branch from 3f3a793 to 82437bf Compare August 7, 2020 20:58
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ssnl
Copy link
Collaborator

ssnl commented Aug 8, 2020

Hmm sure. Maybe worth saying so in comments then? The previous commit says something like "it is important to clamp NaN", which was unintuitive to me.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please clarify in the docs what happens for nan/inf values in the grid? Interestingly original issue was about very large value turning into inf, looks like it should have been handled before by minimum/maximum clipping?
This fix is for CPU only, what happens on CUDA?

@@ -203,7 +203,8 @@ struct ComputeLocationBase<scalar_t, /*align_corners=*/true> {
}

inline Vec clip_coordinates(const Vec &in) const {
return minimum(Vec(max_val), maximum(in, Vec(0)));
// Invert order of clamp_min operands in order to clamp Nans to zero
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

malfet added 2 commits August 10, 2020 11:03
In `clip_coordinates` replace `minimum(maximum(in))` composition with `clamp_max(clamp_min(in))` and swap order of `clamp_min` operands to clamp Nans in grid to 0
@malfet malfet force-pushed the malfet/fix-grid-sample-with-nans branch from 82437bf to a58b4b2 Compare August 10, 2020 18:09
@malfet malfet requested a review from apaszke as a code owner August 10, 2020 18:09
@malfet
Copy link
Contributor Author

malfet commented Aug 10, 2020

Can you please clarify in the docs what happens for nan/inf values in the grid?

Done

Interestingly original issue was about very large value turning into inf, looks like it should have been handled before by minimum/maximum clipping?

It was turning into nan in reflection mode, because inf - inf == nan.

This fix is for CPU only, what happens on CUDA?

Somehow it's already working like that on CUDA (perhaps GPU automatically turns nan into 0 when casting it to int). Extended the test to cover that functionality

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator

ngimel commented Aug 10, 2020

Great, can you add that to docs note too, that large values can be converted to nan due to overflows during internal operations on the grid and subsequent inf-inf ?

@xwang233
Copy link
Collaborator

FYI, the cuda fix is at #35506.

Perhaps we can merge the test in this PR with this and remove onlyCUDA?

pytorch/test/test_nn.py

Lines 10098 to 10139 in 3cf2551

@onlyCUDA
def test_grid_sample_large(self, device):
def issue_35202():
input_tensor = torch.rand(1, 1, 480, 640, dtype=torch.float, device=device, requires_grad=True)
coords = torch.tensor([[-10059144, 67680944], [67680944, 67680944]], dtype=torch.float, device=device)
coords = coords.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 1)
result = torch.nn.functional.grid_sample(input_tensor, coords)
self.assertEqual(result, torch.tensor([[[[0., 0.]]]], dtype=torch.float, device=device))
result.backward(torch.ones_like(result))
torch.cuda.synchronize()
issue_35202()
def issue_24823_1(dtype):
image = torch.arange(27, 0, -1, dtype=dtype, device=device).view(1, 1, 3, 3, 3)
image.requires_grad_()
grid = torch.nn.functional.affine_grid(
torch.tensor([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]], dtype=dtype, device=device),
(1, 1, 3, 3, 3))
grid[:, 1, 1, 1, 0] = float('inf')
result = torch.nn.functional.grid_sample(image, grid, padding_mode='zeros')
self.assertEqual(result, torch.tensor([[[[[27., 26., 25.], [24., 23., 22.], [21., 20., 19.]],
[[18., 17., 16.], [15., 0., 13.], [12., 11., 10.]],
[[9., 8., 7.], [6., 5., 4.], [3., 2., 1.]]]]],
device=device, dtype=dtype))
result.backward(torch.ones_like(result))
expected_grad = torch.ones_like(image)
expected_grad[0, 0, 1, 1, 1] = 0
self.assertTrue(torch.allclose(image.grad, expected_grad, atol=1e-3))
issue_24823_1(torch.half)
issue_24823_1(torch.float)
issue_24823_1(torch.double)
def issue_24823_2():
param = torch.tensor([[[-1.0e+20, 0.0, 0.0], [0.0, -1.0e+20, 0.0]]], dtype=torch.float, device=device)
img = torch.zeros((1, 1, 4, 4), dtype=torch.float, device=device, requires_grad=True)
grid = torch.nn.functional.affine_grid(param, img.size())
result = torch.nn.functional.grid_sample(img, grid)
self.assertEqual(result, torch.zeros(1, 1, 4, 4, device=device, dtype=torch.float))
result.backward(torch.ones_like(result))
torch.cuda.synchronize()
issue_24823_2()

(issue_24823_1(torch.half) can be removed since arange half is not supported on cpu)

@facebook-github-bot
Copy link
Contributor

@malfet merged this pull request in 3cf2551.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: NaNs and Infs Problems related to NaN and Inf handling in floating point module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Segfault in torch.nn.functional.grid_sample in reflection mode
7 participants