Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/1.8] Fix performance of CUDA trilinear interpolate backward #52649

Merged
merged 1 commit into from Feb 23, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
181 changes: 97 additions & 84 deletions aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu
Expand Up @@ -113,31 +113,46 @@ __global__ void upsample_trilinear3d_out_frame(
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_trilinear3d_backward_out_frame(
const size_t nc_,
const int depth1,
const int height1,
const int width1,
const int depth2,
const int height2,
const int width2,
const int num_kernels,
const accscalar_t rdepth,
const accscalar_t rheight,
const accscalar_t rwidth,
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
const size_t i_numel = nc_ * depth1 * height1 * width1;
const size_t o_numel = nc_ * depth2 * height2 * width2;

for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; index += blockDim.x * gridDim.x) {
size_t index_temp = index;
const int w2 = index_temp % width2; // 0:width2-1
index_temp /= width2;
const int h2 = index_temp % height2; // 0:height2-1
index_temp /= height2;
const int t2 = index_temp % depth2; // 0:depth2-1
const int nc = index_temp / depth2;
PackedTensorAccessor64<scalar_t, 5> idata,
const PackedTensorAccessor64<scalar_t, 5> odata,
scalar_t* idata_ptr) {
int index = threadIdx.x + blockIdx.x * blockDim.x;

const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int depth1 = idata.size(2);
const int height1 = idata.size(3);
const int width1 = idata.size(4);
const int depth2 = odata.size(2);
const int height2 = odata.size(3);
const int width2 = odata.size(4);

const size_t i_numel = batchsize * channels * depth1 * height1 * width1;

if (index < num_kernels) {
const int w2 = (index % (height2 * width2)) % width2; // 0:width2-1
const int h2 = (index % (height2 * width2)) / width2; // 0:height2-1
const int t2 = index / (height2 * width2); // 0:depth2-1
// special case: just copy
if (depth1 == depth2 && height1 == height2 && width1 == width2) {
const int t1 = t2;
const int h1 = h2;
const int w1 = w2;

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = odata[n][c][t1][h1][w1];
idata[n][c][t2][h2][w2] = val;
}
}
return;
}
//
const accscalar_t t1r = area_pixel_compute_source_index<accscalar_t>(
rdepth, t2, align_corners, /*cubic=*/false);
const int t1 = t1r;
Expand All @@ -159,55 +174,60 @@ __global__ void upsample_trilinear3d_backward_out_frame(
const accscalar_t w1lambda = w1r - w1;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
//
const scalar_t d2val = odata[index];
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1, h1, w1),
i_numel,
static_cast<scalar_t>(t0lambda * h0lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p),
i_numel,
static_cast<scalar_t>(t0lambda * h0lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1),
i_numel,
static_cast<scalar_t>(t0lambda * h1lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
i_numel,
static_cast<scalar_t>(t0lambda * h1lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1),
i_numel,
static_cast<scalar_t>(t1lambda * h0lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
i_numel,
static_cast<scalar_t>(t1lambda * h0lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
i_numel,
static_cast<scalar_t>(t1lambda * h1lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
i_numel,
static_cast<scalar_t>(t1lambda * h1lambda * w1lambda * d2val),
true);
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t d2val = odata[n][c][t2][h2][w2];
const size_t nc = n * channels + c;
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1, h1, w1),
i_numel,
static_cast<scalar_t>(t0lambda * h0lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p),
i_numel,
static_cast<scalar_t>(t0lambda * h0lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1),
i_numel,
static_cast<scalar_t>(t0lambda * h1lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
i_numel,
static_cast<scalar_t>(t0lambda * h1lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1),
i_numel,
static_cast<scalar_t>(t1lambda * h0lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
i_numel,
static_cast<scalar_t>(t1lambda * h0lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
i_numel,
static_cast<scalar_t>(t1lambda * h1lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
i_numel,
static_cast<scalar_t>(t1lambda * h1lambda * w1lambda * d2val),
true);
}
}
}
}

Expand Down Expand Up @@ -350,21 +370,20 @@ static void upsample_trilinear3d_backward_out_cuda_template(
// so it has to be initialized to zero.
grad_input.zero_();

// const size_t num_kernels = nbatch * channels * output_depth * output_height * output_width;
const size_t num_kernels = grad_output.numel();
const int num_kernels = output_depth * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

if (num_kernels > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(),
"upsample_trilinear3d_backward_out_frame",
[&] {
using accscalar_t = at::acc_type<scalar_t, true>;

auto idata = grad_input.data_ptr<scalar_t>();
auto odata = grad_output.data_ptr<scalar_t>();
auto idata = grad_input.packed_accessor64<scalar_t, 5>();
auto odata = grad_output.packed_accessor64<scalar_t, 5>();
scalar_t* idata_ptr = grad_input.data_ptr<scalar_t>();

const accscalar_t rdepth = area_pixel_compute_scale<accscalar_t>(
input_depth, output_depth, align_corners, scales_d);
Expand All @@ -374,26 +393,20 @@ static void upsample_trilinear3d_backward_out_cuda_template(
input_width, output_width, align_corners, scales_w);

upsample_trilinear3d_backward_out_frame<scalar_t, accscalar_t>
<<<cuda::ATenCeilDiv(num_kernels, static_cast<size_t>(num_threads)),
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(
nbatch * channels,
input_depth,
input_height,
input_width,
output_depth,
output_height,
output_width,
num_kernels,
rdepth,
rheight,
rwidth,
align_corners,
idata,
odata);
odata,
idata_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}

} // namespace
Expand Down