Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cuda launch error in reflection_pad2d #56451

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 48 additions & 25 deletions aten/src/ATen/native/cuda/ReflectionPad.cu
Expand Up @@ -6,6 +6,7 @@
#include <ATen/Utils.h>
// keeping THC headers for gpuAtomicAdd
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

#include <thrust/pair.h>

Expand Down Expand Up @@ -45,12 +46,12 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
int64_t input_dim_x, int64_t input_dim_y,
int64_t output_dim_x, int64_t output_dim_y,
int64_t pad_l, int64_t pad_t,
int64_t output_xy) {
int64_t output_xy, int y_shift, int z_shift, int nplane) {
// 3D grid of 1D blocks
auto input_offset =
(blockIdx.y + blockIdx.z * gridDim.y) * input_dim_x * input_dim_y;
((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) * input_dim_x * input_dim_y;
auto output_offset =
(blockIdx.y + blockIdx.z * gridDim.y) * output_dim_x * output_dim_y;
((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) * output_dim_x * output_dim_y;

auto output_x = output_xy % output_dim_x;
auto output_y = output_xy / output_dim_x;
Expand Down Expand Up @@ -110,7 +111,7 @@ template<typename scalar_t>
__global__ void reflection_pad2d_out_kernel(
scalar_t * input, scalar_t * output,
int64_t input_dim_x, int64_t input_dim_y,
int pad_t, int pad_b, int pad_l, int pad_r) {
int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift, int nplane) {
auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
auto output_dim_x = input_dim_x + pad_l + pad_r;
auto output_dim_y = input_dim_y + pad_t + pad_b;
Expand All @@ -120,7 +121,7 @@ __global__ void reflection_pad2d_out_kernel(
input_dim_x, input_dim_y,
output_dim_x, output_dim_y,
pad_l, pad_t,
output_xy);
output_xy, y_shift, z_shift, nplane);

output[index_pair.second] = input[index_pair.first];
}
Expand All @@ -130,7 +131,7 @@ template <typename scalar_t>
__global__ void reflection_pad2d_backward_out_kernel(
scalar_t * grad_input, scalar_t * grad_output,
int64_t input_dim_x, int64_t input_dim_y,
int pad_t, int pad_b, int pad_l, int pad_r) {
int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift, int nplane) {
auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
auto output_dim_x = input_dim_x + pad_l + pad_r;
auto output_dim_y = input_dim_y + pad_t + pad_b;
Expand All @@ -140,7 +141,7 @@ __global__ void reflection_pad2d_backward_out_kernel(
input_dim_x, input_dim_y,
output_dim_x, output_dim_y,
pad_l, pad_t,
output_xy);
output_xy, y_shift, z_shift, nplane);

gpuAtomicAdd(&grad_input[index_pair.first], grad_output[index_pair.second]);
}
Expand Down Expand Up @@ -258,19 +259,30 @@ void reflection_pad2d_out_template(

Tensor input = input_.contiguous();

int output_plane_size = output_h * output_w;
int64_t output_plane_size = output_h * output_w;
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
dim3 grid_size(
(int) std::ceil(output_plane_size/256.0), nplane, nbatch);

int64_t size_y = nplane;
int64_t size_z = nbatch;

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
input.scalar_type(), "reflection_pad2d_out_template", [&] {
reflection_pad2d_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
input_w, input_h,
pad_t, pad_b, pad_l, pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();

for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));

dim3 grid_size(THCCeilDiv(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use cuda::ATenCeilDiv here, don't include legacy header


reflection_pad2d_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
input_w, input_h,
pad_t, pad_b, pad_l, pad_r, block_y, block_z, nplane);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
}
Expand Down Expand Up @@ -319,19 +331,30 @@ void reflection_pad2d_backward_out_template(

Tensor grad_output = grad_output_.contiguous();

int output_plane_size = output_h * output_w;
int64_t output_plane_size = output_h * output_w;
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
dim3 grid_size(
(int) std::ceil(output_plane_size/256.0), nplane, nbatch);

int64_t size_y = nplane;
int64_t size_z = nbatch;

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
input.scalar_type(), "reflection_pad2d_backward_out_template", [&] {
reflection_pad2d_backward_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(),
input_w, input_h,
pad_t, pad_b, pad_l, pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();

for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));

dim3 grid_size(THCCeilDiv(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);

reflection_pad2d_backward_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(),
input_w, input_h,
pad_t, pad_b, pad_l, pad_r, block_y, block_z, nplane);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
}
Expand Down
21 changes: 21 additions & 0 deletions test/test_nn.py
Expand Up @@ -12213,6 +12213,27 @@ def test_ReflectionPad_empty(self, device, dtype):
inp = torch.randn(3, 0, 10, 10, device=device, dtype=dtype)
mod(inp)

@onlyCUDA # Test if CPU and GPU results match
def test_ReflectionPad2d_large(self, device):
shapes = ([2, 65736, 6, 6], [65736, 2, 6, 6])
pad = (1, 2, 3, 4)
for shape in shapes:
x = torch.randn(shape, device=device, requires_grad=True)
ref_x = x.detach().cpu().requires_grad_()

out = F.pad(x, pad, mode='reflect')
ref_out = F.pad(ref_x, pad, mode='reflect')

self.assertEqual(out, ref_out)

g = torch.randn_like(out)
ref_g = g.cpu()

out.backward(g)
ref_out.backward(ref_g)

self.assertEqual(x.grad, ref_x.grad)


@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double)
Expand Down