Skip to content

Commit

Permalink
fix 3d cuda launch configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
xwang233 committed Jan 15, 2021
1 parent e4b7b1c commit 2a98f21
Showing 1 changed file with 56 additions and 46 deletions.
102 changes: 56 additions & 46 deletions aten/src/ATen/native/cuda/ReplicationPadding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ template <typename scalar_t>
__global__ void replication_pad_forward_kernel3d(
PackedTensorAccessor64<scalar_t, 5> input,
PackedTensorAccessor64<scalar_t, 5> output,
int pfront, int pback, int ptop, int pbottom, int pleft, int pright) {
int pfront, int pback, int ptop, int pbottom, int pleft, int pright, int y_shift, int z_shift) {

int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
int plane = blockIdx.y + y_shift;
int batch = blockIdx.z + z_shift;
if (outputPointId >= (output.size(2) * output.size(3) *
output.size(4))) {
return;
Expand Down Expand Up @@ -166,10 +166,10 @@ template <typename scalar_t>
__global__ void replication_pad_backward_kernel(
PackedTensorAccessor64<scalar_t, 5> gradInput,
PackedTensorAccessor64<scalar_t, 5> gradOutput,
int pfront, int pback, int ptop, int pbottom, int pleft, int pright) {
int pfront, int pback, int ptop, int pbottom, int pleft, int pright, int y_shift, int z_shift) {
int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
int plane = blockIdx.y + y_shift;
int batch = blockIdx.z + z_shift;

if (outputPointId >= (gradOutput.size(2) * gradOutput.size(3) *
gradOutput.size(4))) {
Expand Down Expand Up @@ -687,38 +687,37 @@ void replication_pad3d_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "replication_pad3d_cuda", [&] {
at::Tensor input_ = input;
at::Tensor output_ = output;
if (numInputDims == 4) {
auto input_ = input.unsqueeze(0);
auto output_ = output.unsqueeze(0);
auto devInput = input_.packed_accessor64<scalar_t, 5>();
auto devOutput = output_.packed_accessor64<scalar_t, 5>();
int outputPlaneSize = devOutput.size(2) * devOutput.size(3) *
devOutput.size(4);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.size(1),
devOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
replication_pad_forward_kernel3d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(
devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto devInput = input.packed_accessor64<scalar_t, 5>();
auto devOutput = output.packed_accessor64<scalar_t, 5>();
int outputPlaneSize = devOutput.size(2) * devOutput.size(3) *
devOutput.size(4);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devOutput.size(1),
devOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
replication_pad_forward_kernel3d <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(
devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
auto devInput = input_.packed_accessor64<scalar_t, 5>();
auto devOutput = output_.packed_accessor64<scalar_t, 5>();
int outputPlaneSize = devOutput.size(2) * devOutput.size(3) * devOutput.size(4);
int size1 = devOutput.size(1);
int size0 = devOutput.size(0);
int y_left = size1;
int y_shift = 0;
while (y_left > 0) {
int z_left = size0;
int z_shift = 0;
while (z_left > 0) {
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), y_left > 65535 ? 65535 : y_left, z_left > 65535 ? 65535 : z_left);
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
replication_pad_forward_kernel3d <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright, y_shift, z_shift);
C10_CUDA_KERNEL_LAUNCH_CHECK();
z_shift += 65535;
z_left -= 65535;
}
y_shift += 65535;
y_left -= 65535;
}
}
);
Expand Down Expand Up @@ -770,17 +769,28 @@ void replication_pad3d_backward_out_cuda_template(
auto devGradInput = gradInput_.packed_accessor64<scalar_t, 5>();
auto devGradOutput = gradOutput_.packed_accessor64<scalar_t, 5>();
int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3) *
devGradOutput.size(4);
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256),
devGradOutput.size(1),
devGradOutput.size(0));
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
replication_pad_backward_kernel <<<gridSize, blockSize, 0,
at::cuda::getCurrentCUDAStream()>>>(
devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright);
C10_CUDA_KERNEL_LAUNCH_CHECK();
int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3) * devGradOutput.size(4);
int size1 = devGradOutput.size(1);
int size0 = devGradOutput.size(0);
int y_left = size1;
int y_shift = 0;
while (y_left > 0) {
int z_left = size0;
int z_shift = 0;
while (z_left > 0) {
dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), y_left > 65535 ? 65535 : y_left, z_left > 65535 ? 65535 : z_left);
dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
replication_pad_backward_kernel <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
devGradInput, devGradOutput, pfront, pback, ptop, pbottom, pleft, pright, y_shift, z_shift);
C10_CUDA_KERNEL_LAUNCH_CHECK();
z_shift += 65535;
z_left -= 65535;
}
y_shift += 65535;
y_left -= 65535;
}
}
);
}
Expand Down

0 comments on commit 2a98f21

Please sign in to comment.