Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix replication_pad for cuda launch configuration #50565

Closed
wants to merge 11 commits into from
272 changes: 140 additions & 132 deletions aten/src/ATen/native/cuda/ReplicationPadding.cu
Expand Up @@ -30,11 +30,11 @@ template <typename scalar_t>
__global__ void replication_pad_forward_kernel1d(
PackedTensorAccessor64<scalar_t, 3> input,
PackedTensorAccessor64<scalar_t, 3> output,
int padL, int padR) {
int padL, int padR, int y_shift, int z_shift) {

int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
int plane = blockIdx.y + y_shift;
int batch = blockIdx.z + z_shift;
if (outputPointId >= output.size(2)) {
return;
}
Expand All @@ -53,11 +53,11 @@ template <typename scalar_t>
__global__ void replication_pad_backward_kernel(
PackedTensorAccessor64<scalar_t, 3> gradInput,
PackedTensorAccessor64<scalar_t, 3> gradOutput,
int padL, int padR) {
int padL, int padR, int y_shift, int z_shift) {

int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
int plane = blockIdx.y + y_shift;
int batch = blockIdx.z + z_shift;
if (outputPointId >= gradOutput.size(2)) {
return;
}
Expand All @@ -76,11 +76,11 @@ template <typename scalar_t>
__global__ void replication_pad_forward_kernel2d(
PackedTensorAccessor64<scalar_t, 4> input,
PackedTensorAccessor64<scalar_t, 4> output,
int padT, int padB, int padL, int padR) {
int padT, int padB, int padL, int padR, int y_shift, int z_shift) {

int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
int plane = blockIdx.y + y_shift;
int batch = blockIdx.z + z_shift;
if (outputPointId >= output.size(2) * output.size(3)) {
return;
}
Expand All @@ -103,11 +103,11 @@ template <typename scalar_t>
__global__ void replication_pad_backward_kernel(
PackedTensorAccessor64<scalar_t, 4> gradInput,
PackedTensorAccessor64<scalar_t, 4> gradOutput,
int padT, int padB, int padL, int padR) {
int padT, int padB, int padL, int padR, int y_shift, int z_shift) {

int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
int plane = blockIdx.y + y_shift;
int batch = blockIdx.z + z_shift;
if (outputPointId >= gradOutput.size(2) * gradOutput.size(3)) {
return;
}
Expand All @@ -130,11 +130,11 @@ template <typename scalar_t>
__global__ void replication_pad_forward_kernel3d(
PackedTensorAccessor64<scalar_t, 5> input,
PackedTensorAccessor64<scalar_t, 5> output,
int pfront, int pback, int ptop, int pbottom, int pleft, int pright) {
int pfront, int pback, int ptop, int pbottom, int pleft, int pright, int y_shift, int z_shift) {

int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
int plane = blockIdx.y + y_shift;
int batch = blockIdx.z + z_shift;
if (outputPointId >= (output.size(2) * output.size(3) *
output.size(4))) {
return;
Expand Down Expand Up @@ -166,10 +166,10 @@ template <typename scalar_t>
__global__ void replication_pad_backward_kernel(
PackedTensorAccessor64<scalar_t, 5> gradInput,
PackedTensorAccessor64<scalar_t, 5> gradOutput,
int pfront, int pback, int ptop, int pbottom, int pleft, int pright) {
int pfront, int pback, int ptop, int pbottom, int pleft, int pright, int y_shift, int z_shift) {
int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
int plane = blockIdx.y + y_shift;
int batch = blockIdx.z + z_shift;

if (outputPointId >= (gradOutput.size(2) * gradOutput.size(3) *
gradOutput.size(4))) {
Expand Down Expand Up @@ -249,34 +249,32 @@ void replication_pad1d_out_cuda_template(

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "replication_pad1d_cuda", [&] {
at::Tensor input_ = input;
at::Tensor output_ = output;
if (numInputDims == 2) {
auto input_ = input.unsqueeze(0);
auto output_ = output.unsqueeze(0);
auto devInput = input_.packed_accessor64<scalar_t, 3>();
auto devOutput = output_.packed_accessor64<scalar_t, 3>();

int outputPlaneSize = devOutput.size(2);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.size(1),
devOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel1d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padL, padR);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto devInput = input.packed_accessor64<scalar_t, 3>();
auto devOutput = output.packed_accessor64<scalar_t, 3>();

int outputPlaneSize = devOutput.size(2);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.size(1),
devOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel1d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padL, padR);
C10_CUDA_KERNEL_LAUNCH_CHECK();
input_ = input.unsqueeze(0);
output_ = output.unsqueeze(0);
}

auto devInput = input_.packed_accessor64<scalar_t, 3>();
auto devOutput = output_.packed_accessor64<scalar_t, 3>();

int64_t outputPlaneSize = devOutput.size(2);
int64_t size1 = devOutput.size(1);
int64_t size0 = devOutput.size(0);

for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));

dim3 gridSize(THCCeilDiv(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel1d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padL, padR, block_y, block_z);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
Expand Down Expand Up @@ -330,16 +328,23 @@ void replication_pad1d_backward_out_cuda_template(
auto devGradInput = gradInput_.packed_accessor64<scalar_t, 3>();
auto devGradOutput = gradOutput_.packed_accessor64<scalar_t, 3>();

int outputPlaneSize = devGradOutput.size(2);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devGradOutput.size(1),
devGradOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
int64_t outputPlaneSize = devGradOutput.size(2);
int64_t size1 = devGradOutput.size(1);
int64_t size0 = devGradOutput.size(0);

for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));

replication_pad_backward_kernel <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(devGradInput, devGradOutput,
padL, padR);
C10_CUDA_KERNEL_LAUNCH_CHECK();
dim3 gridSize(THCCeilDiv(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_backward_kernel <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
devGradInput, devGradOutput, padL, padR, block_y, block_z);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
});
}

Expand Down Expand Up @@ -398,36 +403,31 @@ void replication_pad2d_out_cuda_template(

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "replication_pad2d_cuda", [&] {
at::Tensor input_ = input;
at::Tensor output_ = output;
if (numInputDims == 3) {
auto input_ = input.unsqueeze(0);
auto output_ = output.unsqueeze(0);
auto devInput = input_.packed_accessor64<scalar_t, 4>();
auto devOutput = output_.packed_accessor64<scalar_t, 4>();

int outputPlaneSize = devOutput.size(2) * devOutput.size(3);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.size(1),
devOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel2d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(
devInput, devOutput, padT, padB, padL, padR);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto devInput = input.packed_accessor64<scalar_t, 4>();
auto devOutput = output.packed_accessor64<scalar_t, 4>();

int outputPlaneSize = devOutput.size(2) * devOutput.size(3);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.size(1),
devOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel2d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput,
padT, padB, padL, padR);
C10_CUDA_KERNEL_LAUNCH_CHECK();
input_ = input.unsqueeze(0);
output_ = output.unsqueeze(0);
}
auto devInput = input_.packed_accessor64<scalar_t, 4>();
auto devOutput = output_.packed_accessor64<scalar_t, 4>();

int64_t outputPlaneSize = devOutput.size(2) * devOutput.size(3);
int64_t size1 = devOutput.size(1);
int64_t size0 = devOutput.size(0);

for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));

dim3 gridSize(THCCeilDiv(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel2d <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
devInput, devOutput, padT, padB, padL, padR, block_y, block_z);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
Expand Down Expand Up @@ -490,15 +490,23 @@ void replication_pad2d_backward_out_cuda_template(
auto devGradInput = gradInput_.packed_accessor64<scalar_t, 4>();
auto devGradOutput = gradOutput_.packed_accessor64<scalar_t, 4>();

int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devGradOutput.size(1),
devGradOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
replication_pad_backward_kernel <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(devGradInput, devGradOutput,
padT, padB, padL, padR);
C10_CUDA_KERNEL_LAUNCH_CHECK();
int64_t outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3);
int64_t size1 = devGradOutput.size(1);
int64_t size0 = devGradOutput.size(0);

for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));

dim3 gridSize(THCCeilDiv(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_backward_kernel <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
devGradInput, devGradOutput, padT, padB, padL, padR, block_y, block_z);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
}
Expand Down Expand Up @@ -652,38 +660,32 @@ void replication_pad3d_out_cuda_template(

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "replication_pad3d_cuda", [&] {
at::Tensor input_ = input;
at::Tensor output_ = output;
if (numInputDims == 4) {
auto input_ = input.unsqueeze(0);
auto output_ = output.unsqueeze(0);
auto devInput = input_.packed_accessor64<scalar_t, 5>();
auto devOutput = output_.packed_accessor64<scalar_t, 5>();

int outputPlaneSize = devOutput.size(2) * devOutput.size(3) *
devOutput.size(4);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.size(1),
devOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel3d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(
devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto devInput = input.packed_accessor64<scalar_t, 5>();
auto devOutput = output.packed_accessor64<scalar_t, 5>();

int outputPlaneSize = devOutput.size(2) * devOutput.size(3) *
devOutput.size(4);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.size(1),
devOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel3d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(
devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

auto devInput = input_.packed_accessor64<scalar_t, 5>();
auto devOutput = output_.packed_accessor64<scalar_t, 5>();

int64_t outputPlaneSize = devOutput.size(2) * devOutput.size(3) * devOutput.size(4);
int64_t size1 = devOutput.size(1);
int64_t size0 = devOutput.size(0);

for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));

dim3 gridSize(THCCeilDiv(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_forward_kernel3d <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright, block_y, block_z);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
Expand Down Expand Up @@ -735,17 +737,23 @@ void replication_pad3d_backward_out_cuda_template(
auto devGradInput = gradInput_.packed_accessor64<scalar_t, 5>();
auto devGradOutput = gradOutput_.packed_accessor64<scalar_t, 5>();

int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3) *
devGradOutput.size(4);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devGradOutput.size(1),
devGradOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_backward_kernel <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(
devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright);
C10_CUDA_KERNEL_LAUNCH_CHECK();
int64_t outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3) * devGradOutput.size(4);
int64_t size1 = devGradOutput.size(1);
int64_t size0 = devGradOutput.size(0);

for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));

dim3 gridSize(THCCeilDiv(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);

replication_pad_backward_kernel <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright, block_y, block_z);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
}
Expand Down