diff --git a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu index 8ac7abca1824..e1f583e8be6e 100644 --- a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu @@ -113,31 +113,46 @@ __global__ void upsample_trilinear3d_out_frame( template C10_LAUNCH_BOUNDS_1(1024) __global__ void upsample_trilinear3d_backward_out_frame( - const size_t nc_, - const int depth1, - const int height1, - const int width1, - const int depth2, - const int height2, - const int width2, + const int num_kernels, const accscalar_t rdepth, const accscalar_t rheight, const accscalar_t rwidth, const bool align_corners, - scalar_t* __restrict__ idata, - const scalar_t* __restrict__ odata) { - const size_t i_numel = nc_ * depth1 * height1 * width1; - const size_t o_numel = nc_ * depth2 * height2 * width2; - - for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; index += blockDim.x * gridDim.x) { - size_t index_temp = index; - const int w2 = index_temp % width2; // 0:width2-1 - index_temp /= width2; - const int h2 = index_temp % height2; // 0:height2-1 - index_temp /= height2; - const int t2 = index_temp % depth2; // 0:depth2-1 - const int nc = index_temp / depth2; + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata, + scalar_t* idata_ptr) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + + const int batchsize = idata.size(0); + const int channels = idata.size(1); + const int depth1 = idata.size(2); + const int height1 = idata.size(3); + const int width1 = idata.size(4); + const int depth2 = odata.size(2); + const int height2 = odata.size(3); + const int width2 = odata.size(4); + + const size_t i_numel = batchsize * channels * depth1 * height1 * width1; + + if (index < num_kernels) { + const int w2 = (index % (height2 * width2)) % width2; // 0:width2-1 + const int h2 = (index % (height2 * width2)) / width2; // 0:height2-1 + const int t2 = index / (height2 * width2); // 0:depth2-1 + // special case: just copy + if (depth1 == depth2 && height1 == height2 && width1 == width2) { + const int t1 = t2; + const int h1 = h2; + const int w1 = w2; + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; ++c) { + const scalar_t val = odata[n][c][t1][h1][w1]; + idata[n][c][t2][h2][w2] = val; + } + } + return; + } + // const accscalar_t t1r = area_pixel_compute_source_index( rdepth, t2, align_corners, /*cubic=*/false); const int t1 = t1r; @@ -159,55 +174,60 @@ __global__ void upsample_trilinear3d_backward_out_frame( const accscalar_t w1lambda = w1r - w1; const accscalar_t w0lambda = static_cast(1) - w1lambda; // - const scalar_t d2val = odata[index]; - fastAtomicAdd( - idata, - idx_3d(nc, depth1, height1, width1, t1, h1, w1), - i_numel, - static_cast(t0lambda * h0lambda * w0lambda * d2val), - true); - fastAtomicAdd( - idata, - idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p), - i_numel, - static_cast(t0lambda * h0lambda * w1lambda * d2val), - true); - fastAtomicAdd( - idata, - idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1), - i_numel, - static_cast(t0lambda * h1lambda * w0lambda * d2val), - true); - fastAtomicAdd( - idata, - idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p), - i_numel, - static_cast(t0lambda * h1lambda * w1lambda * d2val), - true); - fastAtomicAdd( - idata, - idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1), - i_numel, - static_cast(t1lambda * h0lambda * w0lambda * d2val), - true); - fastAtomicAdd( - idata, - idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p), - i_numel, - static_cast(t1lambda * h0lambda * w1lambda * d2val), - true); - fastAtomicAdd( - idata, - idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1), - i_numel, - static_cast(t1lambda * h1lambda * w0lambda * d2val), - true); - fastAtomicAdd( - idata, - idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p), - i_numel, - static_cast(t1lambda * h1lambda * w1lambda * d2val), - true); + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; ++c) { + const scalar_t d2val = odata[n][c][t2][h2][w2]; + const size_t nc = n * channels + c; + fastAtomicAdd( + idata_ptr, + idx_3d(nc, depth1, height1, width1, t1, h1, w1), + i_numel, + static_cast(t0lambda * h0lambda * w0lambda * d2val), + true); + fastAtomicAdd( + idata_ptr, + idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p), + i_numel, + static_cast(t0lambda * h0lambda * w1lambda * d2val), + true); + fastAtomicAdd( + idata_ptr, + idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1), + i_numel, + static_cast(t0lambda * h1lambda * w0lambda * d2val), + true); + fastAtomicAdd( + idata_ptr, + idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p), + i_numel, + static_cast(t0lambda * h1lambda * w1lambda * d2val), + true); + fastAtomicAdd( + idata_ptr, + idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1), + i_numel, + static_cast(t1lambda * h0lambda * w0lambda * d2val), + true); + fastAtomicAdd( + idata_ptr, + idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p), + i_numel, + static_cast(t1lambda * h0lambda * w1lambda * d2val), + true); + fastAtomicAdd( + idata_ptr, + idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1), + i_numel, + static_cast(t1lambda * h1lambda * w0lambda * d2val), + true); + fastAtomicAdd( + idata_ptr, + idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p), + i_numel, + static_cast(t1lambda * h1lambda * w1lambda * d2val), + true); + } + } } } @@ -350,21 +370,20 @@ static void upsample_trilinear3d_backward_out_cuda_template( // so it has to be initialized to zero. grad_input.zero_(); - // const size_t num_kernels = nbatch * channels * output_depth * output_height * output_width; - const size_t num_kernels = grad_output.numel(); + const int num_kernels = output_depth * output_height * output_width; const int num_threads = std::min( at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_kernels > 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_output.scalar_type(), "upsample_trilinear3d_backward_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = grad_input.data_ptr(); - auto odata = grad_output.data_ptr(); + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); + scalar_t* idata_ptr = grad_input.data_ptr(); const accscalar_t rdepth = area_pixel_compute_scale( input_depth, output_depth, align_corners, scales_d); @@ -374,26 +393,20 @@ static void upsample_trilinear3d_backward_out_cuda_template( input_width, output_width, align_corners, scales_w); upsample_trilinear3d_backward_out_frame - <<(num_threads)), + <<>>( - nbatch * channels, - input_depth, - input_height, - input_width, - output_depth, - output_height, - output_width, + num_kernels, rdepth, rheight, rwidth, align_corners, idata, - odata); + odata, + idata_ptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - } } } // namespace