Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid NaN values in torch.cdist backward for p<1 #45720

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
Expand Up @@ -104,7 +104,11 @@ struct Dist {

// Special general pnorm derivative if p is less than two
struct lttdist_calc {
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); }
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) {
Vec result = (dist == 0.0) ? Vec(0) : (sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)));
result = Vec::blendv(result, Vec(0), (diff == Vec(0)) & (p < Vec(1)));
return result;
}
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
};

// Two norm
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/cuda/DistanceKernel.cu
Expand Up @@ -50,7 +50,9 @@ struct dists {

// Special case backward when p is less than two
struct lt_two {
static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1); }
static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) {
return (dist == 0.0 || (diff == 0.0 && p < 1)) ? 0 : (sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1));
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
}
};

// Two norm
Expand Down
12 changes: 12 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6112,6 +6112,18 @@ def _test_euclidean_large_cdist(sizex, sizey=None):
_test_cdist_for_size((1, 1), (S, 1))
_test_euclidean_large_cdist((2000, 5))

# Ensure that cdist backward with p<1 does not produce NaNs
def test_cdist_grad_p_lt_1_no_nan(self, device):
for p in [0.99, 0.7, 0.5, 0.1, 0.01]:
x = torch.randn(1, 2, device=device)
y = x.clone().detach() + torch.tensor([[1., 0.]], device=device)
x.requires_grad = True
y.requires_grad = True
result = torch.cdist(x, y, p=p)
result.backward(torch.ones_like(result))
self.assertFalse(torch.isnan(x.grad).any())
self.assertFalse(torch.isnan(y.grad).any())

def test_cdist_same_inputs(self, device):
# Test to detect issues in cdist gradient calculation
# When the distances are 0
Expand Down