diff --git a/android/ops/CMakeLists.txt b/android/ops/CMakeLists.txt index 3210925a85c..ad42adbfa71 100644 --- a/android/ops/CMakeLists.txt +++ b/android/ops/CMakeLists.txt @@ -14,13 +14,6 @@ file(GLOB VISION_SRCS ../../torchvision/csrc/ops/*.h ../../torchvision/csrc/ops/*.cpp) -# Remove interpolate_aa sources as they are temporary code -# see https://github.com/pytorch/vision/pull/3761 -# and IndexingUtils.h is unavailable on Android build -list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp") -list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.cpp") -list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.h") - add_library(${TARGET} SHARED ${VISION_SRCS} ) diff --git a/ios/CMakeLists.txt b/ios/CMakeLists.txt index 2ac46c15018..6b9fd3925b2 100644 --- a/ios/CMakeLists.txt +++ b/ios/CMakeLists.txt @@ -11,13 +11,6 @@ file(GLOB VISION_SRCS ../torchvision/csrc/ops/*.h ../torchvision/csrc/ops/*.cpp) -# Remove interpolate_aa sources as they are temporary code -# see https://github.com/pytorch/vision/pull/3761 -# and using TensorIterator unavailable with iOS -list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp") -list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.cpp") -list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.h") - add_library(${TARGET} STATIC ${VISION_SRCS} ) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 11fa851ed21..0ac559565b7 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -3,6 +3,7 @@ import math import os import re +from functools import partial from typing import Sequence import numpy as np @@ -655,11 +656,13 @@ def test_resize_antialias(device, dt, size, interpolation): def test_assert_resize_antialias(interpolation): # Checks implementation on very large scales - # and catch TORCH_CHECK inside interpolate_aa_kernels.cu + # and catch TORCH_CHECK inside PyTorch implementation torch.manual_seed(12) - tensor, pil_img = _create_data(1000, 1000, device="cuda") + tensor, _ = _create_data(1000, 1000, device="cuda") - with pytest.raises(RuntimeError, match=r"Max supported scale factor is"): + # Error message is not yet updated in pytorch nightly + # with pytest.raises(RuntimeError, match=r"Provided interpolation parameters can not be handled"): + with pytest.raises(RuntimeError, match=r"Too much shared memory required"): F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) @@ -674,32 +677,12 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation): return torch.manual_seed(12) - if interpolation == BILINEAR: - forward_op = torch.ops.torchvision._interpolate_bilinear2d_aa - backward_op = torch.ops.torchvision._interpolate_bilinear2d_aa_backward - elif interpolation == BICUBIC: - forward_op = torch.ops.torchvision._interpolate_bicubic2d_aa - backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward - - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, i): - result = forward_op(i, size, False) - ctx.save_for_backward(i, result) - return result - - @staticmethod - def backward(ctx, grad_output): - i, result = ctx.saved_tensors - ishape = i.shape - oshape = result.shape[2:] - return backward_op(grad_output, oshape, ishape, False) - x = (torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),) - assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) + resize = partial(F.resize, size=size, interpolation=interpolation, antialias=True) + assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) x = (torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),) - assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) + assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) def check_functional_vs_PIL_vs_scripted( diff --git a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp deleted file mode 100644 index 32652466916..00000000000 --- a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp +++ /dev/null @@ -1,823 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include - -// Code temporary is in torchvision before merging it to PyTorch -namespace at { -namespace native { -namespace internal_upsample { - -using scale_t = std::vector>; - -template -static inline scalar_t interpolate_aa_single_dim_zero_strides( - char* src, - char** data, - int64_t i, - const index_t ids_stride) { - const index_t ids_min = *(index_t*)&data[0][0]; - const index_t ids_size = *(index_t*)&data[1][0]; - - char* src_min = src + ids_min; - - scalar_t t = *(scalar_t*)&src_min[0]; - index_t wts_idx = *(index_t*)&data[4][0]; - scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx]; - scalar_t wts = wts_ptr[0]; - - scalar_t output = t * wts; - int j = 1; - for (; j < ids_size; j++) { - wts = wts_ptr[j]; - t = *(scalar_t*)&src_min[j * ids_stride]; - output += t * wts; - } - return output; -} - -template -static inline scalar_t interpolate_aa_single_dim( - char* src, - char** data, - const int64_t* strides, - int64_t i, - const index_t ids_stride) { - index_t ids_min = *(index_t*)&data[0][i * strides[0]]; - index_t ids_size = *(index_t*)&data[1][i * strides[1]]; - - char* src_min = src + ids_min; - - scalar_t t = *(scalar_t*)&src_min[0]; - index_t wts_idx = *(index_t*)&data[4][i * strides[4]]; - scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx]; - scalar_t wts = wts_ptr[0]; - - scalar_t output = t * wts; - int j = 1; - for (; j < ids_size; j++) { - wts = wts_ptr[j]; - t = *(scalar_t*)&src_min[j * ids_stride]; - output += t * wts; - } - return output; -} - -template -static inline void basic_loop_aa_single_dim_zero_strides( - char** data, - const int64_t* strides, - int64_t n) { - char* dst = data[0]; - char* src = data[1]; - // index stride is constant for the given dimension - const index_t ids_stride = *(index_t*)&data[2 + 2][0]; - - for (int64_t i = 0; i < n; i++) { - *(scalar_t*)&dst[i * strides[0]] = - interpolate_aa_single_dim_zero_strides( - src + i * strides[1], &data[2], i, ids_stride); - } -} - -template -static inline void basic_loop_aa_single_dim_nonzero_strides( - char** data, - const int64_t* strides, - int64_t n) { - char* dst = data[0]; - char* src = data[1]; - // index stride is constant for the given dimension - const index_t ids_stride = *(index_t*)&data[2 + 2][0]; - - if (strides[1] == 0) { - for (int64_t i = 0; i < n; i++) { - *(scalar_t*)&dst[i * strides[0]] = - interpolate_aa_single_dim( - src, &data[2], &strides[2], i, ids_stride); - } - } else { - for (int64_t i = 0; i < n; i++) { - *(scalar_t*)&dst[i * strides[0]] = - interpolate_aa_single_dim( - src + i * strides[1], &data[2], &strides[2], i, ids_stride); - } - } -} - -template -static inline bool is_zero_stride(const int64_t* strides) { - bool output = strides[0] == 0; - for (int i = 1; i < m; i++) { - output &= (strides[i] == 0); - } - return output; -} - -template -void ti_cpu_upsample_generic_aa( - at::TensorIterator& iter, - int interp_size = -1) { - TORCH_INTERNAL_ASSERT(interp_size > 0); - - auto loop = [&](char** data, const int64_t* strides, int64_t n) { - if ((strides[0] == sizeof(scalar_t)) && (strides[1] == sizeof(scalar_t)) && - is_zero_stride<3 + 2>(&strides[2])) { - basic_loop_aa_single_dim_zero_strides( - data, strides, n); - } else { - basic_loop_aa_single_dim_nonzero_strides( - data, strides, n); - } - }; - - iter.for_each(loop); -} - -// Helper structs to use with ti_upsample_generic_Nd_kernel_impl -template -struct HelperInterpBase { - template - static inline void _compute_weights_aa( - const int64_t i, - const int64_t input_size, - const scalar_t scale, - const scalar_t support, - scalar_t* wt_ptr, - const int64_t interp_size, - filter_fn_t filter_fn, - int64_t& xmin, - int64_t& xsize) { - scalar_t center = scale * (i + 0.5); - scalar_t total_w = 0.0; - scalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; - xmin = std::max( - static_cast(center - support + 0.5), static_cast(0)); - xsize = std::min(static_cast(center + support + 0.5), input_size) - - xmin; - - int64_t j = 0; - for (; j < xsize; j++) { - scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale); - wt_ptr[j] = w; - total_w += w; - } - for (j = 0; j < xsize; j++) { - if (total_w != 0.0) { - wt_ptr[j] /= total_w; - } - } - for (; j < interp_size; j++) { - wt_ptr[j] = static_cast(0.0); - } - } - - template - static inline std::vector _compute_indices_weights_aa( - int64_t input_size, - int64_t output_size, - int64_t stride, - int64_t ndims, - int64_t reshape_dim, - bool align_corners, - scalar_t scale, - int& in_out_interp_size, - filter_fn_t filter_fn) { - int interp_size = in_out_interp_size; - scalar_t support = - (scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5; - interp_size = (int)ceilf(support) * 2 + 1; - - // return interp_size - in_out_interp_size = interp_size; - - std::vector output; - auto new_shape = std::vector(ndims, 1); - new_shape[reshape_dim] = output_size; - - // ---- Bounds approach as in PIL ----- - // bounds: xmin/xmax - output.emplace_back( - empty(new_shape, CPU(c10::CppTypeToScalarType()))); - output.emplace_back( - empty(new_shape, CPU(c10::CppTypeToScalarType()))); - output.emplace_back( - empty(new_shape, CPU(c10::CppTypeToScalarType()))); - - { - // Weights - new_shape[reshape_dim] = output_size * interp_size; - auto wts = empty(new_shape, CPU(c10::CppTypeToScalarType())); - auto strides = wts.strides().vec(); - strides[reshape_dim] = 0; - new_shape[reshape_dim] = output_size; - wts = wts.as_strided(new_shape, strides); - output.emplace_back(wts); - // Weights indices - output.emplace_back( - empty(new_shape, CPU(c10::CppTypeToScalarType()))); - } - - int64_t* idx_ptr_xmin = output[0].data_ptr(); - int64_t* idx_ptr_size = output[1].data_ptr(); - int64_t* idx_ptr_stride = output[2].data_ptr(); - scalar_t* wt_ptr = output[3].data_ptr(); - int64_t* wt_idx_ptr = output[4].data_ptr(); - - int64_t xmin, xmax; - - for (int64_t i = 0; i < output_size; i++) { - HelperInterpBase::_compute_weights_aa( - i, - input_size, - scale, - support, - wt_ptr + i * interp_size, - interp_size, - filter_fn, - xmin, - xmax); - - idx_ptr_xmin[i] = xmin * stride; - idx_ptr_size[i] = xmax; - idx_ptr_stride[i] = stride; - wt_idx_ptr[i] = i * interp_size * sizeof(scalar_t); - } - return output; - } -}; - -template -struct HelperInterpLinear : public HelperInterpBase { - static const int interp_size = 2; - - // taken from - // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ - // src/libImaging/Resample.c#L20-L29 - static inline scalar_t _filter(scalar_t x) { - if (x < 0.0) { - x = -x; - } - if (x < 1.0) { - return 1.0 - x; - } - return 0.0; - } - - static inline std::vector compute_indices_weights( - int64_t input_size, - int64_t output_size, - int64_t stride, - int64_t ndims, - int64_t reshape_dim, - bool align_corners, - const c10::optional opt_scale, - bool antialias, - int& out_interp_size) { - TORCH_INTERNAL_ASSERT(antialias); - scalar_t scale = area_pixel_compute_scale( - input_size, output_size, align_corners, opt_scale); - - out_interp_size = HelperInterpLinear::interp_size; - return HelperInterpLinear::_compute_indices_weights_aa( - input_size, - output_size, - stride, - ndims, - reshape_dim, - align_corners, - scale, - out_interp_size, - _filter); - } -}; - -template -struct HelperInterpCubic : public HelperInterpBase { - static const int interp_size = 4; - - static inline std::vector compute_indices_weights( - int64_t input_size, - int64_t output_size, - int64_t stride, - int64_t ndims, - int64_t reshape_dim, - bool align_corners, - const c10::optional opt_scale, - bool antialias, - int& out_interp_size) { - TORCH_INTERNAL_ASSERT(antialias); - scalar_t scale = area_pixel_compute_scale( - input_size, output_size, align_corners, opt_scale); - - out_interp_size = HelperInterpCubic::interp_size; - return HelperInterpCubic::_compute_indices_weights_aa( - input_size, - output_size, - stride, - ndims, - reshape_dim, - align_corners, - scale, - out_interp_size, - _filter); - } - - // taken from - // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ - // src/libImaging/Resample.c#L46-L62 - static inline scalar_t _filter(scalar_t x) { - // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm -#define a -0.5 - if (x < 0.0) { - x = -x; - } - if (x < 1.0) { - return ((a + 2.0) * x - (a + 3.0)) * x * x + 1; - } - if (x < 2.0) { - return (((x - 5) * x + 8) * x - 4) * a; - } - return 0.0; -#undef a - } -}; - -template < - typename index_t, - int out_ndims, - typename scale_type, - template - class F> -void _ti_separable_upsample_generic_Nd_kernel_impl_single_dim( - Tensor& output, - const Tensor& input, - int interp_dim, - bool align_corners, - const scale_type& scales, - bool antialias) { - // input can be NCHW, NCL or NCKHW - auto shape = input.sizes().vec(); - auto strides = input.strides().vec(); - auto oshape = output.sizes(); - - TORCH_INTERNAL_ASSERT( - shape.size() == oshape.size() && shape.size() == 2 + out_ndims); - TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims); - TORCH_INTERNAL_ASSERT(antialias); - - for (int i = 0; i < out_ndims; i++) { - shape[i + 2] = oshape[i + 2]; - } - strides[interp_dim] = 0; - auto restrided_input = input.as_strided(shape, strides); - - std::vector> indices_weights; - - int interp_size = F::interp_size; - auto input_scalar_type = input.scalar_type(); - - if (interp_size == 1 && input_scalar_type == at::ScalarType::Byte) { - // nearest also supports uint8 tensor, but we have to use float - // with compute_indices_weights - input_scalar_type = at::ScalarType::Float; - } - - AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::Byte, - input_scalar_type, - "compute_indices_weights_generic", - [&] { - indices_weights.emplace_back( - F::compute_indices_weights( - input.size(interp_dim), - oshape[interp_dim], - input.stride(interp_dim) * input.element_size(), - input.dim(), - interp_dim, - align_corners, - scales[interp_dim - 2], - antialias, - interp_size)); - }); - - TensorIteratorConfig config; - config.check_all_same_dtype(false) - .declare_static_dtype_and_device(input.scalar_type(), input.device()) - .add_output(output) - .add_input(restrided_input); - - for (auto& idx_weight : indices_weights) { - for (auto& tensor : idx_weight) { - config.add_input(tensor); - } - } - - auto iter = config.build(); - - if (interp_size > 1) { - // Nearest also supports uint8 tensor, so need to handle it separately - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "upsample_generic_Nd", [&] { - ti_cpu_upsample_generic_aa( - iter, interp_size); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd", [&] { - ti_cpu_upsample_generic_aa( - iter, interp_size); - }); - } -} - -template < - typename index_t, - int out_ndims, - typename scale_type, - template - class F> -void ti_separable_upsample_generic_Nd_kernel_impl( - Tensor& output, - const Tensor& input, - bool align_corners, - const scale_type& scales, - bool antialias) { - auto temp_oshape = input.sizes().vec(); - at::Tensor temp_output, temp_input = input; - for (int i = 0; i < out_ndims - 1; i++) { - int interp_dim = 2 + out_ndims - 1 - i; - temp_oshape[interp_dim] = output.sizes()[interp_dim]; - temp_output = at::empty(temp_oshape, input.options()); - _ti_separable_upsample_generic_Nd_kernel_impl_single_dim< - index_t, - out_ndims, - scale_t, - F>( - temp_output, temp_input, interp_dim, align_corners, scales, antialias); - temp_input = temp_output; - } - _ti_separable_upsample_generic_Nd_kernel_impl_single_dim< - index_t, - out_ndims, - scale_t, - F>(output, temp_input, 2, align_corners, scales, antialias); -} - -void _ti_upsample_bilinear2d_kernel_impl( - Tensor& output, - const Tensor& input, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w, - bool antialias) { - ti_separable_upsample_generic_Nd_kernel_impl< - int64_t, - 2, - scale_t, - HelperInterpLinear>( - output, input, align_corners, {scales_h, scales_w}, antialias); -} - -void _ti_upsample_bicubic2d_kernel_impl( - Tensor& output, - const Tensor& input, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w, - bool antialias) { - ti_separable_upsample_generic_Nd_kernel_impl< - int64_t, - 2, - scale_t, - HelperInterpCubic>( - output, input, align_corners, {scales_h, scales_w}, antialias); -} - -template < - typename scalar_t, - typename scale_type, - template - class F> -void cpu_upsample_genNd_backward_aa( - const Tensor& grad_input_, - const Tensor& grad_output_, - bool align_corners, - const scale_type& scales) { - TORCH_CHECK( - grad_input_.dtype() == grad_output_.dtype(), - "expected dtype ", - grad_output_.dtype(), - " for `grad_input` but got dtype ", - grad_input_.dtype()); - - auto grad_output = grad_output_.contiguous(); - auto grad_input = grad_input_.contiguous(); - - auto grad_output_data = grad_output.data_ptr(); - auto grad_input_data = grad_input.data_ptr(); - auto input_sizes = grad_input.sizes().vec(); - auto output_sizes = grad_output.sizes().vec(); - auto ndim = input_sizes.size(); - - // treat nbatch and channels as one dimension - int64_t channels = input_sizes[0] * input_sizes[1]; - int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1; - int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1; - int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1; - int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1; - int64_t input_width = input_sizes[ndim - 1]; - int64_t output_width = output_sizes[ndim - 1]; - - int64_t output_slice_size = output_depth * output_height * output_width; - int interp_size = F::interp_size; - - auto loop2d = [&](int64_t begin, int64_t end) { - const scalar_t height_scale = area_pixel_compute_scale( - input_height, output_height, align_corners, scales[0]); - const scalar_t width_scale = area_pixel_compute_scale( - input_width, output_width, align_corners, scales[1]); - - auto input_indexr = [=](int64_t c, int64_t h, int64_t w) { - return grad_input_data + c * input_height * input_width + - h * input_width + w; - }; - - const scalar_t support_h = (height_scale >= 1.0) - ? (interp_size * 0.5) * height_scale - : interp_size * 0.5; - const scalar_t support_w = (width_scale >= 1.0) - ? (interp_size * 0.5) * width_scale - : interp_size * 0.5; - - const int interp_height = (int)ceilf(support_h) * 2 + 1; - const int interp_width = (int)ceilf(support_w) * 2 + 1; - - std::vector wx(interp_width, 0.0); - std::vector wy(interp_height, 0.0); - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t xmin, ymin; - int64_t xsize, ysize; - auto filter_fn = F::_filter; - - for (int64_t oh = 0; oh < output_height; oh++) { - F::_compute_weights_aa( - oh, - input_height, - height_scale, - support_h, - wy.data(), - interp_height, - filter_fn, - ymin, - ysize); - - for (int64_t ow = 0; ow < output_width; ow++) { - F::_compute_weights_aa( - ow, - input_width, - width_scale, - support_w, - wx.data(), - interp_width, - filter_fn, - xmin, - xsize); - - for (int64_t c = begin; c < end; c++) { - scalar_t grad_output_value = - grad_output_data[c * output_slice_size + oh * output_width + ow]; - - for (size_t y = 0; y < ysize; y++) { - for (size_t x = 0; x < xsize; x++) { - *input_indexr(c, ymin + y, xmin + x) += - wx[x] * wy[y] * grad_output_value; - } - } - } - } - } - }; - - if (ndim == 4) { - // upsample bilinear 2d - at::parallel_for( - 0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d); - } else { - TORCH_CHECK(false, "Unsupported tensor ndim"); - } - - if (!grad_input_.is_contiguous()) { - grad_input_.copy_(grad_input); - } -} - -void _upsample_bilinear2d_aa_backward_kernel_impl( - const Tensor& grad_input, - const Tensor& grad_output, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w) { - AT_DISPATCH_FLOATING_TYPES( - grad_output.scalar_type(), "upsample_bilinear2d_backward_cpu", [&] { - cpu_upsample_genNd_backward_aa( - grad_input, grad_output, align_corners, {scales_h, scales_w}); - }); -} - -void _upsample_bicubic2d_aa_backward_kernel_impl( - const Tensor& grad_input, - const Tensor& grad_output, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w) { - AT_DISPATCH_FLOATING_TYPES( - grad_output.scalar_type(), "upsample_bicubic2d_backward_cpu", [&] { - cpu_upsample_genNd_backward_aa( - grad_input, grad_output, align_corners, {scales_h, scales_w}); - }); -} - -} // namespace internal_upsample -} // namespace native -} // namespace at - -namespace vision { -namespace ops { - -namespace { - -at::Tensor interpolate_bilinear2d_aa_forward_kernel( - const at::Tensor& input, - at::IntArrayRef output_size, - bool align_corners) { - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - - c10::optional> scale_factors = {}; - - // Copied from UpSampleBilinear2d.cpp - auto output = at::empty({0}, input.options()); - auto osize = at::native::upsample::compute_output_size( - input.sizes(), output_size, scale_factors); - auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0); - auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1); - auto full_output_size = - at::native::upsample_2d_common_check(input.sizes(), osize); - - // Allow for empty batch size but not other dimensions - TORCH_CHECK( - input.numel() != 0 || - c10::multiply_integers( - input.sizes().begin() + 1, input.sizes().end()), - "Non-empty 4D data tensor expected but got a tensor with sizes ", - input.sizes()); - - output.resize_(full_output_size, input.suggest_memory_format()); - at::native::internal_upsample::_ti_upsample_bilinear2d_kernel_impl( - output, input, align_corners, scale_h, scale_w, /*antialias=*/true); - return output; -} - -at::Tensor interpolate_bicubic2d_aa_forward_kernel( - const at::Tensor& input, - at::IntArrayRef output_size, - bool align_corners) { - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - - c10::optional> scale_factors = {}; - - // Copied from UpSampleBilinear2d.cpp - auto output = at::empty({0}, input.options()); - auto osize = at::native::upsample::compute_output_size( - input.sizes(), output_size, scale_factors); - auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0); - auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1); - auto full_output_size = - at::native::upsample_2d_common_check(input.sizes(), osize); - - // Allow for empty batch size but not other dimensions - TORCH_CHECK( - input.numel() != 0 || - c10::multiply_integers( - input.sizes().begin() + 1, input.sizes().end()), - "Non-empty 4D data tensor expected but got a tensor with sizes ", - input.sizes()); - - output.resize_(full_output_size, input.suggest_memory_format()); - at::native::internal_upsample::_ti_upsample_bicubic2d_kernel_impl( - output, input, align_corners, scale_h, scale_w, /*antialias=*/true); - return output; -} - -at::Tensor interpolate_bilinear2d_aa_backward_kernel( - const at::Tensor& grad_output, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners) { - c10::optional> scale_factors = {}; - - // Copied from UpSampleBilinear2d.cpp::upsample_bilinear2d_backward - auto grad_input = at::empty({0}, grad_output.options()); - auto osize = at::native::upsample::compute_output_size( - input_size, output_size, scale_factors); - auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0); - auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1); - - auto full_output_size = - at::native::upsample_2d_common_check(input_size, osize); - - TORCH_CHECK( - grad_output.dim() == 4, - "Expected grad_output to be a tensor of dimension 4 but got: dimension ", - grad_output.dim()); - - for (int i = 0; i < 4; ++i) { - TORCH_CHECK( - grad_output.size(i) == full_output_size[i], - "Expected grad_output to have the same shape as output;", - " output.size(", - i, - ") = ", - full_output_size[i], - " but got grad_output.size(", - i, - ") = ", - grad_output.size(i)); - } - - grad_input.resize_(input_size, grad_output.suggest_memory_format()); - grad_input.zero_(); - at::native::internal_upsample::_upsample_bilinear2d_aa_backward_kernel_impl( - grad_input, grad_output, align_corners, scale_h, scale_w); - - return grad_input; -} - -at::Tensor interpolate_bicubic2d_aa_backward_kernel( - const at::Tensor& grad_output, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners) { - c10::optional> scale_factors = {}; - - // Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward - auto grad_input = at::empty({0}, grad_output.options()); - auto osize = at::native::upsample::compute_output_size( - input_size, output_size, scale_factors); - auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0); - auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1); - - auto full_output_size = - at::native::upsample_2d_common_check(input_size, osize); - - TORCH_CHECK( - grad_output.dim() == 4, - "Expected grad_output to be a tensor of dimension 4 but got: dimension ", - grad_output.dim()); - - for (int i = 0; i < 4; ++i) { - TORCH_CHECK( - grad_output.size(i) == full_output_size[i], - "Expected grad_output to have the same shape as output;", - " output.size(", - i, - ") = ", - full_output_size[i], - " but got grad_output.size(", - i, - ") = ", - grad_output.size(i)); - } - - grad_input.resize_(input_size, grad_output.suggest_memory_format()); - grad_input.zero_(); - at::native::internal_upsample::_upsample_bicubic2d_aa_backward_kernel_impl( - grad_input, grad_output, align_corners, scale_h, scale_w); - - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa"), - TORCH_FN(interpolate_bilinear2d_aa_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"), - TORCH_FN(interpolate_bicubic2d_aa_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"), - TORCH_FN(interpolate_bilinear2d_aa_backward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"), - TORCH_FN(interpolate_bicubic2d_aa_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu b/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu deleted file mode 100644 index f52793408f4..00000000000 --- a/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu +++ /dev/null @@ -1,629 +0,0 @@ -#include -// Copied and adapted from -// Adapted from interp.cpp from Caffe util by Pauline Luc -// Originally developed by George Papandreou -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// Below is experimental temporary code before merging it to PyTorch -namespace at { -namespace native { -namespace internal_upsample { - -__device__ __forceinline__ size_t -idx(const size_t nc, - const size_t height, - const size_t width, - const size_t y, - const size_t x) { - return (nc * height + y) * width + x; -} - -// taken from -// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ -// src/libImaging/Resample.c#L20-L29 -template -__device__ __forceinline__ static accscalar_t bilinear_filter(accscalar_t x) { - if (x < 0.0) { - x = -x; - } - if (x < 1.0) { - return static_cast(1.0) - x; - } - return static_cast(0.0); -} - -// taken from -// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ -// src/libImaging/Resample.c#L46-L62 -template -__device__ __forceinline__ static accscalar_t bicubic_filter(accscalar_t x) { - // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm -#define a -0.5 - if (x < 0.0) { - x = -x; - } - if (x < 1.0) { - return ((a + 2.0) * x - (a + 3.0)) * x * x + static_cast(1.0); - } - if (x < 2.0) { - return (((x - 5) * x + 8) * x - 4) * a; - } - return static_cast(0.0); -#undef a -} - -template -__device__ __forceinline__ static void _compute_weights( - const int i, - const int input_size, - const accscalar_t scale, - const accscalar_t support, - scalar_t* wt_ptr, - int interp_size, - filter_fn_t filter_fn, - int& xmin, - int& xmax) { - accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; - accscalar_t center = scale * (i + 0.5); - xmin = max(static_cast(center - support + 0.5), static_cast(0)); - xmax = min(static_cast(center + support + 0.5), input_size) - xmin; - - accscalar_t total_w = 0.0; - int j = 0; - for (j = 0; j < xmax; j++) { - accscalar_t w = filter_fn((j + xmin - center + 0.5) * invscale); - wt_ptr[j] = static_cast(w); - total_w += w; - } - for (j = 0; j < xmax; j++) { - if (total_w != 0.0) { - wt_ptr[j] /= total_w; - } - } - for (; j < interp_size; j++) { - wt_ptr[j] = static_cast(0.0); - } -} - -template -__device__ __forceinline__ static accscalar_t interpolate_aa_single_dim( - scalar_t* src, - scalar_t* weights, - int64_t size) { - scalar_t t = static_cast(*src); - scalar_t wts = static_cast(weights[0]); - accscalar_t output = t * wts; - - int64_t j = 1; - for (; j < size; j++) { - wts = static_cast(weights[j]); - t = static_cast(*(src + j)); - output += t * wts; - } - return output; -} - -template -C10_LAUNCH_BOUNDS_1(1024) -__global__ void upsample_gen2d_out_frame( - const int n, - const accscalar_t rheight, - const accscalar_t rwidth, - const bool align_corners, - const PackedTensorAccessor64 idata, - PackedTensorAccessor64 odata) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - - const int batchsize = idata.size(0); - const int channels = idata.size(1); - const int height1 = idata.size(2); - const int width1 = idata.size(3); - const int height2 = odata.size(2); - const int width2 = odata.size(3); - - if (index < n) { - const int w2 = index % width2; // 0:width2-1 - const int h2 = index / width2; // 0:height2-1 - // special case: just copy - if (height1 == height2 && width1 == width2) { - const int h1 = h2; - const int w1 = w2; - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; ++c) { - const scalar_t val = idata[n][c][h1][w1]; - odata[n][c][h2][w2] = val; - } - } - return; - } - - const accscalar_t support_h = static_cast( - (rheight >= 1.0) ? (interp_size * 0.5) * rheight : interp_size * 0.5); - const accscalar_t support_w = static_cast( - (rwidth >= 1.0) ? (interp_size * 0.5) * rwidth : interp_size * 0.5); - - const int interp_height = (int)ceilf(support_h) * 2 + 1; - const int interp_width = (int)ceilf(support_w) * 2 + 1; - - // Setup local buffers - // TODO: maybe we can specify dynamic shared memory size before calling the - // cuda code, however we should then ensure that device has enough shared - // memory - scalar_t wx[256]; - scalar_t wy[256]; - scalar_t buffer1[256]; - scalar_t buffer2[256]; - - // Compute weights - int xmin, xsize, ymin, ysize; - typedef scalar_t (*filter_fn_t)(scalar_t); - filter_fn_t filter_fn; - if (interp_size == 2) { - filter_fn = bilinear_filter; - } else if (interp_size == 4) { - filter_fn = bicubic_filter; - } - _compute_weights( - w2, - width1, - rwidth, - support_w, - wx, - interp_width, - filter_fn, - xmin, - xsize); - _compute_weights( - h2, - height1, - rheight, - support_h, - wy, - interp_height, - filter_fn, - ymin, - ysize); - - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; ++c) { - // interpolate on x-axis for ymin to ymin + ysize - for (int y = 0; y < ysize; y++) { - // copy data into the local buffer and use - // interpolate_aa_single_dim method - for (int x = 0; x < xsize; x++) { - buffer1[x] = idata[n][c][ymin + y][xmin + x]; - } - - buffer2[y] = static_cast( - interpolate_aa_single_dim( - buffer1, wx, xsize)); - } - odata[n][c][h2][w2] = static_cast( - interpolate_aa_single_dim( - buffer2, wy, ysize)); - } - } - } -} - -template -static void upsample_gen2d_out_cuda_template( - const Tensor& output, - const Tensor& input, - IntArrayRef output_size, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w) { - // Copied and adapted from - // UpSampleBicubic2d.cu::upsample_bicubic2d_out_cuda_template - TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2}; - checkAllSameGPU("upsample_gen2d_out_cuda", {input_arg, output_arg}); - - int output_height = output_size[0]; - int output_width = output_size[1]; - - int nbatch = input.size(0); - int channels = input.size(1); - int input_height = input.size(2); - int input_width = input.size(3); - - const int num_kernels = output_height * output_width; - const int num_threads = std::min( - at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "upsample_gen2d_out_frame", [&] { - using accscalar_t = at::acc_type; - - auto idata = input.packed_accessor64(); - auto odata = output.packed_accessor64(); - - const accscalar_t rheight = area_pixel_compute_scale( - input_height, output_height, align_corners, scales_h); - const accscalar_t rwidth = area_pixel_compute_scale( - input_width, output_width, align_corners, scales_w); - - // We are using static buffer memory of 256 * sizeof(float) per thread - // to store weights. Size of weights array is - // interp_size = scale * 2 + 1 for bilinear mode - TORCH_CHECK( - rheight < (255 / interp_size), - "Max supported scale factor is 127 (bilinear), 63 (bicubic)"); - TORCH_CHECK( - rwidth < (255 / interp_size), - "Max supported scale factor is 127 (bilinear), 63 (bicubic)"); - - upsample_gen2d_out_frame - <<>>( - num_kernels, rheight, rwidth, align_corners, idata, odata); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -// Backward (adjoint) operation 1 <- 2 (accumulates) -template -C10_LAUNCH_BOUNDS_1(1024) -__global__ void upsample_gen2d_backward_out_frame( - const int num_elements, - const accscalar_t height_scale, - const accscalar_t width_scale, - const bool align_corners, - PackedTensorAccessor64 idata, - const PackedTensorAccessor64 odata) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - - const int batchsize = idata.size(0); - const int channels = idata.size(1); - const int input_height = idata.size(2); - const int input_width = idata.size(3); - const int output_height = odata.size(2); - const int output_width = odata.size(3); - - if (index >= num_elements) { - return; - } - - const int output_x = index % output_width; - const int output_y = index / output_width; - // special case: output just copy - if (input_height == output_height && input_width == output_width) { - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; ++c) { - const scalar_t val = odata[n][c][output_y][output_x]; - idata[n][c][output_y][output_x] = val; - } - } - return; - } - - const accscalar_t support_h = static_cast( - (height_scale >= 1.0) ? (interp_size * 0.5) * height_scale - : interp_size * 0.5); - const accscalar_t support_w = static_cast( - (width_scale >= 1.0) ? (interp_size * 0.5) * width_scale - : interp_size * 0.5); - - const int interp_height = (int)ceilf(support_h) * 2 + 1; - const int interp_width = (int)ceilf(support_w) * 2 + 1; - - // Setup local buffers - // TODO: maybe we can specify dynamic shared memory size before calling the - // cuda code, however we should then ensure that device has enough shared - // memory - scalar_t wx[256]; - scalar_t wy[256]; - - // Compute weights - int xmin, xsize, ymin, ysize; - typedef scalar_t (*filter_fn_t)(scalar_t); - filter_fn_t filter_fn; - if (interp_size == 2) { - filter_fn = bilinear_filter; - } else if (interp_size == 4) { - filter_fn = bicubic_filter; - } - _compute_weights( - output_x, - input_width, - width_scale, - support_w, - wx, - interp_width, - filter_fn, - xmin, - xsize); - _compute_weights( - output_y, - input_height, - height_scale, - support_h, - wy, - interp_height, - filter_fn, - ymin, - ysize); - - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; ++c) { - scalar_t out_value = odata[n][c][output_y][output_x]; - for (int y = 0; y < ysize; y++) { - for (int x = 0; x < xsize; x++) { - upsample_increment_value_bounded( - idata, - n, - c, - input_height, - input_width, - ymin + y, - xmin + x, - wx[x] * wy[y] * out_value); - } - } - } - } -} - -template -static void upsample_gen2d_backward_out_cuda_template( - const Tensor& grad_input, - const Tensor& grad_output_, - IntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w) { - // Copied and adapted from - // UpSampleBicubic2d.cu::upsample_bicubic2d_backward_out_cuda_template - TensorArg grad_input_arg{grad_input, "grad_input", 1}, - grad_output_arg{grad_output_, "grad_output_", 2}; - checkAllSameGPU( - "upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg}); - - int output_height = output_size[0]; - int output_width = output_size[1]; - - int nbatch = input_size[0]; - int channels = input_size[1]; - int input_height = input_size[2]; - int input_width = input_size[3]; - - Tensor grad_output = grad_output_.contiguous(); - - grad_input.zero_(); - - const int num_kernels = output_height * output_width; - const int num_threads = std::min( - at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_output.scalar_type(), "upsample_gen2d_backward_out_frame", [&] { - using accscalar_t = at::acc_type; - - auto idata = grad_input.packed_accessor64(); - auto odata = grad_output.packed_accessor64(); - - const accscalar_t rheight = area_pixel_compute_scale( - input_height, output_height, align_corners, scales_h); - const accscalar_t rwidth = area_pixel_compute_scale( - input_width, output_width, align_corners, scales_w); - - // We are using static buffer memory of 256 * sizeof(float) per thread - // to store weights. Size of weights array is - // interp_size = scale * 2 + 1 for bilinear mode - TORCH_CHECK( - rheight < (255 / interp_size), - "Max supported scale factor is 127 (bilinear), 63 (bicubic)"); - TORCH_CHECK( - rwidth < (255 / interp_size), - "Max supported scale factor is 127 (bilinear), 63 (bicubic)"); - - upsample_gen2d_backward_out_frame - <<>>( - num_kernels, rheight, rwidth, align_corners, idata, odata); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -} // namespace internal_upsample -} // namespace native -} // namespace at - -namespace vision { -namespace ops { - -namespace { - -// Copied from "UpSample.h" as we can not use UpSample.h with UpSample.cuh -static std::array upsample_2d_common_check( - at::IntArrayRef input_size, - at::IntArrayRef output_size) { - TORCH_CHECK( - output_size.size() == 2, - "It is expected output_size equals to 2, but got size ", - output_size.size()); - - TORCH_CHECK( - input_size.size() == 4, - "It is expected input_size equals to 4, but got size ", - input_size.size()); - - int64_t output_height = output_size[0]; - int64_t output_width = output_size[1]; - - int64_t nbatch = input_size[0]; - int64_t channels = input_size[1]; - int64_t input_height = input_size[2]; - int64_t input_width = input_size[3]; - - TORCH_CHECK( - input_height > 0 && input_width > 0 && output_height > 0 && - output_width > 0, - "Input and output sizes should be greater than 0," - " but got input (H: ", - input_height, - ", W: ", - input_width, - ") output (H: ", - output_height, - ", W: ", - output_width, - ")"); - - return {nbatch, channels, output_height, output_width}; -} - -template -at::Tensor interpolate_gen2d_aa_forward_kernel( - const at::Tensor& input, - at::IntArrayRef output_size, - bool align_corners) { - c10::optional> scale_factors = {}; - - // Copied from UpSampleBilinear2d.cpp - auto output = at::empty({0}, input.options()); - auto osize = at::native::upsample::compute_output_size( - input.sizes(), output_size, scale_factors); - auto scale_h = at::native::upsample_cuda::get_scale_value(scale_factors, 0); - auto scale_w = at::native::upsample_cuda::get_scale_value(scale_factors, 1); - - auto full_output_size = upsample_2d_common_check(input.sizes(), osize); - - // Allow for empty batch size but not other dimensions - TORCH_CHECK( - input.numel() != 0 || - c10::multiply_integers( - input.sizes().begin() + 1, input.sizes().end()), - "Non-empty 4D data tensor expected but got a tensor with sizes ", - input.sizes()); - - output.resize_(full_output_size, input.suggest_memory_format()); - - at::native::internal_upsample::upsample_gen2d_out_cuda_template( - output, - input, - {full_output_size[2], full_output_size[3]}, - align_corners, - scale_h, - scale_w); - return output; -} - -template -at::Tensor interpolate_gen2d_aa_backward_kernel( - const at::Tensor& grad_output, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners) { - c10::optional> scale_factors = {}; - - // Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward - auto grad_input = at::empty({0}, grad_output.options()); - auto osize = at::native::upsample::compute_output_size( - input_size, output_size, scale_factors); - auto scale_h = at::native::upsample_cuda::get_scale_value(scale_factors, 0); - auto scale_w = at::native::upsample_cuda::get_scale_value(scale_factors, 1); - - auto full_output_size = upsample_2d_common_check(input_size, osize); - - TORCH_CHECK( - grad_output.dim() == 4, - "Expected grad_output to be a tensor of dimension 4 but got: dimension ", - grad_output.dim()); - - for (int i = 0; i < 4; ++i) { - TORCH_CHECK( - grad_output.size(i) == full_output_size[i], - "Expected grad_output to have the same shape as output;", - " output.size(", - i, - ") = ", - full_output_size[i], - " but got grad_output.size(", - i, - ") = ", - grad_output.size(i)); - } - - grad_input.resize_(input_size, grad_output.suggest_memory_format()); - - at::native::internal_upsample::upsample_gen2d_backward_out_cuda_template< - interp_size>( - grad_input, - grad_output, - {full_output_size[2], full_output_size[3]}, - input_size, - align_corners, - scale_h, - scale_w); - return grad_input; -} - -at::Tensor interpolate_bilinear2d_aa_forward_kernel( - const at::Tensor& input, - at::IntArrayRef output_size, - bool align_corners) { - return interpolate_gen2d_aa_forward_kernel<2>( - input, output_size, align_corners); -} - -at::Tensor interpolate_bicubic2d_aa_forward_kernel( - const at::Tensor& input, - at::IntArrayRef output_size, - bool align_corners) { - return interpolate_gen2d_aa_forward_kernel<4>( - input, output_size, align_corners); -} - -at::Tensor interpolate_bilinear2d_aa_backward_kernel( - const at::Tensor& grad_output, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners) { - return interpolate_gen2d_aa_backward_kernel<2>( - grad_output, output_size, input_size, align_corners); -} - -at::Tensor interpolate_bicubic2d_aa_backward_kernel( - const at::Tensor& grad_output, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners) { - return interpolate_gen2d_aa_backward_kernel<4>( - grad_output, output_size, input_size, align_corners); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa"), - TORCH_FN(interpolate_bilinear2d_aa_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"), - TORCH_FN(interpolate_bicubic2d_aa_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"), - TORCH_FN(interpolate_bilinear2d_aa_backward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"), - TORCH_FN(interpolate_bicubic2d_aa_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/interpolate_aa.cpp b/torchvision/csrc/ops/interpolate_aa.cpp deleted file mode 100644 index 6594f78d731..00000000000 --- a/torchvision/csrc/ops/interpolate_aa.cpp +++ /dev/null @@ -1,76 +0,0 @@ -#include "interpolate_aa.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -at::Tensor _interpolate_bilinear2d_aa( - const at::Tensor& input, // Input image - at::IntArrayRef output_size, // Output image size - bool align_corners) // The flag to align corners -{ - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_interpolate_bilinear2d_aa", "") - .typed(); - return op.call(input, output_size, align_corners); -} - -at::Tensor _interpolate_bicubic_aa( - const at::Tensor& input, // Input image - at::IntArrayRef output_size, // Output image size - bool align_corners) // The flag to align corners -{ - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_interpolate_bicubic2d_aa", "") - .typed(); - return op.call(input, output_size, align_corners); -} - -namespace detail { - -at::Tensor _interpolate_bilinear2d_aa_backward( - const at::Tensor& grad_output, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow( - "torchvision::_interpolate_bilinear2d_aa_backward", "") - .typed(); - return op.call(grad_output, output_size, output_size, align_corners); -} - -at::Tensor _interpolate_bicubic2d_aa_backward( - const at::Tensor& grad_output, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow( - "torchvision::_interpolate_bicubic2d_aa_backward", "") - .typed(); - return op.call(grad_output, output_size, output_size, align_corners); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_interpolate_bilinear2d_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_interpolate_bicubic2d_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_interpolate_bilinear2d_aa_backward(Tensor input, int[] output_size, int[] input_size, bool align_corners) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_interpolate_bicubic2d_aa_backward(Tensor input, int[] output_size, int[] input_size, bool align_corners) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/interpolate_aa.h b/torchvision/csrc/ops/interpolate_aa.h deleted file mode 100644 index 283418b3935..00000000000 --- a/torchvision/csrc/ops/interpolate_aa.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor _interpolate_bilinear2d_aa( - const at::Tensor& input, - at::IntArrayRef output_size, - bool align_corners = false); - -VISION_API at::Tensor _interpolate_bicubic2d_aa( - const at::Tensor& input, - at::IntArrayRef output_size, - bool align_corners = false); - -namespace detail { - -VISION_API at::Tensor _interpolate_bilinear2d_aa_backward( - const at::Tensor& grad, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners = false); - -VISION_API at::Tensor _interpolate_bicubic2d_aa_backward( - const at::Tensor& grad, - at::IntArrayRef output_size, - at::IntArrayRef input_size, - bool align_corners = false); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2ed403e95cb..c12270c3443 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -481,13 +481,7 @@ def resize( # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None - if antialias: - if interpolation == "bilinear": - img = torch.ops.torchvision._interpolate_bilinear2d_aa(img, [new_h, new_w], align_corners=False) - elif interpolation == "bicubic": - img = torch.ops.torchvision._interpolate_bicubic2d_aa(img, [new_h, new_w], align_corners=False) - else: - img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners) + img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias) if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255)