From 476adff9ceb9e1e863af4da647d5780f6a49ed99 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 1 Oct 2020 21:26:55 -0500 Subject: [PATCH] Avoid NaN values in torch.cdist backward for p<1 --- aten/src/ATen/native/cpu/DistanceOpsKernel.cpp | 6 +++++- aten/src/ATen/native/cuda/DistanceKernel.cu | 4 +++- test/test_autograd.py | 12 ++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp index 114ca93dae26..34911a2975e4 100644 --- a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp @@ -104,7 +104,11 @@ struct Dist { // Special general pnorm derivative if p is less than two struct lttdist_calc { - static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); } + static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { + Vec result = (dist == 0.0) ? Vec(0) : (sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1))); + result = Vec::blendv(result, Vec(0), (diff == Vec(0)) & (p < Vec(1))); + return result; + } }; // Two norm diff --git a/aten/src/ATen/native/cuda/DistanceKernel.cu b/aten/src/ATen/native/cuda/DistanceKernel.cu index 385cac5c79e8..c43a2ae9877e 100644 --- a/aten/src/ATen/native/cuda/DistanceKernel.cu +++ b/aten/src/ATen/native/cuda/DistanceKernel.cu @@ -50,7 +50,9 @@ struct dists { // Special case backward when p is less than two struct lt_two { - static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1); } + static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { + return (dist == 0.0 || (diff == 0.0 && p < 1)) ? 0 : (sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1)); + } }; // Two norm diff --git a/test/test_autograd.py b/test/test_autograd.py index d6661b4662fe..6f5734e2ceb3 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6112,6 +6112,18 @@ def _test_euclidean_large_cdist(sizex, sizey=None): _test_cdist_for_size((1, 1), (S, 1)) _test_euclidean_large_cdist((2000, 5)) + # Ensure that cdist backward with p<1 does not produce NaNs + def test_cdist_grad_p_lt_1_no_nan(self, device): + for p in [0.99, 0.7, 0.5, 0.1, 0.01]: + x = torch.randn(1, 2, device=device) + y = x.clone().detach() + torch.tensor([[1., 0.]], device=device) + x.requires_grad = True + y.requires_grad = True + result = torch.cdist(x, y, p=p) + result.backward(torch.ones_like(result)) + self.assertFalse(torch.isnan(x.grad).any()) + self.assertFalse(torch.isnan(y.grad).any()) + def test_cdist_same_inputs(self, device): # Test to detect issues in cdist gradient calculation # When the distances are 0