-
Notifications
You must be signed in to change notification settings - Fork 21.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: resolves #16158 Pull Request resolved: #19630 Differential Revision: D15335765 Pulled By: ezyang fbshipit-source-id: 03dd590c715a65c20ac99674a5d77179cd4a50fc
- Loading branch information
1 parent
7ffc37e
commit 3479777
Showing
38 changed files
with
2,268 additions
and
2,392 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/TensorUtils.h> | ||
#include <ATen/cuda/CUDAApplyUtils.cuh> | ||
|
||
#include <math.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
/* TODO: move this to a common place */ | ||
template <typename scalar_t> | ||
__device__ inline scalar_t min(scalar_t a, scalar_t b) { | ||
return a < b ? a : b; | ||
} | ||
|
||
template <typename scalar_t> | ||
__device__ inline scalar_t max(scalar_t a, scalar_t b) { | ||
return a > b ? a : b; | ||
} | ||
|
||
static inline void upsample_1d_shape_check( | ||
const Tensor& input, | ||
const Tensor& grad_output, | ||
int nbatch, | ||
int nchannels, | ||
int input_width, | ||
int output_width) { | ||
AT_CHECK( | ||
input_width > 0 && output_width > 0, | ||
"input and output sizes should be greater than 0, but got input (W: ", | ||
input_width, | ||
") and output (W: ", | ||
output_width, | ||
")"); | ||
|
||
if (input.defined()) { | ||
AT_CHECK( | ||
input.numel() != 0 && input.dim() == 3, | ||
"non-empty 3D input tensor expected but got a tensor with sizes ", | ||
input.sizes()); | ||
} else if (grad_output.defined()) { | ||
check_dim_size(grad_output, 3, 0, nbatch); | ||
check_dim_size(grad_output, 3, 1, nchannels); | ||
check_dim_size(grad_output, 3, 2, output_width); | ||
} | ||
} | ||
|
||
static inline void upsample_2d_shape_check( | ||
const Tensor& input, | ||
const Tensor& grad_output, | ||
int nbatch, | ||
int nchannels, | ||
int input_height, | ||
int input_width, | ||
int output_height, | ||
int output_width) { | ||
AT_CHECK( | ||
input_height > 0 && input_width > 0 && output_height > 0 && | ||
output_width > 0, | ||
"input and output sizes should be greater than 0," | ||
" but got input (H: ", | ||
input_height, | ||
", W: ", | ||
input_width, | ||
") output (H: ", | ||
output_height, | ||
", W: ", | ||
output_width, | ||
")"); | ||
|
||
if (input.defined()) { | ||
AT_CHECK( | ||
input.numel() != 0 && input.dim() == 4, | ||
"non-empty 4D input tensor expected but got a tensor with sizes ", | ||
input.sizes()); | ||
} else if (grad_output.defined()) { | ||
check_dim_size(grad_output, 4, 0, nbatch); | ||
check_dim_size(grad_output, 4, 1, nchannels); | ||
check_dim_size(grad_output, 4, 2, output_height); | ||
check_dim_size(grad_output, 4, 3, output_width); | ||
} | ||
} | ||
|
||
static inline void upsample_3d_shape_check( | ||
const Tensor& input, | ||
const Tensor& grad_output, | ||
int nbatch, | ||
int nchannels, | ||
int input_depth, | ||
int input_height, | ||
int input_width, | ||
int output_depth, | ||
int output_height, | ||
int output_width) { | ||
AT_CHECK( | ||
input_depth > 0 && input_height > 0 && input_width > 0 && | ||
output_depth > 0 && output_height > 0 && output_width > 0, | ||
"Input and output sizes should be greater than 0, but got input (D: ", | ||
input_depth, | ||
", H: ", | ||
input_height, | ||
", W: ", | ||
input_width, | ||
") output (D: ", | ||
output_depth, | ||
", H: ", | ||
output_height, | ||
", W: ", | ||
output_width, | ||
")"); | ||
|
||
if (input.defined()) { | ||
AT_CHECK( | ||
input.numel() != 0 && input.dim() == 5, | ||
"Non-empty 5D data tensor expected but got a tensor with sizes ", | ||
input.sizes()); | ||
} else if (grad_output.defined()) { | ||
check_dim_size(grad_output, 5, 0, nbatch); | ||
check_dim_size(grad_output, 5, 1, nchannels); | ||
check_dim_size(grad_output, 5, 2, output_depth); | ||
check_dim_size(grad_output, 5, 3, output_height); | ||
check_dim_size(grad_output, 5, 4, output_width); | ||
} | ||
} | ||
|
||
template <typename accscalar_t> | ||
__host__ __forceinline__ static accscalar_t area_pixel_compute_scale( | ||
int input_size, | ||
int output_size, | ||
bool align_corners) { | ||
if (output_size > 1) { | ||
return align_corners ? (accscalar_t)(input_size - 1) / (output_size - 1) | ||
: (accscalar_t)input_size / output_size; | ||
} else { | ||
return static_cast<accscalar_t>(0); | ||
} | ||
} | ||
|
||
template <typename accscalar_t> | ||
__device__ __forceinline__ static accscalar_t area_pixel_compute_source_index( | ||
accscalar_t scale, | ||
int dst_index, | ||
bool align_corners, | ||
bool cubic) { | ||
if (align_corners) { | ||
return scale * dst_index; | ||
} else { | ||
accscalar_t src_idx = scale * (dst_index + static_cast<accscalar_t>(0.5)) - | ||
static_cast<accscalar_t>(0.5); | ||
// See Note[Follow Opencv resize logic] | ||
return (!cubic && src_idx < static_cast<accscalar_t>(0)) | ||
? static_cast<accscalar_t>(0) | ||
: src_idx; | ||
} | ||
} | ||
|
||
__device__ __forceinline__ static int nearest_neighbor_compute_source_index( | ||
const float scale, | ||
int dst_index, | ||
int input_size) { | ||
const int src_index = min<int>( | ||
static_cast<int>(floorf(dst_index * scale)), input_size - 1); | ||
return src_index; | ||
} | ||
|
||
/* just affect UpSampleBicubic2d.cu */ | ||
/* TODO: change width and height order in the arguments */ | ||
/* TODO: maybe change x and y order in the arguments */ | ||
/* TODO: maybe change channel and batch order in the arguments */ | ||
template <typename scalar_t> | ||
__device__ __forceinline__ static scalar_t upsample_get_value_bounded( | ||
const PackedTensorAccessor<scalar_t, 4>& data, | ||
int channel, | ||
int batch, | ||
int width, | ||
int height, | ||
int x, | ||
int y) { | ||
int access_x = | ||
max<int>(min<int>(x, width - 1), static_cast<int>(0)); | ||
int access_y = | ||
max<int>(min<int>(y, height - 1), static_cast<int>(0)); | ||
return data[batch][channel][access_y][access_x]; | ||
} | ||
|
||
/* just affect UpSampleBicubic2d.cu */ | ||
/* TODO: change width and height order in the arguments */ | ||
/* TODO: maybe change x and y order in the arguments */ | ||
/* TODO: maybe change channel and batch order in the arguments */ | ||
template <typename scalar_t, typename accscalar_t> | ||
__device__ __forceinline__ static void upsample_increment_value_bounded( | ||
PackedTensorAccessor<scalar_t, 4>& data, | ||
int channel, | ||
int batch, | ||
int width, | ||
int height, | ||
int x, | ||
int y, | ||
accscalar_t value) { | ||
int access_x = | ||
max<int>(min<int>(x, width - 1), static_cast<int>(0)); | ||
int access_y = | ||
max<int>(min<int>(y, height - 1), static_cast<int>(0)); | ||
/* TODO: result here is trucated to scalar_t, | ||
check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912 */ | ||
atomicAdd( | ||
&data[batch][channel][access_y][access_x], static_cast<scalar_t>(value)); | ||
} | ||
|
||
// Based on | ||
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm | ||
template <typename accscalar_t> | ||
__device__ __forceinline__ static accscalar_t cubic_convolution1( | ||
accscalar_t x, | ||
accscalar_t A) { | ||
return ((A + 2) * x - (A + 3)) * x * x + 1; | ||
} | ||
|
||
template <typename accscalar_t> | ||
__device__ __forceinline__ static accscalar_t cubic_convolution2( | ||
accscalar_t x, | ||
accscalar_t A) { | ||
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; | ||
} | ||
|
||
template <typename accscalar_t> | ||
__device__ __forceinline__ static void get_cubic_upsampling_coefficients( | ||
accscalar_t coeffs[4], | ||
accscalar_t t) { | ||
accscalar_t A = -0.75; | ||
|
||
accscalar_t x1 = t; | ||
coeffs[0] = cubic_convolution2<accscalar_t>(x1 + 1.0, A); | ||
coeffs[1] = cubic_convolution1<accscalar_t>(x1, A); | ||
|
||
// opposite coefficients | ||
accscalar_t x2 = 1.0 - t; | ||
coeffs[2] = cubic_convolution1<accscalar_t>(x2, A); | ||
coeffs[3] = cubic_convolution2<accscalar_t>(x2 + 1.0, A); | ||
} | ||
|
||
template <typename scalar_t, typename accscalar_t> | ||
__device__ __forceinline__ static accscalar_t cubic_interp1d( | ||
scalar_t x0, | ||
scalar_t x1, | ||
scalar_t x2, | ||
scalar_t x3, | ||
accscalar_t t) { | ||
accscalar_t coeffs[4]; | ||
get_cubic_upsampling_coefficients<accscalar_t>(coeffs, t); | ||
|
||
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; | ||
} | ||
|
||
} // namespace native | ||
} // namespace at |
Oops, something went wrong.