Skip to content

Commit

Permalink
reuse code
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoE committed Jan 17, 2023
1 parent 12c0dff commit c0c3a5e
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 464 deletions.
17 changes: 8 additions & 9 deletions aten/src/ATen/native/UpSampleBicubic2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ static void upsample_bicubic2d_backward_out_frame(
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales_w);
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, [&](int64_t start, int64_t end) {
opmath_t* buffer_data_ptr = nullptr;
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (std::is_same<scalar_t, BFloat16>::value) {
buffer_data_ptr = new opmath_t[input_slice_size];
memset(buffer_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
}
for (const auto i : c10::irange(start, end)) {
scalar_t* in = idata + i * input_slice_size;
Expand Down Expand Up @@ -168,7 +170,7 @@ static void upsample_bicubic2d_backward_out_frame(
for (const auto ii : c10::irange(4)) {
for (const auto jj : c10::irange(4)) {
upsample_increment_value_bounded<opmath_t>(
buffer_data_ptr == nullptr ? reinterpret_cast<opmath_t*>(in) : buffer_data_ptr,
acc_data_ptr == nullptr ? reinterpret_cast<opmath_t*>(in) : acc_data_ptr,
input_width,
input_height,
input_x - 1 + ii,
Expand All @@ -178,13 +180,10 @@ static void upsample_bicubic2d_backward_out_frame(
}
}
}
if (buffer_data_ptr != nullptr) {
apply_grad_input(buffer_data_ptr, in, input_slice_size);
if (acc_data_ptr != nullptr) {
apply_grad_input(acc_data_ptr, in, input_slice_size);
}
}
if (buffer_data_ptr != nullptr) {
delete []buffer_data_ptr;
}
});
}

Expand Down

0 comments on commit c0c3a5e

Please sign in to comment.