From 1a76e9ffe78e0dcb981e1099faf4c8ce6deca22a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 1 Mar 2021 09:10:57 -0800 Subject: [PATCH] Optimized bilinear interpolation using TensorIterator (#51653) Summary: Related to https://github.com/pytorch/pytorch/issues/10482 Description: - Optimized bilinear interpolation for 1d, 2d, 3d cases using TensorIterator
Interpolation 2d - 6 thread(s) In | Out | Is contiguous | Channels last | master | this PR | speed-up ---|---|---|---|---|---|--- [1, 3, 320, 320] | [256, 256] | True | False | 0.3938 | 0.0782 | 5.0339 [1, 3, 320, 320] | [512, 512] | True | False | 1.5585 | 0.4105 | 3.7965 [1, 3, 320, 320] | [256, 256] | False | False | 0.3481 | 0.0760 | 4.5780 [1, 3, 320, 320] | [512, 512] | False | False | 1.5848 | 0.4091 | 3.8734 [1, 3, 320, 320] | [256, 256] | False | True | 1.2058 | 1.2034 | 1.0020 [1, 3, 320, 320] | [512, 512] | False | True | 4.8691 | 4.8537 | 1.0032 [32, 128, 64, 64] | [32, 32] | False | True | 6.3915 | 6.4041 | 0.9980 [32, 128, 64, 64] | [128, 128] | False | True | 166.1769 | 164.5621 | 1.0098 [32, 128, 64, 64] | [32, 32] | True | False | 3.7194 | 2.4720 | 1.5046 [32, 128, 64, 64] | [128, 128] | True | False | 86.6704 | 52.3754 | 1.6548 [1, 3, 500, 500] | [256, 256] | True | False | 0.3270 | 0.0792 | 4.1307 [1, 3, 500, 500] | [800, 800] | True | False | 3.3116 | 0.5567 | 5.9482 [1, 3, 500, 500] | [256, 256] | False | False | 0.3763 | 0.0773 | 4.8700 [1, 3, 500, 500] | [800, 800] | False | False | 3.2577 | 0.5590 | 5.8279
Interpolation 1d - 6 thread(s) In | Out | Is contiguous | Channels last | master | this PR | speed-up ---|---|---|---|---|---|--- [4, 512, 320] | 256 | True | False | 0.2795 | 0.1032 | 2.7089 [4, 512, 320] | 512 | True | False | 0.5533 | 0.1888 | 2.9303
Interpolation 3d - 6 thread(s) In | Out | Is contiguous | Channels last | master | this PR | speed-up ---|---|---|---|---|---|--- [1, 3, 16, 320, 320] | [8, 256, 256] | True | False | 4.4105 | 2.1236 | 2.0769 [1, 3, 16, 320, 320] | [32, 512, 512] | True | False | 83.9426 | 42.6641 | 1.9675 [1, 3, 16, 320, 320] | [8, 256, 256] | False | True | 15.5736 | 15.5758 | 0.9999 [1, 3, 16, 320, 320] | [32, 512, 512] | False | True | 272.4795 | 273.2745 | 0.9971
Interpolation 2d - 1 thread(s) In | Out | Is contiguous | Channels last | master | this PR | speed-up ---|---|---|---|---|---|--- [1, 3, 320, 320] | [256, 256] | True | False | 1.0240 | 0.4145 | 2.4705 [1, 3, 320, 320] | [512, 512] | True | False | 4.0771 | 1.3836 | 2.9467 [1, 3, 320, 320] | [256, 256] | False | False | 0.9771 | 0.3270 | 2.9878 [1, 3, 320, 320] | [512, 512] | False | False | 4.1732 | 1.2209 | 3.4180 [1, 3, 320, 320] | [256, 256] | False | True | 1.5466 | 1.5363 | 1.0067 [1, 3, 320, 320] | [512, 512] | False | True | 6.1555 | 6.1199 | 1.0058 [32, 128, 64, 64] | [32, 32] | False | True | 27.6362 | 27.5901 | 1.0017 [32, 128, 64, 64] | [128, 128] | False | True | 468.6442 | 465.5163 | 1.0067 [32, 128, 64, 64] | [32, 32] | True | False | 20.1495 | 10.0694 | 2.0011 [32, 128, 64, 64] | [128, 128] | True | False | 400.0401 | 204.0662 | 1.9603 [1, 3, 500, 500] | [256, 256] | True | False | 0.8956 | 0.3366 | 2.6606 [1, 3, 500, 500] | [800, 800] | True | False | 8.6554 | 2.9530 | 2.9310 [1, 3, 500, 500] | [256, 256] | False | False | 1.0921 | 0.3385 | 3.2263 [1, 3, 500, 500] | [800, 800] | False | False | 8.9594 | 2.9627 | 3.0241
Interpolation 1d - 1 thread(s) In | Out | Is contiguous | Channels last | master | this PR | speed-up ---|---|---|---|---|---|--- [4, 512, 320] | 256 | True | False | 1.5233 | 0.5027 | 3.0301 [4, 512, 320] | 512 | True | False | 3.0302 | 0.9735 | 3.1128
Interpolation 3d - 1 thread(s) In | Out | Is contiguous | Channels last | master | this PR | speed-up ---|---|---|---|---|---|--- [1, 3, 16, 320, 320] | [8, 256, 256] | True | False | 12.0477 | 11.3196 | 1.0643 [1, 3, 16, 320, 320] | [32, 512, 512] | True | False | 222.8618 | 209.9955 | 1.0613 [1, 3, 16, 320, 320] | [8, 256, 256] | False | True | 17.9883 | 17.9937 | 0.9997 [1, 3, 16, 320, 320] | [32, 512, 512] | False | True | 380.7244 | 380.1916 | 1.0014
Versions and build configs PyTorch master: 1.9.0.dev20210223 PyTorch master build setting: ``` BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=10.2, CUDNN_VERSION=7.6.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, ``` PR : 1.9.0a0+74b172b PR build setting: ``` BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/g++-7, CXX_FLAGS=-O3 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, ```
This description is based on the benchmarks and the code from [here](https://github.com/vfdev-5/interpolate-tensoriterator/tree/master/step_six). TL;DR - Linear upsampling generic implementation using TensorIterator for Nd case (single loop function for 1d, 2d and 3d cases) - can be generalized to nearest, bicubic interpolation modes. - works for channels first and last cases. Joint work with Francisco Massa (fmassa). Pull Request resolved: https://github.com/pytorch/pytorch/pull/51653 Reviewed By: malfet Differential Revision: D26619437 Pulled By: fmassa fbshipit-source-id: 7d435e23881c5b40a18bf0dbcab4906d5462025f --- aten/src/ATen/native/UpSample.h | 2 + .../ATen/native/cpu/UpSampleMoreKernel.cpp | 392 +++++++++++------- 2 files changed, 255 insertions(+), 139 deletions(-) diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h index 31351293d511c..2abe6d5b02571 100644 --- a/aten/src/ATen/native/UpSample.h +++ b/aten/src/ATen/native/UpSample.h @@ -1,3 +1,5 @@ +#pragma once + #include #include diff --git a/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp b/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp index 4cf021e6b432e..49a746a22c371 100644 --- a/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp @@ -1,14 +1,20 @@ +#include +#include #include #include #include #include #include +#include namespace at { namespace native { namespace { +using scale_t = std::vector>; + + template static inline void compute_source_index_and_lambda( int64_t& input_index0, @@ -37,146 +43,147 @@ static inline void compute_source_index_and_lambda( } } -template -void cpu_upsample_linear( - const Tensor& output_, - const Tensor& input_, - bool align_corners, - const scale_type& scales) { - TORCH_CHECK(input_.dtype() == output_.dtype(), "expected dtype ", input_.dtype(), - " for `output` but got dtype ", output_.dtype()); - auto input = input_.contiguous(); - auto output = output_.contiguous(); - - auto input_data = input.data_ptr(); - auto output_data = output.data_ptr(); - auto input_sizes = input.sizes().vec(); - auto output_sizes = output.sizes().vec(); - auto ndim = input_sizes.size(); - auto numel = output.numel(); - - // treat nbatch and channels as one dimension - int64_t channels = input_sizes[0] * input_sizes[1]; - int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1; - int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1; - int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1; - int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1; - int64_t input_width = input_sizes[ndim - 1]; - int64_t output_width = output_sizes[ndim - 1]; - - int64_t output_slice_size = output_depth * output_height * output_width; +// Helper structs and methods for cpu_upsample_linear +// +// Interpolation methods that used below are separable, and as such we can compute the interpolation +// independently per dimension in a recursive way. Please, refer to #10482 for more context. +// +// Linear Interpolation structure to compute output value in n-dimensional case. +// - recursively compute interpolated output for each dimension +// - we rely a lot on compiler's code optimization such that implemented operations +// can be automatically factorized and vectorized using SSE and AVX2 +template +struct InterpLinear { + static inline scalar_t eval(char* src, char** data, const int64_t* strides, int64_t i) { + index_t i0 = *(index_t*)&data[0][i * strides[0]]; + index_t i1 = *(index_t*)&data[2][i * strides[2]]; + scalar_t w0 = *(scalar_t *)&data[1][i * strides[1]]; + scalar_t w1 = *(scalar_t *)&data[3][i * strides[3]]; + + scalar_t t0 = InterpLinear::eval(src + i0, &data[4], &strides[4], i); + scalar_t t1 = InterpLinear::eval(src + i1, &data[4], &strides[4], i); + + return t0 * w0 + t1 * w1; + } +}; + +template +struct InterpLinear<1, scalar_t, index_t> { + static inline scalar_t eval(char* src, char** data, const int64_t* strides, int64_t i) { + index_t i0 = *(index_t*)&data[0][i * strides[0]]; + index_t i1 = *(index_t*)&data[2][i * strides[2]]; + scalar_t w0 = *(scalar_t *)&data[1][i * strides[1]]; + scalar_t w1 = *(scalar_t *)&data[3][i * strides[3]]; + scalar_t t0 = *(scalar_t *)&src[i0]; + scalar_t t1 = *(scalar_t *)&src[i1]; + return t0 * w0 + t1 * w1; + } +}; - auto loop1d = [&](int64_t begin, int64_t end) { - const scalar_t width_scale = area_pixel_compute_scale( - input_width, output_width, align_corners, scales[0]); +template +static inline scalar_t interp_linear(char* src, char** data, const int64_t* strides, int64_t i) { + return InterpLinear::eval(src, data, strides, i); +} - auto input_indexr = [=](int64_t c, int64_t w) { - return input_data[c * input_width + w]; - }; +static inline bool is_zero_stride(const int64_t* strides) { + return (strides[0] == 0) && (strides[1] == 0) && (strides[2] == 0) && (strides[3] == 0); +} - int64_t iw0, iw1; - scalar_t w0lambda, w1lambda; - for (int64_t c = begin; c < end; c++) { - for (int64_t ow = 0; ow < output_width; ow++) { - compute_source_index_and_lambda( - iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners); - int64_t output_offset = c * output_slice_size + ow; - output_data[output_offset] = - w0lambda * input_indexr(c, iw0) + /* w0 * i0 */ - w1lambda * input_indexr(c, iw1); /* w1 * i1 */ - } - } - }; +template +static inline bool is_contiguous_stride(const int64_t* strides) { + return (strides[0] == sizeof(index_t)) && (strides[1] == sizeof(scalar_t)) && + (strides[2] == sizeof(index_t)) && (strides[3] == sizeof(scalar_t)); +} - auto loop2d = [&](int64_t begin, int64_t end) { - const scalar_t height_scale = area_pixel_compute_scale( - input_height, output_height, align_corners, scales[0]); - const scalar_t width_scale = area_pixel_compute_scale( - input_width, output_width, align_corners, scales[1]); - auto input_indexr = [=](int64_t c, int64_t h, int64_t w) { - return input_data[c * input_height * input_width + h * input_width + w]; - }; +// Helper class to recursively check if all input strides corresponding to interpolated dimensions +// are equal zero except on a single dimension. +// +// Inputs: array of strides of size N, non_zero_stride_dim which can be -1, 0, 1, 2, ... +// if non_zero_stride_dim, we check that all strides are equal zero, otherwise +// 4 strides corresponding to the strides for index_0, weight_0, index_1 and weight_1 for non_zero_stride_dim +// dimension should be non zero. +// +// Unit check of the recursion is to verify whether 4 strides for one interpolated dimension are either zero, +// see method is_zero_stride, or (sizeof(index_t), sizeof(scalar_t), sizeof(index_t), sizeof(scalar_t)), see +// method is_contiguous_stride. +// +// In practice, we have the following cases: +// - for ND, float32, channel first, strides are +// dimN-1, dim1, dim0 +// i0, w0, i1, w1, ..., i0, w0, i1, w1, i0, w0, i1, w1 +// strides=(0, 0, 0, 0, ..., 0, 0, 0, 0, 4, 4, 4, 4) +// +// if size dim0 is 1 then its strides are 0 and dim1 strides are equal 4 +// +// - for ND, float32, channel last, strides are +// dimN-1, dimN-2, dim0 +// i0, w0, i1, w1, i0, w0, i1, w1, ... i0, w0, i1, w1 +// strides=(0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0) +// +// Using these methods we can hint the compiler to factorize constant indices and weights +// in cpu_upsample_linear method +template +struct CheckAlmostAllZeroStrides { + static inline bool eval(const int64_t* strides) { + return (N == non_zero_stride_dim ? is_contiguous_stride(strides) : is_zero_stride(strides)) && + CheckAlmostAllZeroStrides::eval(&strides[4]); + } +}; - int64_t ih0, ih1, iw0, iw1; - scalar_t h0lambda, h1lambda, w0lambda, w1lambda; - for (int64_t c = begin; c < end; c++) { - for (int64_t oh = 0; oh < output_height; oh++) { - compute_source_index_and_lambda( - ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners); - for (int64_t ow = 0; ow < output_width; ow++) { - compute_source_index_and_lambda( - iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners); - int64_t output_offset = c * output_slice_size + oh * output_width + ow; - output_data[output_offset] = - h0lambda * w0lambda * input_indexr(c, ih0, iw0) + /* h0 * w0 * i00 */ - h0lambda * w1lambda * input_indexr(c, ih0, iw1) + /* h0 * w1 * i01 */ - h1lambda * w0lambda * input_indexr(c, ih1, iw0) + /* h1 * w0 * i10 */ - h1lambda * w1lambda * input_indexr(c, ih1, iw1); /* h1 * w1 * i11 */ - } - } - } - }; +template +struct CheckAlmostAllZeroStrides<0, non_zero_stride_dim, scalar_t, index_t> { + static inline bool eval(const int64_t* strides) { + return true; + } +}; - auto loop3d = [&](int64_t begin, int64_t end) { - const scalar_t depth_scale = area_pixel_compute_scale( - input_depth, output_depth, align_corners, scales[0]); - const scalar_t height_scale = area_pixel_compute_scale( - input_height, output_height, align_corners, scales[1]); - const scalar_t width_scale = area_pixel_compute_scale( - input_width, output_width, align_corners, scales[2]); +template +static inline bool is_all_zero_stride(const int64_t* strides) { + return CheckAlmostAllZeroStrides::eval(strides); +} - auto input_indexr = [=](int64_t c, int64_t d, int64_t h, int64_t w) { - return input_data[c * input_depth * input_height * input_width + - d * input_height * input_width + h * input_width + w]; - }; +// Helper method to compute linear interpolation +template +static inline void basic_loop(char** data, const int64_t* strides, int64_t n) { + char* dst = data[0]; + char* src = data[1]; + for (int64_t i = 0; i < n; i++) { + *(scalar_t*)&dst[i * strides[0]] = interp_linear( + src + i * strides[1], &data[2], &strides[2], i); + } +} - int64_t id0, id1, ih0, ih1, iw0, iw1; - scalar_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda; - for (int64_t c = begin; c < end; c++) { - for (int64_t od = 0; od < output_depth; od++) { - compute_source_index_and_lambda( - id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners); - for (int64_t oh = 0; oh < output_height; oh++) { - compute_source_index_and_lambda( - ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners); - for (int64_t ow = 0; ow < output_width; ow++) { - compute_source_index_and_lambda( - iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners); - int64_t output_offset = c * output_slice_size + - od * output_height * output_width + oh * output_width + ow; - output_data[output_offset] = - d0lambda * h0lambda * w0lambda * input_indexr(c, id0, ih0, iw0) + /* d0 * h0 * w0 * i000 */ - d0lambda * h0lambda * w1lambda * input_indexr(c, id0, ih0, iw1) + /* d0 * h0 * w1 * i001 */ - d0lambda * h1lambda * w0lambda * input_indexr(c, id0, ih1, iw0) + /* d0 * h1 * w0 * i010 */ - d0lambda * h1lambda * w1lambda * input_indexr(c, id0, ih1, iw1) + /* d0 * h1 * w1 * i011 */ - d1lambda * h0lambda * w0lambda * input_indexr(c, id1, ih0, iw0) + /* d1 * h0 * w0 * i100 */ - d1lambda * h0lambda * w1lambda * input_indexr(c, id1, ih0, iw1) + /* d1 * h0 * w1 * i101 */ - d1lambda * h1lambda * w0lambda * input_indexr(c, id1, ih1, iw0) + /* d1 * h1 * w0 * i110 */ - d1lambda * h1lambda * w1lambda * input_indexr(c, id1, ih1, iw1); /* d1 * h1 * w1 * i111 */ - } - } - } +// Linear upsampling computation method using TensorIterator for Nd case. +// +// Single loop function for 1d, 2d and 3d cases. +// For N dimensions, output value up to Di dimension can be computed as +// +// output_i[a] = linear_interp(output_{i+1}[a], w_{i+1}[a], output_{i+1}[a+1], w_{i+1}[a+1]) +// with +// output_DN[a] = linear_interp(input_DN[a], w_DN[a], input_DN[a+1], w_DN[a+1]) +// +// The recursive call is implemented with InterpLinear struct using template for +// the loop unrolling on compile time. +template +void cpu_upsample_linear(at::TensorIterator& iter) +{ + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + // special-cases to let the compiler apply compile-time input-specific optimizations + if ((strides[0] == sizeof(scalar_t) && (strides[1] == 0) && + is_all_zero_stride(&strides[2]))) { + // contiguous channels-first case + basic_loop(data, strides, n); + } else if ((strides[0] == sizeof(scalar_t) && (strides[1] == sizeof(scalar_t)) && + is_all_zero_stride(&strides[2]))) { + // contiguous channels-last case + basic_loop(data, strides, n); + } else { + // fallback + basic_loop(data, strides, n); } }; - - // compared to "nearest" mode, lower the grain size: - // "linear", "bilinear", "trilinear" mode are more computational expensive - if (ndim == 3) { - // upsample linear 1d - at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 2, loop1d); - } else if (ndim == 4){ - // upsample bilinear 2d - at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d); - } else { - // upsample trilinear 3d - TORCH_INTERNAL_ASSERT(ndim == 5); - at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d); - } - - if (!output_.is_contiguous()) { - output_.copy_(output); - } + iter.for_each(loop); } template @@ -484,15 +491,126 @@ void cpu_upsample_linear_backward( } } -using scale_t = std::vector>; + +// Method to compute indices and weights for each interpolated dimension +// indices_weights = { +// {indices_0, weights_0, indices_1, weights_1}, // dim -n +// {indices_0, weights_0, indices_1, weights_1}, // dim -(n-1) +// ... +// {indices_0, weights_0, indices_1, weights_1}, // dim -1 +// } +// Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to +// fit input/output tensors. +// Indices are already containing the strides to optimize the computations +// +template +std::vector compute_indices_weights_linear( + int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim, + bool align_corners, const c10::optional opt_scale +) { + + scalar_t scale = area_pixel_compute_scale(input_size, output_size, align_corners, opt_scale); + + std::vector output; + auto new_shape = std::vector(ndims, 1); + new_shape[reshape_dim] = output_size; + + output.emplace_back(empty(new_shape, CPU(at::kLong))); + output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType()))); + output.emplace_back(empty(new_shape, CPU(at::kLong))); + output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType()))); + + auto input_index0_ptr = output[0].data_ptr(); + auto lambda0_ptr = output[1].data_ptr(); + auto input_index1_ptr = output[2].data_ptr(); + auto lambda1_ptr = output[3].data_ptr(); + + for (int64_t i=0; i( + input_index0_ptr[i], input_index1_ptr[i], + lambda0_ptr[i], lambda1_ptr[i], + scale, i, input_size, output_size, align_corners + ); + // put stride into indices + // index values correspond to input indices (0, 1, 2, 3, ...) + // when multiplied by input stride, maximum possible value + // input_size[dim-1] * input_size[dim-2] * ... for the given dimension. + input_index0_ptr[i] *= stride; + input_index1_ptr[i] *= stride; + } + return output; +} + +// Upsampling linear interpolation kernel for N-d case. +// Input is assumed to be like NCHW, NCL, NCKHW - interpolated spatial dimension +// are those from the end up to batch size N and number of channels C. +// +// Internally, it uses TensorIterator to optimize the computations. +// - out_ndims is the number of interpolated dims: 1, 2, 3 +// - scale_type is template type for scales, typically c10::optional +template +void upsample_linearNd_kernel_impl( + const Tensor& output, + const Tensor& input, + bool align_corners, + const scale_type& scales) { + + // We apply a similar logic as in advanced indexing implementation + // - output spatial dimensions are different from input ones and + // we have to restride the input to have strides equal zeron on + // spatial dimensions + auto shape = input.sizes().vec(); + auto strides = input.strides().vec(); + auto oshape = output.sizes(); + + for (int i=0; i> indices_weights; + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "compute_indices_weights_linear", [&] { + auto es = input.element_size(); + for (int i=0; i( + input.size(i + 2), oshape[i + 2], input.stride(i + 2) * es, input.dim(), i + 2, align_corners, scales[i]) + ); + } + } + ); + + TensorIteratorConfig config; + config.check_all_same_dtype(false) + .declare_static_dtype_and_device(input.scalar_type(), input.device()) + .add_output(output) + .add_input(restrided_input); + + for (auto iter=indices_weights.begin(); iter!=indices_weights.end(); iter++) { + for (auto& tensor : *iter) { + config.add_input(tensor); + } + } + + auto iter = config.build(); + + AT_DISPATCH_FLOATING_TYPES( + iter.dtype(), "upsample_linearNd", [&] { + cpu_upsample_linear(iter); + }); + +} + void upsample_linear1d_kernel_impl( const Tensor& output, const Tensor& input, bool align_corners, c10::optional scales_w) { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "upsample_linear1d", [&] { - cpu_upsample_linear(output, input, align_corners, {scales_w}); - }); + upsample_linearNd_kernel_impl<1, scale_t>( + output, input, align_corners, {scales_w}); } void upsample_bilinear2d_kernel_impl( @@ -506,9 +624,7 @@ void upsample_bilinear2d_kernel_impl( cpu_upsample_linear_channels_last(output, input, align_corners, {scales_h, scales_w}); }); } else { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "upsample_bilinear2d", [&] { - cpu_upsample_linear(output, input, align_corners, {scales_h, scales_w}); - }); + upsample_linearNd_kernel_impl<2, scale_t>(output, input, align_corners, {scales_h, scales_w}); } } @@ -524,9 +640,7 @@ void upsample_trilinear3d_kernel_impl( cpu_upsample_linear_channels_last(output, input, align_corners, {scales_d, scales_h, scales_w}); }); } else { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "upsample_trilinear3d", [&] { - cpu_upsample_linear(output, input, align_corners, {scales_d, scales_h, scales_w}); - }); + upsample_linearNd_kernel_impl<3, scale_t>(output, input, align_corners, {scales_d, scales_h, scales_w}); } }