Skip to content

Commit

Permalink
Implement bicubic grid sampler (#44780)
Browse files Browse the repository at this point in the history
Summary:
Fix #44601

I added bicubic grid sampler in both cpu and cuda side, but haven't in AVX2

There is a [colab notebook](https://colab.research.google.com/drive/1mIh6TLLj5WWM_NcmKDRvY5Gltbb781oU?usp=sharing) show some test results. The notebook use bilinear for test, since I could only use distributed version of pytorch in it. You could just download it and modify the `mode_torch=bicubic` to show the results.

There are some duplicate code about getting and setting values, since the helper function used in bilinear at first clip the coordinate beyond boundary, and then get or set the value. However, in bicubic, there are more points should be consider. I could refactor that part after making sure the overall calculation are correct.

Thanks

Pull Request resolved: #44780

Reviewed By: mrshenli

Differential Revision: D24681114

Pulled By: mruberry

fbshipit-source-id: d39c8715e2093a5a5906cb0ef040d62bde578567
  • Loading branch information
pomelyu authored and facebook-github-bot committed Nov 3, 2020
1 parent 6397855 commit f41f3e3
Show file tree
Hide file tree
Showing 8 changed files with 665 additions and 47 deletions.
109 changes: 101 additions & 8 deletions aten/src/ATen/native/GridSampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <c10/core/Layout.h>
#include <ATen/cpu/vml.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/UpSample.h>
#include <ATen/native/cpu/GridSamplerKernel.h>
#include <c10/util/Exception.h>

Expand Down Expand Up @@ -422,11 +423,11 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
for (int64_t w = 0; w < out_W; ++w) {
// get the corresponding input x, y, z co-ordinates from grid
scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NHW;
scalar_t iy = grid_ptr_NHW[grid_sCoor];
scalar_t x = *grid_ptr_NHW;
scalar_t y = grid_ptr_NHW[grid_sCoor];

ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);

if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y)
Expand Down Expand Up @@ -483,6 +484,43 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
*out_ptr_NCHW = static_cast<scalar_t>(0);
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
// grid_sampler_compute_source_index will "clip the value" of idx depends on the padding,
// which would cause calculation to be wrong,
// for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix = floor(x) = -1
// There would be more problem in reflection padding, since the -1 and +1 direction is not fixed in boundary condition
ix = grid_sampler_unnormalize(x, inp_W, align_corners);
iy = grid_sampler_unnormalize(y, inp_H, align_corners);

scalar_t ix_nw = std::floor(ix);
scalar_t iy_nw = std::floor(iy);

const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;

scalar_t *inp_ptr_NC = inp_ptr_N;
scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
scalar_t coefficients[4];

// Interpolate 4 values in the x directon
for (int64_t i = 0; i < 4; ++i) {
coefficients[i] = cubic_interp1d<scalar_t>(
get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
tx);
}

// Interpolate in the y direction
*out_ptr_NCHW = cubic_interp1d<scalar_t>(
coefficients[0],
coefficients[1],
coefficients[2],
coefficients[3],
ty);
}
}
}
}
Expand Down Expand Up @@ -547,13 +585,13 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) {
// get the corresponding input x, y co-ordinates from grid
scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NHW;
scalar_t iy = grid_ptr_NHW[grid_sCoor];
scalar_t x = *grid_ptr_NHW;
scalar_t y = grid_ptr_NHW[grid_sCoor];

// multipliers for gradients on ix, iy
scalar_t gix_mult, giy_mult;
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);

if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y)
Expand Down Expand Up @@ -628,6 +666,55 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW,
inp_H, inp_W, *gOut_ptr_NCHW);
}
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {

ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult);
iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult);

scalar_t ix_nw = std::floor(ix);
scalar_t iy_nw = std::floor(iy);

const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;

scalar_t x_coeffs[4];
scalar_t y_coeffs[4];
scalar_t x_coeffs_grad[4];
scalar_t y_coeffs_grad[4];

get_cubic_upsample_coefficients<scalar_t>(x_coeffs, tx);
get_cubic_upsample_coefficients<scalar_t>(y_coeffs, ty);
get_cubic_coefficients_grad<scalar_t>(x_coeffs_grad, tx);
get_cubic_coefficients_grad<scalar_t>(y_coeffs_grad, ty);

scalar_t gix = static_cast<scalar_t>(0);
scalar_t giy = static_cast<scalar_t>(0);

scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
scalar_t *inp_ptr_NC = inp_ptr_N;

for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) {
scalar_t gOut = *gOut_ptr_NCHW;

for (int64_t i = 0; i < 4; ++i) {
for (int64_t j = 0; j < 4; ++j) {

// set input gradient
add_value_bounded<scalar_t>(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
inp_W, inp_H, gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners);

// set grid gradient
scalar_t val = get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners);

gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut;
giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut;
}
}
}
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
}
}
}
Expand All @@ -640,6 +727,7 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode,
bool align_corners) {

// AVX gather instructions use signed 32-bit offsets to gather float values.
// Check for possible overflow and fallback to scalar implementation
if (input.scalar_type() != kDouble) {
Expand Down Expand Up @@ -682,6 +770,7 @@ Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid,
std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {

// AVX gather instructions use signed 32-bit offsets to gather float values.
// Check for possible overflow and fallback to scalar implementation
if (input.scalar_type() != kDouble) {
Expand Down Expand Up @@ -757,6 +846,10 @@ Tensor grid_sampler(const Tensor& input, const Tensor& grid,
grid.size(-1) == input.dim() - 2,
"grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
"dimension, but got grid with sizes ", grid.sizes());
TORCH_CHECK(
!(input.dim() == 5 && static_cast<GridSamplerInterpolation>(interpolation_mode) == GridSamplerInterpolation::Bicubic),
"grid_sampler(): bicubic interpolation only supports 4D input"
);
for (int64_t i = 2; i < input.dim(); i++) {
TORCH_CHECK(input.size(i) > 0,
"grid_sampler(): expected input to have non-empty spatial dimensions, "
Expand Down
95 changes: 86 additions & 9 deletions aten/src/ATen/native/GridSampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace at { namespace native {

namespace detail {

enum class GridSamplerInterpolation {Bilinear, Nearest};
enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
enum class GridSamplerPadding {Zeros, Border, Reflection};

} // namespace detail
Expand Down Expand Up @@ -139,14 +139,12 @@ static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_l
}
}

// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
// Mapping the out-of-boundary points back into boundary
// This would only affect padding_mode=border or reflection
template<typename scalar_t>
static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
Expand All @@ -163,6 +161,18 @@ static inline scalar_t grid_sampler_compute_source_index(
return coord;
}

// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
coord = compute_coordinates(coord, size, padding_mode, align_corners);
return coord;
}

// grid_sampler_compute_source_index_set_grad works similarly to
// grid_sampler_compute_source_index except that it also returns the
// `d output / d input` via pointer argument `grad_in`.
Expand Down Expand Up @@ -202,6 +212,30 @@ static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D,
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}

template<typename scalar_t>
static inline scalar_t get_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int64_t W,
int64_t H,
int64_t sW,
int64_t sH,
GridSamplerPadding padding_mode,
bool align_corners) {

x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);

int64_t ix = static_cast<int64_t>(x);
int64_t iy = static_cast<int64_t>(y);

if (within_bounds_2d(iy, ix, H, W)) {
return data[iy * sH + ix * sW];
}
return static_cast<scalar_t>(0);
}

template<typename scalar_t>
static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
int64_t sH, int64_t sW, int64_t H, int64_t W,
Expand All @@ -221,4 +255,47 @@ static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
}
}

template<typename scalar_t>
static inline void add_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int64_t W,
int64_t H,
int64_t sW,
int64_t sH,
scalar_t delta,
GridSamplerPadding padding_mode,
bool align_corners) {

x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);

int64_t ix = static_cast<int64_t>(x);
int64_t iy = static_cast<int64_t>(y);

safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
}

// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
template<typename scalar_t>
static inline void get_cubic_coefficients_grad(
scalar_t coeffs[4],
scalar_t t) {

// Must be the same as forward calculation in
// aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
scalar_t A = -0.75;

scalar_t x;
x = -1 - t; // 1 < x = |-1 - tx| < 2
coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
x = -t; // x = |0 - tx| <= 1
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 1 - t; // x = |1 - tx| <= 1
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 2 - t; // 1 < x = |2 - tx| < 2
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
}

}} // namespace at::native

0 comments on commit f41f3e3

Please sign in to comment.