Skip to content

Commit

Permalink
apply diff 52351 (#52649)
Browse files Browse the repository at this point in the history
  • Loading branch information
xwang233 committed Feb 23, 2021
1 parent 02b61b4 commit d6943ea
Showing 1 changed file with 97 additions and 84 deletions.
181 changes: 97 additions & 84 deletions aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu
Expand Up @@ -113,31 +113,46 @@ __global__ void upsample_trilinear3d_out_frame(
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_trilinear3d_backward_out_frame(
const size_t nc_,
const int depth1,
const int height1,
const int width1,
const int depth2,
const int height2,
const int width2,
const int num_kernels,
const accscalar_t rdepth,
const accscalar_t rheight,
const accscalar_t rwidth,
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
const size_t i_numel = nc_ * depth1 * height1 * width1;
const size_t o_numel = nc_ * depth2 * height2 * width2;

for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; index += blockDim.x * gridDim.x) {
size_t index_temp = index;
const int w2 = index_temp % width2; // 0:width2-1
index_temp /= width2;
const int h2 = index_temp % height2; // 0:height2-1
index_temp /= height2;
const int t2 = index_temp % depth2; // 0:depth2-1
const int nc = index_temp / depth2;
PackedTensorAccessor64<scalar_t, 5> idata,
const PackedTensorAccessor64<scalar_t, 5> odata,
scalar_t* idata_ptr) {
int index = threadIdx.x + blockIdx.x * blockDim.x;

const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int depth1 = idata.size(2);
const int height1 = idata.size(3);
const int width1 = idata.size(4);
const int depth2 = odata.size(2);
const int height2 = odata.size(3);
const int width2 = odata.size(4);

const size_t i_numel = batchsize * channels * depth1 * height1 * width1;

if (index < num_kernels) {
const int w2 = (index % (height2 * width2)) % width2; // 0:width2-1
const int h2 = (index % (height2 * width2)) / width2; // 0:height2-1
const int t2 = index / (height2 * width2); // 0:depth2-1
// special case: just copy
if (depth1 == depth2 && height1 == height2 && width1 == width2) {
const int t1 = t2;
const int h1 = h2;
const int w1 = w2;

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = odata[n][c][t1][h1][w1];
idata[n][c][t2][h2][w2] = val;
}
}
return;
}
//
const accscalar_t t1r = area_pixel_compute_source_index<accscalar_t>(
rdepth, t2, align_corners, /*cubic=*/false);
const int t1 = t1r;
Expand All @@ -159,55 +174,60 @@ __global__ void upsample_trilinear3d_backward_out_frame(
const accscalar_t w1lambda = w1r - w1;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
//
const scalar_t d2val = odata[index];
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1, h1, w1),
i_numel,
static_cast<scalar_t>(t0lambda * h0lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p),
i_numel,
static_cast<scalar_t>(t0lambda * h0lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1),
i_numel,
static_cast<scalar_t>(t0lambda * h1lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
i_numel,
static_cast<scalar_t>(t0lambda * h1lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1),
i_numel,
static_cast<scalar_t>(t1lambda * h0lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
i_numel,
static_cast<scalar_t>(t1lambda * h0lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
i_numel,
static_cast<scalar_t>(t1lambda * h1lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
i_numel,
static_cast<scalar_t>(t1lambda * h1lambda * w1lambda * d2val),
true);
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t d2val = odata[n][c][t2][h2][w2];
const size_t nc = n * channels + c;
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1, h1, w1),
i_numel,
static_cast<scalar_t>(t0lambda * h0lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p),
i_numel,
static_cast<scalar_t>(t0lambda * h0lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1),
i_numel,
static_cast<scalar_t>(t0lambda * h1lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
i_numel,
static_cast<scalar_t>(t0lambda * h1lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1),
i_numel,
static_cast<scalar_t>(t1lambda * h0lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
i_numel,
static_cast<scalar_t>(t1lambda * h0lambda * w1lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
i_numel,
static_cast<scalar_t>(t1lambda * h1lambda * w0lambda * d2val),
true);
fastAtomicAdd(
idata_ptr,
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
i_numel,
static_cast<scalar_t>(t1lambda * h1lambda * w1lambda * d2val),
true);
}
}
}
}

Expand Down Expand Up @@ -350,21 +370,20 @@ static void upsample_trilinear3d_backward_out_cuda_template(
// so it has to be initialized to zero.
grad_input.zero_();

// const size_t num_kernels = nbatch * channels * output_depth * output_height * output_width;
const size_t num_kernels = grad_output.numel();
const int num_kernels = output_depth * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

if (num_kernels > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(),
"upsample_trilinear3d_backward_out_frame",
[&] {
using accscalar_t = at::acc_type<scalar_t, true>;

auto idata = grad_input.data_ptr<scalar_t>();
auto odata = grad_output.data_ptr<scalar_t>();
auto idata = grad_input.packed_accessor64<scalar_t, 5>();
auto odata = grad_output.packed_accessor64<scalar_t, 5>();
scalar_t* idata_ptr = grad_input.data_ptr<scalar_t>();

const accscalar_t rdepth = area_pixel_compute_scale<accscalar_t>(
input_depth, output_depth, align_corners, scales_d);
Expand All @@ -374,26 +393,20 @@ static void upsample_trilinear3d_backward_out_cuda_template(
input_width, output_width, align_corners, scales_w);

upsample_trilinear3d_backward_out_frame<scalar_t, accscalar_t>
<<<cuda::ATenCeilDiv(num_kernels, static_cast<size_t>(num_threads)),
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(
nbatch * channels,
input_depth,
input_height,
input_width,
output_depth,
output_height,
output_width,
num_kernels,
rdepth,
rheight,
rwidth,
align_corners,
idata,
odata);
odata,
idata_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}

} // namespace
Expand Down

0 comments on commit d6943ea

Please sign in to comment.