Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.cdist backward CUDA error due to illegal gridDim setting #51569

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 8 additions & 4 deletions aten/src/ATen/native/cuda/DistanceKernel.cu
Expand Up @@ -132,7 +132,7 @@ __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t
template <typename scalar_t, typename F>
__global__ static void cdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * x1, const scalar_t * x2, const scalar_t * dist, int64_t gs,
const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m, const int64_t count, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const int y = (blockIdx.y * gridDim.z + blockIdx.z) * blockDim.y + threadIdx.y;
const int init = blockIdx.x * blockDim.x + threadIdx.x;
if (y >= count || init >= m) {
return;
Expand Down Expand Up @@ -335,12 +335,16 @@ void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
const int block_x = 64;
const int block_y = 16;
const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
const int grid_y = (dist.numel() + block_y - 1) / block_y;

const dim3 grid(grid_x, grid_y);
const int64_t count = dist.numel();
const int64_t grid_temp = (count + block_y - 1) / block_y;

const int grid_y = (grid_temp - 1) / 65535 + 1;
const int grid_z = (grid_temp - 1) / grid_y + 1;

const dim3 grid(grid_x, grid_y, grid_z);
const dim3 block(block_x, block_y);

const int64_t count = dist.numel();
const int64_t r_size = r1 * r2;
const int64_t l1_size = r1 * m;
const int64_t l2_size = r2 * m;
Expand Down
23 changes: 23 additions & 0 deletions test/test_torch.py
Expand Up @@ -3819,6 +3819,29 @@ def test_cdist_norm_batch(self, device):
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)

@onlyCUDA
def test_cdist_cuda_backward(self, device):
for l1 in [1, 511, 513]:
for l2 in [1, 511, 513]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
x1 = torch.randn(4, l1, 32, device=device, requires_grad=True)
x2 = x1.clone().detach_().requires_grad_()
y1 = torch.randn(4, l2, 32, device=device, requires_grad=True)
y2 = y1.clone().detach_().requires_grad_()
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
z1 = torch.cdist(x1, y1, p=2, compute_mode=cm).mean()
z2 = self._brute_cdist(x2, y2, p=2).mean()
z1.backward()
z2.backward()
self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
else:
z1 = torch.cdist(x1, y1, p=p).mean()
z2 = self._brute_cdist(x2, y2, p=p).mean()
self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)

@tf32_on_and_off(0.005)
def test_cdist_large(self, device):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
Expand Down