diff --git a/aten/src/ATen/native/GridSampler.cpp b/aten/src/ATen/native/GridSampler.cpp index 59242e0e6c03..667cbe8f07b3 100644 --- a/aten/src/ATen/native/GridSampler.cpp +++ b/aten/src/ATen/native/GridSampler.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -422,11 +423,11 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, for (int64_t w = 0; w < out_W; ++w) { // get the corresponding input x, y, z co-ordinates from grid scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; - scalar_t ix = *grid_ptr_NHW; - scalar_t iy = grid_ptr_NHW[grid_sCoor]; + scalar_t x = *grid_ptr_NHW; + scalar_t y = grid_ptr_NHW[grid_sCoor]; - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get corner pixel values from (x, y) @@ -483,6 +484,43 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, *out_ptr_NCHW = static_cast(0); } } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + // grid_sampler_compute_source_index will "clip the value" of idx depends on the padding, + // which would cause calculation to be wrong, + // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix = floor(x) = -1 + // There would be more problem in reflection padding, since the -1 and +1 direction is not fixed in boundary condition + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + scalar_t ix_nw = std::floor(ix); + scalar_t iy_nw = std::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t *inp_ptr_NC = inp_ptr_N; + scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { + scalar_t coefficients[4]; + + // Interpolate 4 values in the x directon + for (int64_t i = 0; i < 4; ++i) { + coefficients[i] = cubic_interp1d( + get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + // Interpolate in the y direction + *out_ptr_NCHW = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + ty); + } } } } @@ -547,13 +585,13 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) { // get the corresponding input x, y co-ordinates from grid scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; - scalar_t ix = *grid_ptr_NHW; - scalar_t iy = grid_ptr_NHW[grid_sCoor]; + scalar_t x = *grid_ptr_NHW; + scalar_t y = grid_ptr_NHW[grid_sCoor]; // multipliers for gradients on ix, iy scalar_t gix_mult, giy_mult; - ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); - iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get corner pixel values from (x, y) @@ -628,6 +666,55 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW); } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult); + iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult); + + scalar_t ix_nw = std::floor(ix); + scalar_t iy_nw = std::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t x_coeffs[4]; + scalar_t y_coeffs[4]; + scalar_t x_coeffs_grad[4]; + scalar_t y_coeffs_grad[4]; + + get_cubic_upsample_coefficients(x_coeffs, tx); + get_cubic_upsample_coefficients(y_coeffs, ty); + get_cubic_coefficients_grad(x_coeffs_grad, tx); + get_cubic_coefficients_grad(y_coeffs_grad, ty); + + scalar_t gix = static_cast(0); + scalar_t giy = static_cast(0); + + scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; + scalar_t *inp_ptr_NC = inp_ptr_N; + + for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + for (int64_t i = 0; i < 4; ++i) { + for (int64_t j = 0; j < 4; ++j) { + + // set input gradient + add_value_bounded(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners); + + // set grid gradient + scalar_t val = get_value_bounded(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners); + + gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; + giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; + } + } + } + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; } } } @@ -640,6 +727,7 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + // AVX gather instructions use signed 32-bit offsets to gather float values. // Check for possible overflow and fallback to scalar implementation if (input.scalar_type() != kDouble) { @@ -682,6 +770,7 @@ Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid, std::tuple grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + // AVX gather instructions use signed 32-bit offsets to gather float values. // Check for possible overflow and fallback to scalar implementation if (input.scalar_type() != kDouble) { @@ -757,6 +846,10 @@ Tensor grid_sampler(const Tensor& input, const Tensor& grid, grid.size(-1) == input.dim() - 2, "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last " "dimension, but got grid with sizes ", grid.sizes()); + TORCH_CHECK( + !(input.dim() == 5 && static_cast(interpolation_mode) == GridSamplerInterpolation::Bicubic), + "grid_sampler(): bicubic interpolation only supports 4D input" + ); for (int64_t i = 2; i < input.dim(); i++) { TORCH_CHECK(input.size(i) > 0, "grid_sampler(): expected input to have non-empty spatial dimensions, " diff --git a/aten/src/ATen/native/GridSampler.h b/aten/src/ATen/native/GridSampler.h index ebafc9727061..effc322c0d3a 100644 --- a/aten/src/ATen/native/GridSampler.h +++ b/aten/src/ATen/native/GridSampler.h @@ -7,7 +7,7 @@ namespace at { namespace native { namespace detail { - enum class GridSamplerInterpolation {Bilinear, Nearest}; + enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic}; enum class GridSamplerPadding {Zeros, Border, Reflection}; } // namespace detail @@ -139,14 +139,12 @@ static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_l } } -// Computes the pixel source index value for a grid coordinate -template -static inline scalar_t grid_sampler_compute_source_index( - scalar_t coord, - int64_t size, - GridSamplerPadding padding_mode, - bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); +// Mapping the out-of-boundary points back into boundary +// This would only affect padding_mode=border or reflection +template +static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { if (padding_mode == GridSamplerPadding::Border) { // clip coordinates to image borders coord = clip_coordinates(coord, size); @@ -163,6 +161,18 @@ static inline scalar_t grid_sampler_compute_source_index( return coord; } +// Computes the pixel source index value for a grid coordinate +template +static inline scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; +} + // grid_sampler_compute_source_index_set_grad works similarly to // grid_sampler_compute_source_index except that it also returns the // `d output / d input` via pointer argument `grad_in`. @@ -202,6 +212,30 @@ static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } +template +static inline scalar_t get_value_bounded( + scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + template static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w, int64_t sH, int64_t sW, int64_t H, int64_t W, @@ -221,4 +255,47 @@ static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w, } } +template +static inline void add_value_bounded( + scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static inline void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 1b86b3f2d634..fdee519c4bd0 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -61,15 +61,30 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) { } } - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] { - flip_cpu_kernel( - total_dims, - stride_contiguous_v, - flip_dims_b, - in_tensor, - out_tensor - ); - }); + if (in_tensor.is_quantized()) { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(in_tensor.scalar_type(), + "flip_quantized_cpu", [&] { + flip_cpu_kernel( + total_dims, + stride_contiguous_v, + flip_dims_b, + in_tensor, + out_tensor + ); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, + in_tensor.scalar_type(), + "flip_cpu", [&] { + flip_cpu_kernel( + total_dims, + stride_contiguous_v, + flip_dims_b, + in_tensor, + out_tensor + ); + }); + } return out_tensor; } diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 8c3fc182e646..7d57651005d5 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -351,8 +351,8 @@ Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out( Tensor cosh(const Tensor& self) { return unary_op_impl(self, at::cosh_out); } Tensor& cosh_(Tensor& self) { return unary_op_impl_(self, at::cosh_out); } -Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, acosh_stub); } -Tensor acosh(const Tensor& self) { return unary_op_impl(self, at::acosh_out); } +Tensor& acosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, acosh_stub); } +Tensor acosh(const Tensor& self) { return unary_op_impl_float(self, acosh_stub); } Tensor& acosh_(Tensor& self) { return unary_op_impl_(self, at::acosh_out); } // arccosh, alias for acosh @@ -360,8 +360,8 @@ Tensor& arccosh_out(Tensor& result, const Tensor& self) { return at::acosh_out(r Tensor arccosh(const Tensor& self) { return at::acosh(self); } Tensor& arccosh_(Tensor& self) { return at::acosh_(self); } -Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, asinh_stub); } -Tensor asinh(const Tensor& self) { return unary_op_impl(self, at::asinh_out); } +Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, asinh_stub); } +Tensor asinh(const Tensor& self) { return unary_op_impl_float(self, asinh_stub); } Tensor& asinh_(Tensor& self) { return unary_op_impl_(self, at::asinh_out); } // arcsinh, alias for asinh diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 4dfe644b89a4..ece2d527e899 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -354,6 +354,10 @@ struct ComputeLocation return unnormalize(in); } + inline Vec compute_coordinates(const Vec &in) const { + return in; + } + inline std::pair apply_get_grad(const Vec &in) const { return std::make_pair(unnormalize(in), Vec(scaling_factor)); } @@ -374,6 +378,10 @@ struct ComputeLocation return clip_coordinates(unnormalize(in)); } + inline Vec compute_coordinates(const Vec &in) const { + return clip_coordinates(in); + } + inline std::pair apply_get_grad(const Vec &in) const { Vec res, grad_clip; std::tie(res, grad_clip) = clip_coordinates_get_grad(unnormalize(in)); @@ -400,6 +408,12 @@ struct ComputeLocation return res; } + inline Vec compute_coordinates(const Vec &in) const { + auto res = reflect_coordinates(in); + res = clip_coordinates(res); + return res; + } + inline std::pair apply_get_grad(const Vec &in) const { Vec res, grad_refl, grad_clip, grad(scaling_factor); std::tie(res, grad_refl) = reflect_coordinates_get_grad(unnormalize(in)); @@ -764,6 +778,202 @@ struct ApplyGridSample +struct ApplyGridSample { + using Vec = Vec256; + using integer_t = int_same_size_t; + using iVec = Vec256; + + const int64_t inp_H; + const int64_t inp_W; + const int64_t inp_sH; + const int64_t inp_sW; + const int64_t C; + const int64_t inp_sC; + const ComputeLocation compute_H; + const ComputeLocation compute_W; + const bool must_in_bound = padding != GridSamplerPadding::Zeros; + + // constant used in cubic convolution + // could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h + const Vec A = Vec(-0.75); + + ApplyGridSample(const TensorAccessor& input) + : inp_H(input.size(2)) + , inp_W(input.size(3)) + , inp_sH(input.stride(2)) + , inp_sW(input.stride(3)) + , C(input.size(1)) + , inp_sC(input.stride(1)) + , compute_H(input.size(2)) + , compute_W(input.size(3)) {} + + // Calculate the cubic convolution coefficient + inline void get_cubic_coefficients(Vec (&coeffs)[4], const Vec& tx) const { + Vec x; + x = tx + Vec(1); // 1 < x = |-1 - tx| < 2 + coeffs[0] = ((A * x - Vec(5) * A) * x + Vec(8) * A) * x - Vec(4) * A; + x = tx; // x = |0 - tx| <= 1 + coeffs[1] = ((A + Vec(2)) * x - (A + Vec(3))) * x * x + Vec(1); + x = Vec(1) - tx; // x = |1 - tx| <= 1 + coeffs[2] = ((A + Vec(2)) * x - (A + Vec(3))) * x * x + Vec(1); + x = Vec(2) - tx; // 1 < x = |2 - tx| < 2 + coeffs[3] = ((A * x - Vec(5) * A) * x + Vec(8) * A) * x - Vec(4) * A; + } + + // Calculate the differential of the cubic convolution, i.e. `d coeff / d x` + inline void get_cubic_coefficients_grad(Vec (&coeffs)[4], const Vec& tx) const { + Vec x; + x = Vec(-1) - tx; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (Vec(-3) * A * x - Vec(10) * A ) * x - Vec(8) * A; + x = Vec(0) - tx; // x = |0 - tx| <= 1 + coeffs[1] = (Vec(-3) * (A + Vec(2)) * x - Vec(2) * (A + Vec(3))) * x; + x = Vec(1) - tx; // x = |1 - tx| <= 1 + coeffs[2] = (Vec(3) * (A + Vec(2)) * x - Vec(2) * (A + Vec(3))) * x; + x = Vec(2) - tx; // 1 < x = |2 - tx| < 2 + coeffs[3] = (Vec(3) * A * x - Vec(10) * A) * x + Vec(8) * A; + } + + inline Vec get_value_bounded(const scalar_t* data, const Vec& x, const Vec& y) const { + auto ix = convert_to_int_of_same_size(compute_W.compute_coordinates(x)); + auto iy = convert_to_int_of_same_size(compute_H.compute_coordinates(y)); + + auto mask_x = must_in_bound ? iVec(-1) : (ix > iVec(-1)) & (ix < iVec(inp_W)); + auto mask_y = must_in_bound ? iVec(-1) : (iy > iVec(-1)) & (iy < iVec(inp_H)); + auto mask = cast(mask_x & mask_y); + + auto offset = iy * iVec(inp_sH) + ix * iVec(inp_sW); + + auto val = mask_gather(Vec(0), data, offset, mask); + return val; + } + + inline void add_value_bounded(scalar_t* data, int64_t len, const Vec& x, const Vec&y, + const Vec& delta) const { + + auto ix = convert_to_int_of_same_size(compute_W.compute_coordinates(x)); + auto iy = convert_to_int_of_same_size(compute_H.compute_coordinates(y)); + + auto mask_x = must_in_bound ? iVec(-1) : (ix > iVec(-1)) & (ix < iVec(inp_W)); + auto mask_y = must_in_bound ? iVec(-1) : (iy > iVec(-1)) & (iy < iVec(inp_H)); + auto mask = cast(mask_x & mask_y); + + auto i_gInp_offset = iy * iVec(inp_W) + ix; + integer_t i_gInp_offset_arr[iVec::size()]; + i_gInp_offset.store(i_gInp_offset_arr); + + integer_t mask_arr[iVec::size()]; + mask.store(mask_arr); + + scalar_t gInp_corner_arr[Vec::size()]; + delta.store(gInp_corner_arr); + + mask_scatter_add(gInp_corner_arr, data, i_gInp_offset_arr, mask_arr, len); + } + + inline void forward(TensorAccessor& out_slice, + const TensorAccessor& inp_slice, + int64_t offset, const Vec& grid_x, const Vec& grid_y, + int64_t len) const { + + auto x = compute_W.unnormalize(grid_x); + auto y = compute_H.unnormalize(grid_y); + + auto ix = x.floor(); + auto iy = y.floor(); + + Vec coeff_x[4]; + Vec coeff_y[4]; + get_cubic_coefficients(coeff_x, x - ix); + get_cubic_coefficients(coeff_y, y - iy); + + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (int64_t c = 0; c < C; ++c) { + auto inp_slice_C_ptr = inp_slice[c].data(); + + // Interpolate the 4 values in the x direction + Vec interp_x[4]; + for (int64_t i = 0; i < 4; ++i) { + interp_x[i] = + coeff_x[0] * get_value_bounded(inp_slice_C_ptr, ix - Vec(1), iy + Vec(-1 + i)) + + coeff_x[1] * get_value_bounded(inp_slice_C_ptr, ix + Vec(0), iy + Vec(-1 + i)) + + coeff_x[2] * get_value_bounded(inp_slice_C_ptr, ix + Vec(1), iy + Vec(-1 + i)) + + coeff_x[3] * get_value_bounded(inp_slice_C_ptr, ix + Vec(2), iy + Vec(-1 + i)); + } + + // Interpolate the 4 values in the y direction + auto interpolated = coeff_y[0] * interp_x[0] + coeff_y[1] * interp_x[1] + + coeff_y[2] * interp_x[2] + coeff_y[3] * interp_x[3]; + interpolated.store(out_slice[c].data() + offset, len); + } + } + + inline void backward(TensorAccessor& gInp_slice, + TensorAccessor& gGrid_slice, + const TensorAccessor& gOut_slice, + const TensorAccessor& inp_slice, + int64_t offset, const Vec& grid_x, const Vec& grid_y, + int64_t len) const { + + Vec x = compute_W.unnormalize(grid_x); + Vec y = compute_H.unnormalize(grid_y); + Vec gx_mult = Vec(compute_W.scaling_factor); + Vec gy_mult = Vec(compute_H.scaling_factor); + + auto ix = x.floor(); + auto iy = y.floor(); + + Vec coeff_x[4]; + Vec coeff_y[4]; + get_cubic_coefficients(coeff_x, x - ix); + get_cubic_coefficients(coeff_y, y - iy); + + Vec coeff_x_grad[4]; + Vec coeff_y_grad[4]; + get_cubic_coefficients_grad(coeff_x_grad, x - ix); + get_cubic_coefficients_grad(coeff_y_grad, y - iy); + + auto gx = Vec(0), gy = Vec(0); + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (int64_t c = 0; c < C; ++c) { + auto inp_slice_C_ptr = inp_slice[c].data(); + auto gInp_slice_C_ptr = gInp_slice[c].data(); + auto gOut = Vec::loadu(gOut_slice[c].data() + offset, len); + + for (int64_t i = 0; i < 4; ++i) { + for (int64_t j = 0; j < 4; ++j) { + auto xx = ix + Vec(-1 + i); + auto yy = iy + Vec(-1 + j); + + add_value_bounded(gInp_slice_C_ptr, len, xx, yy, gOut * coeff_x[i] * coeff_y[j]); + + auto val = get_value_bounded(inp_slice_C_ptr, xx, yy); + gx = gx - val * gOut * coeff_x_grad[i] * coeff_y[j]; + gy = gy - val * gOut * coeff_y_grad[j] * coeff_x[i]; + } + } + } + + gx = gx * gx_mult; + gy = gy * gy_mult; + + constexpr int64_t step = Vec::size(); + auto interleaved_gGrid = interleave2(gx, gy); + auto gGrid_ptr = gGrid_slice.data() + offset * 2; + std::get<0>(interleaved_gGrid).store(gGrid_ptr, + std::min(len * 2, step)); + std::get<1>(interleaved_gGrid).store(gGrid_ptr + step, + std::max(static_cast(0), len * 2 - step)); + } +}; + // ~~~~~~~~~~~~~~~~~~ grid_sample_2d_grid_slice_iterator ~~~~~~~~~~~~~~~~~~~~~~ // Function to apply a vectorized function on a grid slice tensor (without batch // dimension). @@ -940,11 +1150,13 @@ Tensor grid_sampler_2d_cpu_kernel_impl(const Tensor& input, const Tensor& grid, switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true); HANDLE_INTERP(GridSamplerInterpolation::Nearest, true); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true); } } else { switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false); HANDLE_INTERP(GridSamplerInterpolation::Nearest, false); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false); } } }); @@ -1014,11 +1226,13 @@ grid_sampler_2d_backward_cpu_kernel_impl(const Tensor& grad_output_, switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true); HANDLE_INTERP(GridSamplerInterpolation::Nearest, true); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true); } } else { switch (static_cast(interpolation_mode)) { HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false); HANDLE_INTERP(GridSamplerInterpolation::Nearest, false); + HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false); } } }); diff --git a/aten/src/ATen/native/cuda/GridSampler.cu b/aten/src/ATen/native/cuda/GridSampler.cu index 023167109af2..7674e9137238 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cu +++ b/aten/src/ATen/native/cuda/GridSampler.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -50,11 +51,11 @@ namespace { const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid - scalar_t ix = grid.data[grid_offset]; - scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get NE, NW, SE, SW pixel values from (x, y) @@ -105,6 +106,38 @@ namespace { *out_ptr_NCHW = static_cast(0); } } + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + scalar_t ix_nw = ::floor(ix); + scalar_t iy_nw = ::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + scalar_t coefficients[4]; + + for (index_t i = 0; i < 4; ++i) { + coefficients[i] = cubic_interp1d( + get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + *out_ptr_NCHW = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + ty); + } } } } @@ -300,13 +333,13 @@ namespace { const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid - scalar_t ix = grid.data[grid_offset]; - scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; // multipliers for gradients on ix and iy scalar_t gix_mult, giy_mult; - ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); - iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); if (interpolation_mode == GridSamplerInterpolation::Bilinear) { // get NE, NW, SE, SW pixel values from (x, y) @@ -387,6 +420,57 @@ namespace { scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; gGrid_ptr_NHW[0] = static_cast(0); gGrid_ptr_NHW[1] = static_cast(0); + } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { + + ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult); + iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult); + + scalar_t ix_nw = ::floor(ix); + scalar_t iy_nw = ::floor(iy); + + const scalar_t tx = ix - ix_nw; + const scalar_t ty = iy - iy_nw; + + scalar_t x_coeffs[4]; + scalar_t y_coeffs[4]; + scalar_t x_coeffs_grad[4]; + scalar_t y_coeffs_grad[4]; + + get_cubic_upsampling_coefficients(x_coeffs, tx); + get_cubic_upsampling_coefficients(y_coeffs, ty); + get_cubic_coefficients_grad(x_coeffs_grad, tx); + get_cubic_coefficients_grad(y_coeffs_grad, ty); + + scalar_t gix = static_cast(0); + scalar_t giy = static_cast(0); + + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + scalar_t *gInp_ptr_NC = grad_input.data + n * gInp_sN; + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + for (index_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + for (index_t i = 0; i < 4; ++i) { + for (index_t j = 0; j < 4; ++j) { + + // set input gradient + add_value_bounded(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H, + gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners); + + // set grid gradient + scalar_t val = get_value_bounded(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j, + inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners); + + gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; + giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; + } + } + } + + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; } } } diff --git a/aten/src/ATen/native/cuda/GridSampler.cuh b/aten/src/ATen/native/cuda/GridSampler.cuh index 4a94a3fda1bb..0c4acd1be41c 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cuh +++ b/aten/src/ATen/native/cuda/GridSampler.cuh @@ -7,7 +7,7 @@ namespace at { namespace native { namespace detail { - enum class GridSamplerInterpolation {Bilinear, Nearest}; + enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic}; enum class GridSamplerPadding {Zeros, Border, Reflection}; } // namespace detail @@ -153,15 +153,11 @@ scalar_t safe_downgrade_to_int_range(scalar_t x){ return x; } -// Computes the pixel source index value for a grid coordinate -template +template static __forceinline__ __device__ -scalar_t grid_sampler_compute_source_index( - scalar_t coord, - int size, - GridSamplerPadding padding_mode, - bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); +scalar_t compute_coordinates(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners) { if (padding_mode == GridSamplerPadding::Border) { // clip coordinates to image borders coord = clip_coordinates(coord, size); @@ -176,7 +172,20 @@ scalar_t grid_sampler_compute_source_index( coord = clip_coordinates(coord, size); } - coord = safe_downgrade_to_int_range(coord); + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +static __forceinline__ __device__ +scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); return coord; } @@ -224,6 +233,25 @@ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } +template +static __forceinline__ __device__ +scalar_t get_value_bounded( + scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + template static __forceinline__ __device__ void safe_add_2d(scalar_t *data, int h, int w, @@ -244,4 +272,44 @@ void safe_add_3d(scalar_t *data, int d, int h, int w, } } +template +static __forceinline__ __device__ +void add_value_bounded( + scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static __forceinline__ __device__ +void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu index 6bf0bdc3ea89..bc64b7eb2268 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu @@ -75,7 +75,7 @@ void tanh_kernel_cuda(TensorIterator& iter) { } void acosh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "acosh_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::acosh(a); }); @@ -83,7 +83,7 @@ void acosh_kernel_cuda(TensorIterator& iter) { } void asinh_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "asinh_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::asinh(a); }); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 398aa7474eab..349256477df3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3882,7 +3882,7 @@ use_c10_dispatcher: full variants: function, method dispatch: - CPU: flip_cpu + CPU, QuantizedCPU: flip_cpu CUDA: flip_cuda - func: fliplr(Tensor self) -> Tensor diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 7f4e1a0afc24..61bdf9d9b974 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -278,6 +278,106 @@ void testLLVMIfThenElseTest() { ASSERT_EQ(b_buffer[0], 42); } +// if (x < 10) x = x + 1 +void testLLVMCondNoFalseBlockTest() { + KernelScope kernel_scope; + + Placeholder x(BufHandle("X", {1}, kInt)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value); + } + } +} + +// if (x < 10) { +// x = x + 1; +// } else { +// x = x - 1; +// } +void testLLVMCondTest() { + KernelScope kernel_scope; + + Placeholder x(BufHandle("X", {1}, kInt)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = + Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto block = Block::make({ + cond, + x.store({0}, x.load(0) * 2), + }); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(block, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], (x_value + 1) * 2); + } else { + ASSERT_EQ(x_buffer[0], (x_value - 1) * 2); + } + } +} + +// if (x < 10) { +// if (x > 5) { +// x = x + 1; +// } else { +// x = x - 1; +// } +// } else { +// if (x <= 15) { +// x = x + 2; +// } else { +// x = x - 2; +// } +// } +void testLLVMCondNestedTest() { + KernelScope kernel_scope; + + Placeholder x(BufHandle("X", {1}, kInt)); + auto true_cmp = + CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); + auto true_cond = Cond::make( + true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto false_cmp = + CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE); + auto false_cond = Cond::make( + false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, true_cond, false_cond); + + for (int32_t x_value : {0, 8, 15, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + if (x_value > 5) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value - 1); + } + } else { + if (x_value <= 15) { + ASSERT_EQ(x_buffer[0], x_value + 2); + } else { + ASSERT_EQ(x_buffer[0], x_value - 2); + } + } + } +} + void testLLVMVecLoadStoreTest() { KernelScope kernel_scope; Placeholder a(BufHandle("A", {1}, kInt)); diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 6f7b50790477..72fedbc2a95a 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -506,6 +506,9 @@ namespace jit { _(LLVMEmptyStmt) \ _(LLVMEliminatedStmt) \ _(LLVMIfThenElseTest) \ + _(LLVMCondNoFalseBlockTest) \ + _(LLVMCondTest) \ + _(LLVMCondNestedTest) \ _(LLVMVectorizerLoadStoreTest) \ _(LLVMSimpleReduction) \ _(LLVMRFactorReduction) diff --git a/test/test_nn.py b/test/test_nn.py index 15b877391c14..2ce752aa0eb8 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7167,6 +7167,9 @@ def test_grid_sample_error_checking(self): with self.assertRaisesRegex(RuntimeError, "expected input to have non-empty spatial dimensions"): F.grid_sample(torch.empty(1, 1, 0, 2), grid, align_corners=False) + with self.assertRaisesRegex(RuntimeError, "bicubic interpolation only supports 4D input"): + F.grid_sample(torch.empty(1, 1, 2, 2, 2), torch.empty(1, 1, 1, 1, 3), mode='bicubic') + if TEST_CUDA: with self.assertRaisesRegex(RuntimeError, "expected input and grid to be on same device"): F.grid_sample(input.cuda(), grid, align_corners=False) @@ -7299,8 +7302,8 @@ def get_grid(device='cpu', data=None): self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5) out_fallback.backward(gradients.float()) - self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-5, rtol=5e-5) - self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-5, rtol=5e-5) + self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5) + self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5) if TEST_CUDA: input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_() @@ -7378,7 +7381,7 @@ def get_grid(device='cpu', data=None): W = random.randint(3, IW + 2) test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners) - for mode in ('bilinear', 'nearest'): + for mode in ('bilinear', 'nearest', 'bicubic'): for padding_mode in ('zeros', 'border', 'reflection'): for align_corners in (True, False): # test known input on CPU @@ -7446,6 +7449,37 @@ def get_grid(device='cpu', data=None): [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5) else: raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + elif mode == 'bicubic': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000], + [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264], + [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]]).view(1, 1, 2, 5) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000], + [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781], + [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]]).view(1, 1, 2, 5) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000], + [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]]).view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531], + [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]]).view(1, 1, 2, 5) + else: + raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + else: raise AssertionError("missing groundtruth test for interpolation mode '{}'".format(mode)) output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, @@ -7501,11 +7535,42 @@ def get_grid(device='cpu', data=None): groundtruth = torch.tensor( [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]], [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2) + elif mode == 'bicubic': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[[[-4.5, -6.], [-4.5, 6.], [2.725679, 0.740878], [2.725679, -0.740878]], + [[1.5, 0.], [1.5, 0.], [1.927921, -0.05688], [1.927921, 0.05688]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[-5.859375, -5.888672], [-5.859375, 5.888672], [-5.6250, -7.5000], [-5.6250, 7.5000]], + [[-0.234375, -0.263672], [-0.234375, 0.263672], [1.8750, 0.], [1.8750, 0.]]]] + ).view(1, 2, 4, 2) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[[[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]], + [[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]], + [[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]]]]).view(1, 2, 4, 2) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[[[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]], + [[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]]]]).view(1, 2, 4, 2) + else: + groundtruth = torch.tensor( + [[[[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]], + [[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]]]]).view(1, 2, 4, 2) + else: + raise AssertionError("missing gradient groundtruth test for padding mode '{}'".format(padding_mode)) else: raise AssertionError("missing gradient groundtruth test for interpolation mode '{}'".format(mode)) F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners).sum().backward() - self.assertEqual(grid.grad, groundtruth, + self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0, msg="gradient groundtruth comparison failed for mode={}, " "padding_mode={}".format(mode, padding_mode)) @@ -7516,7 +7581,7 @@ def get_grid(device='cpu', data=None): F.GRID_SAMPLE_INTERPOLATION_MODES[mode], F.GRID_SAMPLE_PADDING_MODES[padding_mode], align_corners).sum().backward() - self.assertEqual(grid.grad, groundtruth) + self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0) # do gradcheck N = random.randint(2, 8) @@ -11075,7 +11140,7 @@ def test_grid_sample_large_index_2d(self, device, dtype): sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31, msg="View must use 64-bit indexing") for mode, padding_mode, align_corners in itertools.product( - ('nearest', 'bilinear'), ('zeros', 'border', 'reflection'), (True, False)): + ('nearest', 'bilinear', 'bicubic'), ('zeros', 'border', 'reflection'), (True, False)): a = F.grid_sample( small_image, coords, mode=mode, padding_mode=padding_mode, align_corners=align_corners) diff --git a/torch/csrc/api/include/torch/nn/functional/vision.h b/torch/csrc/api/include/torch/nn/functional/vision.h index e1041cb21d8c..1fe084d02c79 100644 --- a/torch/csrc/api/include/torch/nn/functional/vision.h +++ b/torch/csrc/api/include/torch/nn/functional/vision.h @@ -61,8 +61,10 @@ inline Tensor grid_sample( if (c10::get_if(&mode)) { mode_enum = 0; - } else { /// mode == 'nearest' + } else if (c10::get_if(&mode)) { mode_enum = 1; + } else { /// mode == 'bicubic' + mode_enum = 2; } if (c10::get_if(&padding_mode)) { diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index e0b7e15556eb..a85f1ef69b1e 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1897,6 +1897,19 @@ std::vector inlineCallTo( std::unordered_map new_callstack_entries; + c10::optional module_instance_info = c10::nullopt; + if (to_replace->kind() == prim::CallMethod) { + auto class_type_ptr = to_replace->input(0)->type()->cast(); + if (to_replace->input(0)->node()->kind() == prim::GetAttr) { + module_instance_info = c10::make_optional(ModuleInstanceInfo( + class_type_ptr, to_replace->input(0)->node()->s(attr::name))); + } else { + std::string instance_name_unknown("INSTANCE_NAME_UNKNOWN"); + module_instance_info = c10::make_optional( + ModuleInstanceInfo(class_type_ptr, instance_name_unknown)); + } + } + // TODO: We might need to use nodes_map instead of value_map. Otherwise, we // are missing nodes without outputs (e.g. prim::Print). std::unordered_set updated_nodes; @@ -1915,11 +1928,14 @@ std::vector inlineCallTo( if (new_node_cs) { new_callstack_entries[raw_callstack_ptr] = c10::make_intrusive( - *new_node_cs, callee, to_replace->sourceRange()); + *new_node_cs, + callee, + to_replace->sourceRange(), + module_instance_info); } else { new_callstack_entries[raw_callstack_ptr] = c10::make_intrusive( - callee, to_replace->sourceRange()); + callee, to_replace->sourceRange(), module_instance_info); } } new_node->setCallStack(new_callstack_entries.at(raw_callstack_ptr)); diff --git a/torch/csrc/jit/ir/scope.cpp b/torch/csrc/jit/ir/scope.cpp index 900722427225..3901ce1038bf 100644 --- a/torch/csrc/jit/ir/scope.cpp +++ b/torch/csrc/jit/ir/scope.cpp @@ -89,6 +89,14 @@ InlinedCallStackPtr InlinedCallStack::intrusive_from_this() { InlinedCallStack::InlinedCallStack(Function* fn, SourceRange source_range) : fn_(fn), source_range_(std::move(source_range)) {} +InlinedCallStack::InlinedCallStack( + Function* fn, + SourceRange source_range, + c10::optional module_instance_info) + : fn_(fn), + source_range_(std::move(source_range)), + module_instance_info_(std::move(module_instance_info)) {} + InlinedCallStack::InlinedCallStack( InlinedCallStackPtr callee, Function* fn, @@ -97,6 +105,16 @@ InlinedCallStack::InlinedCallStack( fn_(fn), source_range_(std::move(source_range)) {} +InlinedCallStack::InlinedCallStack( + InlinedCallStackPtr callee, + Function* fn, + SourceRange source_range, + c10::optional module_instance_info) + : callee_(std::move(callee)), + fn_(fn), + source_range_(std::move(source_range)), + module_instance_info_(std::move(module_instance_info)) {} + c10::optional InlinedCallStack::callee() const { return callee_; } @@ -110,5 +128,11 @@ std::vector InlinedCallStack::vec() { } return r; } + +ModuleInstanceInfo::ModuleInstanceInfo( + c10::ClassTypePtr module_type, + std::string instance_name) + : module_type_(std::move(module_type)), + instance_name_(std::move(instance_name)) {} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index d75f3e060f36..784c2942c263 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include @@ -51,6 +52,32 @@ struct TORCH_API Scope : public c10::intrusive_ptr_target { struct Function; struct InlinedCallStack; +/** + * ModuleInstanceInfo is a structure to include the module type and instance + * name. It also provide public methods to get the pointer to module type and + * instance name. + * + * This structure is mainly used as a private member in InlinedCallStack, such + * that one can follow the callstack to find the relevant module hierarchy. + */ +struct ModuleInstanceInfo { + private: + c10::ClassTypePtr module_type_{nullptr}; + std::string instance_name_; + + public: + ModuleInstanceInfo(c10::ClassTypePtr module_type, std::string instance_name); + c10::ClassTypePtr class_type() { + return module_type_; + } + c10::ClassTypePtr class_type() const { + return module_type_; + } + std::string instance_name() const { + return instance_name_; + } +}; + /** * InlinedCallStack is an element in a list representing callstack of functions * that have been inlined. @@ -81,6 +108,8 @@ struct InlinedCallStack; */ using InlinedCallStackPtr = c10::intrusive_ptr; using InlinedCallStackEntry = std::pair; +using InlinedCallStackWithModuleInfo = + std::tuple>; struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { private: @@ -88,17 +117,30 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { Function* fn_; SourceRange source_range_; InlinedCallStackPtr intrusive_from_this(); + c10::optional module_instance_info_; public: // Constructor for a leaf callstack node. InlinedCallStack(Function* fn, SourceRange source_range); + // Constructor for a leaf callstack node. + InlinedCallStack( + Function* fn, + SourceRange source_range, + c10::optional module_instance_info); + // Constructor for an inner callstack node. InlinedCallStack( InlinedCallStackPtr callee, Function* fn, SourceRange source_range); + InlinedCallStack( + InlinedCallStackPtr callee, + Function* fn, + SourceRange source_range, + c10::optional module_instance_info); + // Return next element in the callstack list. c10::optional callee() const; diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 0400b6f14143..e692237b7c6f 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -1689,7 +1689,45 @@ void LLVMCodeGenImpl::visit(const Let* v) { } void LLVMCodeGenImpl::visit(const Cond* v) { - throw unimplemented_lowering(v); + // Even if true_stmt and false_stmt are nullptr, + // in case condition is a function call with side effect, + // we still evaluate it. + v->condition()->accept(this); + + if (!v->true_stmt() && !v->false_stmt()) { + return; + } + assert(v->true_stmt()); + + llvm::Value* condition = value_; + llvm::Value* c = irb_.CreateICmpNE( + condition, llvm::ConstantInt::get(condition->getType(), 0)); + llvm::BasicBlock* then_block = + llvm::BasicBlock::Create(getContext(), "then", fn_); + llvm::BasicBlock* else_block = nullptr; + if (v->false_stmt()) { + else_block = llvm::BasicBlock::Create(getContext(), "else", fn_); + } + llvm::BasicBlock* end_block = + llvm::BasicBlock::Create(getContext(), "end", fn_); + + if (else_block) { + irb_.CreateCondBr(c, then_block, else_block); + } else { + irb_.CreateCondBr(c, then_block, end_block); + } + + irb_.SetInsertPoint(then_block); + v->true_stmt()->accept(this); + irb_.CreateBr(end_block); + + if (else_block) { + irb_.SetInsertPoint(else_block); + v->false_stmt()->accept(this); + irb_.CreateBr(end_block); + } + + irb_.SetInsertPoint(end_block); } void LLVMCodeGenImpl::optimize(llvm::Module& M) { diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 6253f9ddda7b..031d974d4973 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3227,6 +3227,7 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 GRID_SAMPLE_INTERPOLATION_MODES = { 'bilinear': 0, 'nearest': 1, + 'bicubic': 2, } GRID_SAMPLE_PADDING_MODES = { @@ -3293,8 +3294,9 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case) or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) mode (str): interpolation mode to calculate output values - ``'bilinear'`` | ``'nearest'``. Default: ``'bilinear'`` - Note: When ``mode='bilinear'`` and the input is 5-D, the interpolation mode + ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` + Note: ``mode='bicubic'`` supports only 4-D input. + When ``mode='bilinear'`` and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legitimately be bilinear. padding_mode (str): padding mode for outside grid values @@ -3324,6 +3326,17 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner The default behavior up to version 1.2.0 was ``align_corners = True``. Since then, the default behavior has been changed to ``align_corners = False``, in order to bring it in line with the default for :func:`interpolate`. + + .. note:: + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. + The constant :math:`\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + Clamp the results with :func: `torch.clamp` to ensure they are within the valid range. + .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 + .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 """ if not torch.jit.is_scripting(): tens_ops = (input, grid) @@ -3331,9 +3344,9 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner return handle_torch_function( grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners) - if mode != 'bilinear' and mode != 'nearest': + if mode != 'bilinear' and mode != 'nearest' and mode != 'bicubic': raise ValueError("nn.functional.grid_sample(): expected mode to be " - "'bilinear' or 'nearest', but got: '{}'".format(mode)) + "'bilinear', 'nearest' or 'bicubic', but got: '{}'".format(mode)) if padding_mode != 'zeros' and padding_mode != 'border' and padding_mode != 'reflection': raise ValueError("nn.functional.grid_sample(): expected padding_mode " "to be 'zeros', 'border', or 'reflection', " @@ -3341,8 +3354,10 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner if mode == 'bilinear': mode_enum = 0 - else: # mode == 'nearest' + elif mode == 'nearest': mode_enum = 1 + else: # mode == 'bicubic' + mode_enum = 2 if padding_mode == 'zeros': padding_mode_enum = 0 diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b588ded49494..277124466101 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -237,8 +237,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): UnaryUfuncInfo('acosh', ref=np.arccosh, domain=(1, float('inf')), - dtypesIfCPU=floating_types(), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), test_inplace_grad=False), UnaryUfuncInfo('asin', @@ -255,8 +257,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): # NOTE: derivative for inplace asinh is not implemented UnaryUfuncInfo('asinh', ref=np.arcsinh, - dtypesIfCPU=floating_types(), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), test_inplace_grad=False), UnaryUfuncInfo('atan',