Skip to content

Commit 8f658d7

Browse files
authored
don't produce invalid grid configs (#166973) (#167158)
Proper fix for #164048, fixes gather too, reverts #164049 Pull Request resolved: #166974 Approved by: https://github.com/eqy
1 parent 3d27d95 commit 8f658d7

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

aten/src/ATen/native/cuda/IndexKernel.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
7373

7474
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
7575
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
76-
7776
if (is_gather_like && num_indices==1) {
7877
const size_t element_size = iter.element_size(0);
7978
constexpr size_t alignment = 16;
@@ -83,11 +82,10 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
8382
auto ind_dim_size = index_size[0];
8483
auto inp_stride_bytes = index_stride[0];
8584
auto out_stride_bytes = iter.strides(0)[1];
86-
if (iter.numel() == 0) return;
8785
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
8886
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
8987
return;
90-
}
88+
}
9189
}
9290

9391
auto sizes = std::array<int64_t, MAX_DIMS>{};

aten/src/ATen/native/cuda/IndexKernelUtils.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
1414
ind = (ind < 0) ? ind + ind_dim_size : ind;
1515
}
1616
CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
17-
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
18-
if (off >= slice_size) return;
19-
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
20-
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
17+
// off is guaranteed to be within int32 limits
18+
for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) {
19+
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
20+
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
21+
}
2122
}
2223

2324

@@ -30,7 +31,9 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int
3031
auto num_threads = at::round_up(
3132
at::ceil_div(slice_size_in_bytes, Alignment),
3233
static_cast<int64_t>(C10_WARP_SIZE));
33-
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
34+
uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
35+
grid_y = std::min(static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y);
36+
dim3 grid = {static_cast<uint32_t>(num_ind), grid_y, 1};
3437
auto block = std::min(max_num_threads, num_threads);
3538
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
3639
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);

test/test_scatter_gather_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from torch.testing import make_tensor
88
from torch.testing._internal.common_utils import \
9-
(parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM)
9+
(parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM, serialTest)
1010
from torch.testing._internal.common_device_type import \
1111
(instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA,
1212
toleranceOverride, tol,)
@@ -65,10 +65,12 @@ def test_gather(self, device, dtype):
6565
actual = torch.gather(src, 2, idx)
6666
self.assertEqual(actual, expected, atol=0, rtol=0)
6767

68+
@serialTest()
6869
@dtypes(torch.int8, torch.bfloat16)
6970
def test_gather_large(self, device, dtype):
7071
# test larger shapes to check vectorized implementation
71-
for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100)):
72+
for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100), (4, 4, 16384 * 8192)):
73+
torch.cuda.empty_cache()
7274
src = make_tensor((m, k), device=device, dtype=dtype)
7375
alloc0 = torch.empty(src.nelement() * 2, device=device, dtype=dtype)
7476
discontig = alloc0.view(m, 2 * k)[:, ::2].copy_(src)
@@ -111,6 +113,8 @@ def test_gather_large(self, device, dtype):
111113
self.assertEqual(res_ind, ref, atol=0, rtol=0)
112114
res_gather = torch.gather(misaligned1, dim=dim, index=ind)
113115
self.assertEqual(res_gather, ref, atol=0, rtol=0)
116+
del src, alloc0, alloc1, alloc2
117+
del discontig, misaligned, misaligned1
114118
# test gather along 1st dim that can accidentally trigger fast path
115119
# because due to index dimension in the gather dim being 1
116120
# an unexpected squashing in tensorIterator happens

0 commit comments

Comments
 (0)