diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index 7b44c76fc5cde..f611d1c529797 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -509,30 +509,32 @@ void cpu_upsample_linear_channels_last( } // Helper structs to use with upsample_generic_Nd_kernel_impl -template struct HelperInterpBase { static inline void init_indices_weights( - std::vector & output, int64_t output_size, int64_t ndims, int64_t reshape_dim, int interp_size + at::ScalarType output_type, + std::vector & output, int64_t output_size, int64_t ndims, + int64_t reshape_dim, int interp_size ) { auto new_shape = std::vector(ndims, 1); new_shape[reshape_dim] = output_size; for (int j=0; j()))); - output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType()))); + output.emplace_back(empty(new_shape, CPU(output_type))); } } }; -template -struct HelperInterpNearest : public HelperInterpBase { +struct HelperInterpNearest : public HelperInterpBase { static const int interp_size = 1; static inline void init_indices_weights( - std::vector & output, int64_t output_size, int64_t ndims, int64_t reshape_dim, int interp_size + at::ScalarType output_type, + std::vector & output, int64_t output_size, int64_t ndims, + int64_t reshape_dim, int interp_size ) { auto new_shape = std::vector(ndims, 1); new_shape[reshape_dim] = output_size; @@ -540,7 +542,7 @@ struct HelperInterpNearest : public HelperInterpBase { for (int j=0; j()))); // Defines weights for consistency, but not used - output.emplace_back(at::ones(new_shape, CPU(c10::CppTypeToScalarType()))); + output.emplace_back(at::ones(new_shape, CPU(output_type))); } } @@ -555,32 +557,37 @@ struct HelperInterpNearest : public HelperInterpBase { // fit input/output tensors. // Indices are already containing the strides to optimize the computations static inline std::vector compute_indices_weights( - int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim, - bool align_corners, const c10::optional opt_scale + at::ScalarType scalar_type, + int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, + int64_t reshape_dim, bool align_corners, const c10::optional opt_scale ) { std::vector output; - HelperInterpNearest::init_indices_weights( - output, output_size, ndims, reshape_dim, HelperInterpNearest::interp_size); + HelperInterpNearest::init_indices_weights( + scalar_type, output, output_size, ndims, reshape_dim, HelperInterpNearest::interp_size); + + AT_DISPATCH_FLOATING_TYPES( + scalar_type, "compute_indices_weights_nearest", [&] { - scalar_t scale = area_pixel_compute_scale(input_size, output_size, align_corners, opt_scale); + scalar_t scale = area_pixel_compute_scale(input_size, output_size, align_corners, opt_scale); - auto input_index_ptr = output[0].data_ptr(); - int64_t input_index; + auto input_index_ptr = output[0].data_ptr(); + int64_t input_index; - for (int64_t i=0; i( - scale, i, /*align_corners=*/true, /*cubic=*/false); - input_index = static_cast(floorf(real_input_index)); - input_index_ptr[i] = static_cast(std::min(input_index, input_size - 1)) * stride; - } + for (int64_t i=0; i( + scale, i, /*align_corners=*/true, /*cubic=*/false); + input_index = static_cast(floorf(real_input_index)); + input_index_ptr[i] = static_cast(std::min(input_index, input_size - 1)) * stride; + } + } + ); return output; } }; -template -struct HelperInterpLinear : public HelperInterpBase { +struct HelperInterpLinear : public HelperInterpBase { static const int interp_size = 2; @@ -595,43 +602,47 @@ struct HelperInterpLinear : public HelperInterpBase { // fit input/output tensors. // Indices are already containing the strides to optimize the computations static inline std::vector compute_indices_weights( + at::ScalarType scalar_type, int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim, bool align_corners, const c10::optional opt_scale ) { std::vector output; - HelperInterpLinear::init_indices_weights( - output, output_size, ndims, reshape_dim, HelperInterpLinear::interp_size); - - scalar_t scale = area_pixel_compute_scale(input_size, output_size, align_corners, opt_scale); - - auto input_index0_ptr = output[0].data_ptr(); - auto lambda0_ptr = output[1].data_ptr(); - auto input_index1_ptr = output[2].data_ptr(); - auto lambda1_ptr = output[3].data_ptr(); - - for (int64_t i=0; i( - input_index0_ptr[i], input_index1_ptr[i], - lambda0_ptr[i], lambda1_ptr[i], - scale, i, input_size, output_size, align_corners - ); - // put stride into indices - // index values correspond to input indices (0, 1, 2, 3, ...) - // when multiplied by input stride, maximum possible value - // input_size[dim-1] * input_size[dim-2] * ... for the given dimension. - input_index0_ptr[i] *= stride; - input_index1_ptr[i] *= stride; - } + HelperInterpLinear::init_indices_weights( + scalar_type, output, output_size, ndims, reshape_dim, HelperInterpLinear::interp_size); + + AT_DISPATCH_FLOATING_TYPES( + scalar_type, "compute_indices_weights_linear", [&] { + + scalar_t scale = area_pixel_compute_scale(input_size, output_size, align_corners, opt_scale); + + auto input_index0_ptr = output[0].data_ptr(); + auto lambda0_ptr = output[1].data_ptr(); + auto input_index1_ptr = output[2].data_ptr(); + auto lambda1_ptr = output[3].data_ptr(); + + for (int64_t i=0; i( + input_index0_ptr[i], input_index1_ptr[i], + lambda0_ptr[i], lambda1_ptr[i], + scale, i, input_size, output_size, align_corners + ); + // put stride into indices + // index values correspond to input indices (0, 1, 2, 3, ...) + // when multiplied by input stride, maximum possible value + // input_size[dim-1] * input_size[dim-2] * ... for the given dimension. + input_index0_ptr[i] *= stride; + input_index1_ptr[i] *= stride; + } + } + ); return output; } }; - -template -struct HelperInterpCubic : public HelperInterpBase { +struct HelperInterpCubic : public HelperInterpBase { static const int interp_size = 4; @@ -646,37 +657,43 @@ struct HelperInterpCubic : public HelperInterpBase { // fit input/output tensors. // Indices are already containing the strides to optimize the computations static inline std::vector compute_indices_weights( + at::ScalarType scalar_type, int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim, bool align_corners, const c10::optional opt_scale ) { std::vector output; - HelperInterpCubic::init_indices_weights( - output, output_size, ndims, reshape_dim, HelperInterpCubic::interp_size); + HelperInterpCubic::init_indices_weights( + scalar_type, output, output_size, ndims, reshape_dim, HelperInterpCubic::interp_size); + + AT_DISPATCH_FLOATING_TYPES( + scalar_type, "compute_indices_weights_cubic", [&] { - scalar_t scale = area_pixel_compute_scale(input_size, output_size, align_corners, opt_scale); + scalar_t scale = area_pixel_compute_scale(input_size, output_size, align_corners, opt_scale); - int64_t input_index; - int64_t zero = static_cast(0); - scalar_t coeffs[4]; + int64_t input_index; + int64_t zero = static_cast(0); + scalar_t coeffs[4]; - int64_t * idx_ptr; - scalar_t * wt_ptr; + int64_t * idx_ptr; + scalar_t * wt_ptr; - for (int64_t i=0; i( - scale, i, align_corners, /*cubic=*/true); - input_index = static_cast(floorf(real_input_index)); - get_cubic_upsample_coefficients(coeffs, real_input_index - input_index); + const scalar_t real_input_index = area_pixel_compute_source_index( + scale, i, align_corners, /*cubic=*/true); + input_index = static_cast(floorf(real_input_index)); + get_cubic_upsample_coefficients(coeffs, real_input_index - input_index); - for (int j=0; j(); - idx_ptr[i] = static_cast(std::max(std::min(input_index + j - 1, input_size - 1), zero)) * stride; - wt_ptr = output[2 * j + 1].data_ptr(); - wt_ptr[i] = coeffs[j]; + for (int j=0; j(); + idx_ptr[i] = static_cast(std::max(std::min(input_index + j - 1, input_size - 1), zero)) * stride; + wt_ptr = output[2 * j + 1].data_ptr(); + wt_ptr[i] = coeffs[j]; + } + } } - } + ); return output; } }; @@ -689,7 +706,7 @@ struct HelperInterpCubic : public HelperInterpBase { // - out_ndims is the number of interpolated dims: 1, 2, 3 // - scale_type is template type for scales, typically c10::optional // - template class F is one of the above structs to compute indices and weights -template class F> +template void upsample_generic_Nd_kernel_impl( const Tensor& output, const Tensor& input, @@ -714,7 +731,7 @@ void upsample_generic_Nd_kernel_impl( std::vector> indices_weights; - constexpr int interp_size = F::interp_size; + constexpr int interp_size = F::interp_size; auto input_scalar_type = input.scalar_type(); if (interp_size == 1 && input_scalar_type == at::ScalarType::Byte) { // nearest also supports uint8 tensor, but we have to use float @@ -722,19 +739,15 @@ void upsample_generic_Nd_kernel_impl( input_scalar_type = at::ScalarType::Float; } - AT_DISPATCH_FLOATING_TYPES( - input_scalar_type, "compute_indices_weights_generic", [&] { - for (int i=0; i::compute_indices_weights( - input.size(i + 2), oshape[i + 2], - input.stride(i + 2) * input.element_size(), - input.dim(), i + 2, align_corners, scales[i] - ) - ); - } - } - ); + for (int i=0; i::interp_size; + constexpr int mode = F::interp_size; cpu_upsample_generic(iter); }); } else { AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd", [&] { - constexpr int mode = F::interp_size; + constexpr int mode = F::interp_size; cpu_upsample_generic(iter); }); }