Skip to content

Commit

Permalink
Merge https://github.com/pytorch/pytorch into floating_point
Browse files Browse the repository at this point in the history
  • Loading branch information
KsenijaS committed Nov 3, 2020
2 parents 167a444 + 31ebac3 commit f13e1de
Show file tree
Hide file tree
Showing 19 changed files with 930 additions and 70 deletions.
109 changes: 101 additions & 8 deletions aten/src/ATen/native/GridSampler.cpp
Expand Up @@ -6,6 +6,7 @@
#include <c10/core/Layout.h>
#include <ATen/cpu/vml.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/UpSample.h>
#include <ATen/native/cpu/GridSamplerKernel.h>
#include <c10/util/Exception.h>

Expand Down Expand Up @@ -422,11 +423,11 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
for (int64_t w = 0; w < out_W; ++w) {
// get the corresponding input x, y, z co-ordinates from grid
scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NHW;
scalar_t iy = grid_ptr_NHW[grid_sCoor];
scalar_t x = *grid_ptr_NHW;
scalar_t y = grid_ptr_NHW[grid_sCoor];

ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);

if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y)
Expand Down Expand Up @@ -483,6 +484,43 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
*out_ptr_NCHW = static_cast<scalar_t>(0);
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
// grid_sampler_compute_source_index will "clip the value" of idx depends on the padding,
// which would cause calculation to be wrong,
// for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix = floor(x) = -1
// There would be more problem in reflection padding, since the -1 and +1 direction is not fixed in boundary condition
ix = grid_sampler_unnormalize(x, inp_W, align_corners);
iy = grid_sampler_unnormalize(y, inp_H, align_corners);

scalar_t ix_nw = std::floor(ix);
scalar_t iy_nw = std::floor(iy);

const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;

scalar_t *inp_ptr_NC = inp_ptr_N;
scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
scalar_t coefficients[4];

// Interpolate 4 values in the x directon
for (int64_t i = 0; i < 4; ++i) {
coefficients[i] = cubic_interp1d<scalar_t>(
get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
tx);
}

// Interpolate in the y direction
*out_ptr_NCHW = cubic_interp1d<scalar_t>(
coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
ty);
}
}
}
}
Expand Down Expand Up @@ -547,13 +585,13 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) {
// get the corresponding input x, y co-ordinates from grid
scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NHW;
scalar_t iy = grid_ptr_NHW[grid_sCoor];
scalar_t x = *grid_ptr_NHW;
scalar_t y = grid_ptr_NHW[grid_sCoor];

// multipliers for gradients on ix, iy
scalar_t gix_mult, giy_mult;
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);

if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y)
Expand Down Expand Up @@ -628,6 +666,55 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW,
inp_H, inp_W, *gOut_ptr_NCHW);
}
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {

ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult);
iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult);

scalar_t ix_nw = std::floor(ix);
scalar_t iy_nw = std::floor(iy);

const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;

scalar_t x_coeffs[4];
scalar_t y_coeffs[4];
scalar_t x_coeffs_grad[4];
scalar_t y_coeffs_grad[4];

get_cubic_upsample_coefficients<scalar_t>(x_coeffs, tx);
get_cubic_upsample_coefficients<scalar_t>(y_coeffs, ty);
get_cubic_coefficients_grad<scalar_t>(x_coeffs_grad, tx);
get_cubic_coefficients_grad<scalar_t>(y_coeffs_grad, ty);

scalar_t gix = static_cast<scalar_t>(0);
scalar_t giy = static_cast<scalar_t>(0);

scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
scalar_t *inp_ptr_NC = inp_ptr_N;

for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) {
scalar_t gOut = *gOut_ptr_NCHW;

for (int64_t i = 0; i < 4; ++i) {
for (int64_t j = 0; j < 4; ++j) {

// set input gradient
add_value_bounded<scalar_t>(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
inp_W, inp_H, gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners);

// set grid gradient
scalar_t val = get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners);

gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut;
giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut;
}
}
}
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
}
}
}
Expand All @@ -640,6 +727,7 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {

// AVX gather instructions use signed 32-bit offsets to gather float values.
// Check for possible overflow and fallback to scalar implementation
if (input.scalar_type() != kDouble) {
Expand Down Expand Up @@ -682,6 +770,7 @@ Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid,
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {

// AVX gather instructions use signed 32-bit offsets to gather float values.
// Check for possible overflow and fallback to scalar implementation
if (input.scalar_type() != kDouble) {
Expand Down Expand Up @@ -757,6 +846,10 @@ Tensor grid_sampler(const Tensor& input, const Tensor& grid,
grid.size(-1) == input.dim() - 2,
"grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
"dimension, but got grid with sizes ", grid.sizes());
TORCH_CHECK(
!(input.dim() == 5 && static_cast<GridSamplerInterpolation>(interpolation_mode) == GridSamplerInterpolation::Bicubic),
"grid_sampler(): bicubic interpolation only supports 4D input"
);
for (int64_t i = 2; i < input.dim(); i++) {
TORCH_CHECK(input.size(i) > 0,
"grid_sampler(): expected input to have non-empty spatial dimensions, "
Expand Down
95 changes: 86 additions & 9 deletions aten/src/ATen/native/GridSampler.h
Expand Up @@ -7,7 +7,7 @@ namespace at { namespace native {

namespace detail {

enum class GridSamplerInterpolation {Bilinear, Nearest};
enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
enum class GridSamplerPadding {Zeros, Border, Reflection};

} // namespace detail
Expand Down Expand Up @@ -139,14 +139,12 @@ static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_l
}
}

// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
// Mapping the out-of-boundary points back into boundary
// This would only affect padding_mode=border or reflection
template<typename scalar_t>
static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
Expand All @@ -163,6 +161,18 @@ static inline scalar_t grid_sampler_compute_source_index(
return coord;
}

// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
coord = compute_coordinates(coord, size, padding_mode, align_corners);
return coord;
}

// grid_sampler_compute_source_index_set_grad works similarly to
// grid_sampler_compute_source_index except that it also returns the
// `d output / d input` via pointer argument `grad_in`.
Expand Down Expand Up @@ -202,6 +212,30 @@ static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D,
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}

template<typename scalar_t>
static inline scalar_t get_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int64_t W,
int64_t H,
int64_t sW,
int64_t sH,
GridSamplerPadding padding_mode,
bool align_corners) {

x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);

int64_t ix = static_cast<int64_t>(x);
int64_t iy = static_cast<int64_t>(y);

if (within_bounds_2d(iy, ix, H, W)) {
return data[iy * sH + ix * sW];
}
return static_cast<scalar_t>(0);
}

template<typename scalar_t>
static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
int64_t sH, int64_t sW, int64_t H, int64_t W,
Expand All @@ -221,4 +255,47 @@ static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
}
}

template<typename scalar_t>
static inline void add_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int64_t W,
int64_t H,
int64_t sW,
int64_t sH,
scalar_t delta,
GridSamplerPadding padding_mode,
bool align_corners) {

x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);

int64_t ix = static_cast<int64_t>(x);
int64_t iy = static_cast<int64_t>(y);

safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
}

// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
template<typename scalar_t>
static inline void get_cubic_coefficients_grad(
scalar_t coeffs[4],
scalar_t t) {

// Must be the same as forward calculation in
// aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
scalar_t A = -0.75;

scalar_t x;
x = -1 - t; // 1 < x = |-1 - tx| < 2
coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
x = -t; // x = |0 - tx| <= 1
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 1 - t; // x = |1 - tx| <= 1
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 2 - t; // 1 < x = |2 - tx| < 2
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
}

}} // namespace at::native
33 changes: 24 additions & 9 deletions aten/src/ATen/native/TensorTransformations.cpp
Expand Up @@ -61,15 +61,30 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) {
}
}

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] {
flip_cpu_kernel<scalar_t>(
total_dims,
stride_contiguous_v,
flip_dims_b,
in_tensor,
out_tensor
);
});
if (in_tensor.is_quantized()) {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(in_tensor.scalar_type(),
"flip_quantized_cpu", [&] {
flip_cpu_kernel<scalar_t>(
total_dims,
stride_contiguous_v,
flip_dims_b,
in_tensor,
out_tensor
);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool,
in_tensor.scalar_type(),
"flip_cpu", [&] {
flip_cpu_kernel<scalar_t>(
total_dims,
stride_contiguous_v,
flip_dims_b,
in_tensor,
out_tensor
);
});
}

return out_tensor;
}
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -351,17 +351,17 @@ Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(
Tensor cosh(const Tensor& self) { return unary_op_impl(self, at::cosh_out); }
Tensor& cosh_(Tensor& self) { return unary_op_impl_(self, at::cosh_out); }

Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, acosh_stub); }
Tensor acosh(const Tensor& self) { return unary_op_impl(self, at::acosh_out); }
Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, acosh_stub); }
Tensor acosh(const Tensor& self) { return unary_op_impl_float(self, acosh_stub); }
Tensor& acosh_(Tensor& self) { return unary_op_impl_(self, at::acosh_out); }

// arccosh, alias for acosh
Tensor& arccosh_out(Tensor& result, const Tensor& self) { return at::acosh_out(result, self); }
Tensor arccosh(const Tensor& self) { return at::acosh(self); }
Tensor& arccosh_(Tensor& self) { return at::acosh_(self); }

Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, asinh_stub); }
Tensor asinh(const Tensor& self) { return unary_op_impl(self, at::asinh_out); }
Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, asinh_stub); }
Tensor asinh(const Tensor& self) { return unary_op_impl_float(self, asinh_stub); }
Tensor& asinh_(Tensor& self) { return unary_op_impl_(self, at::asinh_out); }

// arcsinh, alias for asinh
Expand Down

0 comments on commit f13e1de

Please sign in to comment.