Skip to content

Commit

Permalink
Avoid NaN values in torch.cdist backward for p<1
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Oct 2, 2020
1 parent 6acd7b6 commit 476adff
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
6 changes: 5 additions & 1 deletion aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
Expand Up @@ -104,7 +104,11 @@ struct Dist {

// Special general pnorm derivative if p is less than two
struct lttdist_calc {
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); }
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) {
Vec result = (dist == 0.0) ? Vec(0) : (sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)));
result = Vec::blendv(result, Vec(0), (diff == Vec(0)) & (p < Vec(1)));
return result;
}
};

// Two norm
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/cuda/DistanceKernel.cu
Expand Up @@ -50,7 +50,9 @@ struct dists {

// Special case backward when p is less than two
struct lt_two {
static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1); }
static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) {
return (dist == 0.0 || (diff == 0.0 && p < 1)) ? 0 : (sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1));
}
};

// Two norm
Expand Down
12 changes: 12 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6112,6 +6112,18 @@ def _test_euclidean_large_cdist(sizex, sizey=None):
_test_cdist_for_size((1, 1), (S, 1))
_test_euclidean_large_cdist((2000, 5))

# Ensure that cdist backward with p<1 does not produce NaNs
def test_cdist_grad_p_lt_1_no_nan(self, device):
for p in [0.99, 0.7, 0.5, 0.1, 0.01]:
x = torch.randn(1, 2, device=device)
y = x.clone().detach() + torch.tensor([[1., 0.]], device=device)
x.requires_grad = True
y.requires_grad = True
result = torch.cdist(x, y, p=p)
result.backward(torch.ones_like(result))
self.assertFalse(torch.isnan(x.grad).any())
self.assertFalse(torch.isnan(y.grad).any())

def test_cdist_same_inputs(self, device):
# Test to detect issues in cdist gradient calculation
# When the distances are 0
Expand Down

0 comments on commit 476adff

Please sign in to comment.