diff --git a/aten/src/ATen/BatchedFallback.cpp b/aten/src/ATen/BatchedFallback.cpp index 396acc9e0403..9b39006b106e 100644 --- a/aten/src/ATen/BatchedFallback.cpp +++ b/aten/src/ATen/BatchedFallback.cpp @@ -156,11 +156,7 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j auto first_physical_view_sizes = input_physical_views.front().tensor().sizes(); auto batch_sizes = ArrayRef( first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims); - auto num_batches = std::accumulate( - batch_sizes.begin(), - batch_sizes.end(), - 1, - std::multiplies()); + const auto num_batches = prod_intlist(batch_sizes); // Without a shape-checking API, we're unable to compute the correct shape of // the output so we just error out. TORCH_CHECK(num_batches > 0, @@ -293,11 +289,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta auto num_batch_dims = input_physical_views.front().numBatchDims(); auto some_sizes = input_physical_views.front().tensor().sizes(); auto batch_sizes = ArrayRef(some_sizes.begin(), some_sizes.begin() + num_batch_dims); - auto num_batches = std::accumulate( - batch_sizes.begin(), - batch_sizes.end(), - 1, - std::multiplies()); + const auto num_batches = prod_intlist(batch_sizes); // Without a shape-checking API, we're unable to compute the correct shape of // the output so we just error out. TORCH_CHECK(num_batches > 0, diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 626e0c73e45e..08588f6a8cdd 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -335,8 +335,7 @@ c10::optional> computeStride( // we use the stride as if it were computed via resize. // This could perhaps be combined with the below code, but the complexity // didn't seem worth it. - int64_t numel = std::accumulate(oldshape.begin(), oldshape.end(), 1, - std::multiplies()); + const int64_t numel = prod_intlist(oldshape); if (numel == 0 && oldshape.equals(newshape)) { return oldstride.vec(); } diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index 87446df09487..4fe4b632362b 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -93,10 +93,18 @@ inline int64_t sum_intlist(ArrayRef list) { return std::accumulate(list.begin(), list.end(), 0ll); } -inline int64_t prod_intlist(ArrayRef list) { - return std::accumulate(list.begin(), list.end(), 1ll, std::multiplies()); +//std::accumulate infers return type from `init` type, so if `init` type is not enough to hold the result, computation can overflow +//the next 2 functions set `init` type to int64_t to avoid overflow. +template::value, int>::type = 0> +inline int64_t prod_intlist(const C &container){ + return std::accumulate(container.begin(), container.end(), static_cast(1), std::multiplies()); } +template::value_type>::value, int>::type = 0> +inline int64_t prod_intlist(Iter begin, Iter end){ + return std::accumulate(begin, end, static_cast(1), std::multiplies()); +} /** * Utility function to static cast input Generator* to * the backend generator type (CPU/CUDAGeneratorImpl etc.) diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index b2b760513a1d..91d804687290 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -27,7 +27,7 @@ Tensor pdist(const Tensor& self, const double p) { Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) { /** This function does the fist part of the euclidean distance calculation - * We divide it in two steps to simplify dealing with subgradients in the + * We divide it in two steps to simplify dealing with subgradients in the * backward step */ Tensor x1_norm = x1.pow(2).sum(-1, true); Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT); @@ -74,7 +74,7 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10 std::vector tensor2_expand_size(expand_batch_portion); tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); - int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies()); + const int64_t expand_batch_product = prod_intlist(expand_batch_portion); std::vector tensor1_view{expand_batch_product, r1, c1}; std::vector tensor2_view{expand_batch_product, r2, c2}; @@ -147,8 +147,10 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c auto device2 = x2.device().type(); TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2); IntArrayRef batch_tensor1(x1.sizes().data(), std::max(x1.dim() - 2, 0)); - int batch_product = std::accumulate(batch_tensor1.begin(), batch_tensor1.end(), 1, std::multiplies()); - Tensor grad_x1 = at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT).view({batch_product, n, m}); + const int64_t batch_product = prod_intlist(batch_tensor1); + Tensor grad_x1 = + at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT) + .view({batch_product, n, m}); cdist_backward_stub(device1, grad_x1, grad, x1, x2, p, cdist); return grad_x1; } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 853d913e9336..8796657dc293 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -675,8 +675,8 @@ Tensor matmul( std::vector tensor2_expand_size(expand_batch_portion); tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p}); - int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), - 1, std::multiplies()); + const int64_t expand_batch_product = + prod_intlist(expand_batch_portion); std::vector tensor1_bmm_view({expand_batch_product}); tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1}); @@ -742,7 +742,7 @@ Tensor _allocate_buffer(const Tensor& a, int n_copies, bool is_zero = false) { {n_copies, a.size(0), a.size(1), a.size(2)}, a.options().memory_format(at::MemoryFormat::Contiguous) ); - + if (is_zero) { res.zero_(); } @@ -850,7 +850,7 @@ Tensor compute_T4(const Tensor& A) { auto As = _allocate_buffer(A, 4); // 3 for {I, A, A^2} _fill_matrix_powers(As, A, 3); - + at::native::matmul( // output for A^2 * (I / 2 + A / 6 + A^2 / 24) As.select(0, 3), @@ -1101,7 +1101,7 @@ Tensor mexp_impl( if (!compute_highest_degree_approx) { constexpr std::array< Tensor(*)(const Tensor&), - total_n_degs - 1> + total_n_degs - 1> compute_Ts = { compute_T1, compute_T2, compute_T4, compute_T8, compute_T12 @@ -1192,7 +1192,7 @@ Tensor mexp(const Tensor& a, bool compute_highest_degree_approx = false) { // Based on: // -// Mathias, Roy. +// Mathias, Roy. // A Chain Rule for Matrix Functions and Applications. // SIAM J. Matrix Anal. Appl. 17 (1996): 610-620. // @@ -1227,8 +1227,8 @@ Tensor backward_analytic_function_of_a_matrix( // Mathematics 2019, 7, 1174. // Tensor matrix_exp(const Tensor& a) { - TORCH_CHECK(a.dim() >= 2 - && (at::isFloatingType(a.scalar_type()) + TORCH_CHECK(a.dim() >= 2 + && (at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())), "matrix_exp(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor " "of floating or complex types with dim at least 2"); diff --git a/aten/src/ATen/native/NaiveDilatedConvolution.cpp b/aten/src/ATen/native/NaiveDilatedConvolution.cpp index 459dd857727f..e80b0c546362 100644 --- a/aten/src/ATen/native/NaiveDilatedConvolution.cpp +++ b/aten/src/ATen/native/NaiveDilatedConvolution.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -181,10 +182,8 @@ void slow_conv_dilated_all_cpu_template( // Temporary buffer: Tensor columns = at::empty({0}, options); if (output.defined() || grad_weight.defined() || grad_input.defined()) { - int64_t m = std::accumulate( - kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); - int64_t n = std::accumulate( - output_size.begin(), output_size.end(), 1, std::multiplies()); + const int64_t m = prod_intlist(kernel_size); + const int64_t n = prod_intlist(output_size); columns.resize_({nInputPlane * m, n}); } // Initialize diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index eec16a4cf824..14fb67e5d4ba 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -368,7 +368,7 @@ static Tensor cat_sparse(TensorList tensors, int64_t dim) { // The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting. int64_t values_dim = wrapped - sparse_dim + 1; // The final size along the catted dimension. - int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), 0, [values_dim](int64_t l, Tensor const &r) { + const int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), static_cast(0), [values_dim](int64_t l, Tensor const &r) { return l + r._values().size(values_dim); }); auto zeros_sizes = tensors[0]._values().sizes().vec(); @@ -1675,7 +1675,7 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty"); TORCH_INTERNAL_ASSERT(!names || names->size() == sizes.size()); - auto numel = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); + const int64_t numel = prod_intlist(sizes); if (self.has_names()) { TORCH_CHECK(numel == self.size(dim), "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", diff --git a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu index 2ad6f0785a17..522e3bbd8760 100644 --- a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu +++ b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu @@ -188,10 +188,8 @@ void slow_conv_dilated_all_cuda_template( int64_t nInputPlane = weight.size(1); int64_t nOutputPlane = weight.size(0); // Temporary buffers: - int64_t m = std::accumulate( - kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); - int64_t output_vsize = std::accumulate( - output_size.begin(), output_size.end(), 1, std::multiplies()); + const int64_t m = prod_intlist(kernel_size); + const int64_t output_vsize = prod_intlist(output_size); Tensor columns = at::empty({0}, options); if (output.defined() || grad_weight.defined() || grad_input.defined()) { columns.resize_({nInputPlane * m, output_vsize}); diff --git a/aten/src/ATen/native/cuda/ScanKernels.cu b/aten/src/ATen/native/cuda/ScanKernels.cu index 6bc2c381e1db..b0dc71c568ba 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cu +++ b/aten/src/ATen/native/cuda/ScanKernels.cu @@ -128,16 +128,16 @@ __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *se */ template __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scalar_t *values_, int64_t *indices_, - int num_orows, int num_irows, int row_size, scalar_t init, BinaryFunction binary_op) { - for (int orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (int irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { scalar_t *self = self_ + orow * row_size * num_irows + irow; scalar_t *values = values_ + orow * row_size * num_irows + irow; int64_t *indices = indices_ + orow * row_size * num_irows + irow; scalar_t out = init; int64_t out_idx = 0; - for (int64_t col = 0; col < row_size; ++col) { + for (auto col = decltype(row_size){0}; col < row_size; ++col) { if(THCNumerics::isnan(*self) || (!THCNumerics::isnan(out) && binary_op(*self, out))) { out = *self; out_idx = col; @@ -152,21 +152,34 @@ __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scala } } +void check_fits_in_unsigned(int64_t val, const char* name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK( + val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); +} + + template __host__ void scan_outer_dim_with_indices(const Tensor& self, Tensor& values, Tensor& indices, int dim, scalar_t init, BinaryFunction binary_op) { - int row_size = self.size(dim); + int64_t row_size = self.size(dim); auto sizes = self.sizes(); // Treat all outer dimensions (i.e. dim_ < dim) as one. - int num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies()); + const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim); // Treat all inner dimensions (i.e. dim > dimension) as one. - int num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies()); + const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end()); + //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row, + //make sure that input is not bigger than supported by uint32_t + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + dim3 threads(std::min(512, int(num_irows))); - int maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int(threads.x)))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); tensor_kernel_scan_outer_dim_with_indices<<>>( self.data_ptr(), values.data_ptr(), indices.data_ptr(), num_orows, num_irows, row_size, init, binary_op); @@ -254,16 +267,16 @@ void cummin_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int */ template __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, - unsigned num_orows, unsigned num_irows, unsigned row_size, - scalar_t init, BinaryOp binary_op) + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, + const scalar_t init, BinaryOp binary_op) { - for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { scalar_t *src = src_ + orow * row_size * num_irows + irow; scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; scalar_t acc = init; - for (unsigned col = 0; col < row_size; ++col) { + for (uint32_t col = 0; col < row_size; ++col) { acc = binary_op(acc, *src); *tgt = acc; @@ -286,12 +299,12 @@ __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, */ template __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *src_, - unsigned num_rows, unsigned row_size, + const uint32_t num_rows, const uint32_t row_size, T init, BinaryFunction binary_op){ - for (unsigned block_row = blockIdx.x * blockDim.y; + for (uint32_t block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { - unsigned row = block_row + threadIdx.y; + uint32_t row = block_row + threadIdx.y; T block_total = init; T *row_src = src_ + row * row_size; @@ -299,10 +312,10 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr // Perform scan on one block at a time, keeping track of the total value of // all blocks processed so far. - for (unsigned block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { // Load data into shared memory (two values per thread). - unsigned col1 = block_col + threadIdx.x; - unsigned col2 = block_col + num_threads_x + threadIdx.x; + uint32_t col1 = block_col + threadIdx.x; + uint32_t col2 = block_col + num_threads_x + threadIdx.x; if (row < num_rows) { if (col1 < row_size) { row_buf[threadIdx.x] = row_src[col1]; @@ -324,18 +337,18 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr __syncthreads(); // Parallel reduction (up-sweep). - for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { + for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { if (row < num_rows && threadIdx.x < s) { - unsigned offset = (2 * threadIdx.x + 1) * d - 1; + uint32_t offset = (2 * threadIdx.x + 1) * d - 1; row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); } __syncthreads(); } // Down-sweep. - for (unsigned s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { + for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { if (row < num_rows && threadIdx.x < s - 1) { - unsigned offset = 2 * (threadIdx.x + 1) * d - 1; + uint32_t offset = 2 * (threadIdx.x + 1) * d - 1; row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); } __syncthreads(); @@ -361,8 +374,8 @@ __global__ typename std::enable_if::value, void>::type tensor_kernel_scan_innermost_dim( T* tgt_, T* src_, - unsigned num_rows, - unsigned row_size, + const uint32_t num_rows, + const uint32_t row_size, T init, BinaryFunction binary_op) { __shared__ T sbuf[num_threads_y][2 * num_threads_x]; @@ -381,8 +394,8 @@ __global__ typename std::enable_if::value, void>::type tensor_kernel_scan_innermost_dim( T* tgt_, T* src_, - unsigned num_rows, - unsigned row_size, + const uint32_t num_rows, + const uint32_t row_size, T init, BinaryFunction binary_op) { // As we cannot directly initialize shared array for complex types @@ -399,23 +412,18 @@ tensor_kernel_scan_innermost_dim( row_buf, tgt_, src_, num_rows, row_size, init, binary_op); } -void check_fits_in_unsigned(int64_t val, const char* name) { - constexpr auto umax = std::numeric_limits::max(); - TORCH_CHECK( - val >= 0 && val <= umax, name, " must fit in a 32-bit unsigned value"); -} template __host__ void scan_outer_dim(const Tensor& self, Tensor& result, int dim, scalar_t init, BinaryFunction binary_op) { - int64_t row_size = self.size(dim); + const int64_t row_size = self.size(dim); auto sizes = self.sizes(); // Treat all outer dimensions (i.e. dim_ < dim) as one. - int64_t num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies()); + const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim); // Treat all inner dimensions (i.e. dim > dimension) as one. - int64_t num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies()); + const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end()); dim3 threads(std::min(512, int(num_irows))); int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 759167095ae3..3bd4daac917b 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -106,11 +106,8 @@ Tensor group_norm( input.sizes()); const auto input_shape = input.sizes(); - const int64_t HxW = std::accumulate( - input_shape.cbegin() + 2, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t HxW = + prod_intlist(input_shape.cbegin() + 2, input_shape.cend()); const Tensor kEmpty; const auto& X = input.is_contiguous() ? input : input.contiguous(); diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index bf931fb26c5f..fa936ab7d4ce 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -52,16 +52,10 @@ std::tuple _prepare_layer_norm_inputs( } const int axis = input_ndim - normalized_ndim; - const int64_t M = std::accumulate( - input_shape.cbegin(), - input_shape.cbegin() + axis, - 1LL, - std::multiplies()); - const int64_t N = std::accumulate( - input_shape.cbegin() + axis, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t M = + prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis); + const int64_t N = + prod_intlist(input_shape.cbegin() + axis, input_shape.cend()); const auto& X = input.is_contiguous() ? input : input.contiguous(); const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); diff --git a/aten/src/ATen/native/metal/MetalTensor.mm b/aten/src/ATen/native/metal/MetalTensor.mm index 6dfe3932bf16..b1fc38d92a6b 100644 --- a/aten/src/ATen/native/metal/MetalTensor.mm +++ b/aten/src/ATen/native/metal/MetalTensor.mm @@ -17,7 +17,7 @@ class API_AVAILABLE(ios(10.0), macos(10.13)) MetalTensor::Impl { _numel(std::accumulate( std::begin(_sizes), std::end(_sizes), - 1, + (int64_t)1, std::multiplies())), _textureImpl(std::make_unique(sizes)) {} diff --git a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp index 6ed193cd82c9..c8bbe9d29b24 100644 --- a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp +++ b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp @@ -71,11 +71,8 @@ Tensor quantized_group_norm_impl( const int64_t batches = input_shape[0]; const int64_t num_channels = input_shape[1]; - const int64_t elements_per_batch = std::accumulate( - input_shape.cbegin() + 1, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t elements_per_batch = + prod_intlist(input_shape.cbegin() + 1, input_shape.cend()); const int64_t M = batches * num_groups; const int64_t N = elements_per_batch / num_groups; diff --git a/aten/src/ATen/native/vulkan/Vulkan.cpp b/aten/src/ATen/native/vulkan/Vulkan.cpp index d6fa6a32291b..d58d7d7dcc09 100644 --- a/aten/src/ATen/native/vulkan/Vulkan.cpp +++ b/aten/src/ATen/native/vulkan/Vulkan.cpp @@ -7,6 +7,7 @@ #include #include +#include #ifdef USE_VULKAN_WRAPPER #include @@ -1182,11 +1183,7 @@ class VulkanTensor::Impl final { explicit Impl(std::vector sizes) : sizes_(std::move(sizes)), strides_(std::vector(sizes_.size())), - numel_(std::accumulate( - std::begin(sizes_), - std::end(sizes_), - 1, - std::multiplies())) { + numel_(prod_intlist(sizes_)) { TORCH_CHECK( initVulkanContextOnce(), "Vulkan Failed to create Vulkan Context"); } @@ -1289,8 +1286,7 @@ class VulkanTensor::Impl final { VkDeviceSize buffer_size_for_sizes(std::vector sizes) const { const auto d = sizes.size(); - const auto numel = std::accumulate( - std::begin(sizes), std::end(sizes), 1, std::multiplies()); + const auto numel = prod_intlist(sizes); VkDeviceSize bufferSize{sizeof(float) * numel}; // alignment to be able to copy between image and buffer if (d == 4) { diff --git a/aten/src/ATen/native/vulkan/VulkanOps.cpp b/aten/src/ATen/native/vulkan/VulkanOps.cpp index 2ad3a695d65b..0e13dce41a7c 100644 --- a/aten/src/ATen/native/vulkan/VulkanOps.cpp +++ b/aten/src/ATen/native/vulkan/VulkanOps.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -553,8 +554,7 @@ void add( void add(VulkanTensor& output, const VulkanTensor& input, const float s) { const auto sizes = input.sizes(); - const auto C = std::accumulate( - sizes.cbegin(), sizes.cend() - 2, 1, std::multiplies()); + const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2); const auto C_4 = UP_DIV(C, 4); const auto H = sizes[2]; const auto W = sizes[3]; @@ -605,8 +605,7 @@ void add(VulkanTensor& output, const VulkanTensor& input, const float s) { void mul(VulkanTensor& output, const VulkanTensor& input, const float s) { const auto sizes = input.sizes(); - const auto C = std::accumulate( - sizes.cbegin(), sizes.cend() - 2, 1, std::multiplies()); + const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2); const auto C_4 = UP_DIV(C, 4); const auto H = sizes[2]; const auto W = sizes[3]; diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.cpp b/aten/src/ATen/native/vulkan/ops/Tensor.cpp index a51fd972d19a..a5baf716069f 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tensor.cpp @@ -22,11 +22,7 @@ VkDeviceSize bytes( size *= extents.width * extents.height * (4u * extents.depth); } else { - size = std::accumulate( - sizes.cbegin(), - sizes.cend(), - size, - std::multiplies()); + size *= prod_intlist(sizes); } return size;