Skip to content

Commit

Permalink
fix calculation of number of elements to not overflow (#46997)
Browse files Browse the repository at this point in the history
Summary:
Possibly fixes #46764.
Computing number of tensor elements in many cases is written as
```
int64_t numel = std::accumulate(oldshape.begin(), oldshape.end(), 1,
                                  std::multiplies<int64_t>());
```
This computes the product with the type of `1` literal, which is `int`. When there's more than INT_MAX elements, result overflows. In #46746, the tensor that was sent to reshape had 256^4 elements, and that was computed as `0`, so reshape was not done correctly.
I've audited usages of std::accumulate and changed them to use int64_t as `init` type.

Pull Request resolved: #46997

Reviewed By: albanD

Differential Revision: D24624654

Pulled By: ngimel

fbshipit-source-id: 3d9c5e6355531a9df6b10500eec140e020aac77e
  • Loading branch information
ngimel authored and facebook-github-bot committed Oct 29, 2020
1 parent 78de12f commit e17b8de
Show file tree
Hide file tree
Showing 16 changed files with 94 additions and 109 deletions.
12 changes: 2 additions & 10 deletions aten/src/ATen/BatchedFallback.cpp
Expand Up @@ -156,11 +156,7 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j
auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
auto batch_sizes = ArrayRef<int64_t>(
first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
auto num_batches = std::accumulate(
batch_sizes.begin(),
batch_sizes.end(),
1,
std::multiplies<int64_t>());
const auto num_batches = prod_intlist(batch_sizes);
// Without a shape-checking API, we're unable to compute the correct shape of
// the output so we just error out.
TORCH_CHECK(num_batches > 0,
Expand Down Expand Up @@ -293,11 +289,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
auto num_batch_dims = input_physical_views.front().numBatchDims();
auto some_sizes = input_physical_views.front().tensor().sizes();
auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
auto num_batches = std::accumulate(
batch_sizes.begin(),
batch_sizes.end(),
1,
std::multiplies<int64_t>());
const auto num_batches = prod_intlist(batch_sizes);
// Without a shape-checking API, we're unable to compute the correct shape of
// the output so we just error out.
TORCH_CHECK(num_batches > 0,
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/TensorUtils.cpp
Expand Up @@ -335,8 +335,7 @@ c10::optional<std::vector<int64_t>> computeStride(
// we use the stride as if it were computed via resize.
// This could perhaps be combined with the below code, but the complexity
// didn't seem worth it.
int64_t numel = std::accumulate(oldshape.begin(), oldshape.end(), 1,
std::multiplies<int64_t>());
const int64_t numel = prod_intlist(oldshape);
if (numel == 0 && oldshape.equals(newshape)) {
return oldstride.vec();
}
Expand Down
12 changes: 10 additions & 2 deletions aten/src/ATen/Utils.h
Expand Up @@ -93,10 +93,18 @@ inline int64_t sum_intlist(ArrayRef<int64_t> list) {
return std::accumulate(list.begin(), list.end(), 0ll);
}

inline int64_t prod_intlist(ArrayRef<int64_t> list) {
return std::accumulate(list.begin(), list.end(), 1ll, std::multiplies<int64_t>());
//std::accumulate infers return type from `init` type, so if `init` type is not enough to hold the result, computation can overflow
//the next 2 functions set `init` type to int64_t to avoid overflow.
template<typename C, typename std::enable_if<std::is_integral<typename C::value_type>::value, int>::type = 0>
inline int64_t prod_intlist(const C &container){
return std::accumulate(container.begin(), container.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
}

template<typename Iter,
typename std::enable_if<std::is_integral<typename std::iterator_traits<Iter>::value_type>::value, int>::type = 0>
inline int64_t prod_intlist(Iter begin, Iter end){
return std::accumulate(begin, end, static_cast<int64_t>(1), std::multiplies<int64_t>());
}
/**
* Utility function to static cast input Generator* to
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/native/Distance.cpp
Expand Up @@ -27,7 +27,7 @@ Tensor pdist(const Tensor& self, const double p) {

Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) {
/** This function does the fist part of the euclidean distance calculation
* We divide it in two steps to simplify dealing with subgradients in the
* We divide it in two steps to simplify dealing with subgradients in the
* backward step */
Tensor x1_norm = x1.pow(2).sum(-1, true);
Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Expand Down Expand Up @@ -74,7 +74,7 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});

int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies<int64_t>());
const int64_t expand_batch_product = prod_intlist(expand_batch_portion);
std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};

Expand Down Expand Up @@ -147,8 +147,10 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c
auto device2 = x2.device().type();
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2);
IntArrayRef batch_tensor1(x1.sizes().data(), std::max<int64_t>(x1.dim() - 2, 0));
int batch_product = std::accumulate(batch_tensor1.begin(), batch_tensor1.end(), 1, std::multiplies<int64_t>());
Tensor grad_x1 = at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT).view({batch_product, n, m});
const int64_t batch_product = prod_intlist(batch_tensor1);
Tensor grad_x1 =
at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT)
.view({batch_product, n, m});
cdist_backward_stub(device1, grad_x1, grad, x1, x2, p, cdist);
return grad_x1;
}
Expand Down
16 changes: 8 additions & 8 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -675,8 +675,8 @@ Tensor matmul(
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p});

int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(),
1, std::multiplies<int64_t>());
const int64_t expand_batch_product =
prod_intlist(expand_batch_portion);

std::vector<int64_t> tensor1_bmm_view({expand_batch_product});
tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1});
Expand Down Expand Up @@ -742,7 +742,7 @@ Tensor _allocate_buffer(const Tensor& a, int n_copies, bool is_zero = false) {
{n_copies, a.size(0), a.size(1), a.size(2)},
a.options().memory_format(at::MemoryFormat::Contiguous)
);

if (is_zero) {
res.zero_();
}
Expand Down Expand Up @@ -850,7 +850,7 @@ Tensor compute_T4(const Tensor& A) {
auto As = _allocate_buffer(A, 4);
// 3 for {I, A, A^2}
_fill_matrix_powers(As, A, 3);

at::native::matmul(
// output for A^2 * (I / 2 + A / 6 + A^2 / 24)
As.select(0, 3),
Expand Down Expand Up @@ -1101,7 +1101,7 @@ Tensor mexp_impl(
if (!compute_highest_degree_approx) {
constexpr std::array<
Tensor(*)(const Tensor&),
total_n_degs - 1>
total_n_degs - 1>
compute_Ts = {
compute_T1, compute_T2, compute_T4<scalar_t>,
compute_T8<scalar_t>, compute_T12<scalar_t>
Expand Down Expand Up @@ -1192,7 +1192,7 @@ Tensor mexp(const Tensor& a, bool compute_highest_degree_approx = false) {

// Based on:
//
// Mathias, Roy.
// Mathias, Roy.
// A Chain Rule for Matrix Functions and Applications.
// SIAM J. Matrix Anal. Appl. 17 (1996): 610-620.
//
Expand Down Expand Up @@ -1227,8 +1227,8 @@ Tensor backward_analytic_function_of_a_matrix(
// Mathematics 2019, 7, 1174.
//
Tensor matrix_exp(const Tensor& a) {
TORCH_CHECK(a.dim() >= 2
&& (at::isFloatingType(a.scalar_type())
TORCH_CHECK(a.dim() >= 2
&& (at::isFloatingType(a.scalar_type())
|| at::isComplexType(a.scalar_type())),
"matrix_exp(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor "
"of floating or complex types with dim at least 2");
Expand Down
7 changes: 3 additions & 4 deletions aten/src/ATen/native/NaiveDilatedConvolution.cpp
Expand Up @@ -2,6 +2,7 @@

#include <tuple>
#include <ATen/ATen.h>
#include <ATen/Utils.h>
#include <ATen/native/im2col.h>
#include <ATen/native/vol2col.h>

Expand Down Expand Up @@ -181,10 +182,8 @@ void slow_conv_dilated_all_cpu_template(
// Temporary buffer:
Tensor columns = at::empty({0}, options);
if (output.defined() || grad_weight.defined() || grad_input.defined()) {
int64_t m = std::accumulate(
kernel_size.begin(), kernel_size.end(), 1, std::multiplies<int64_t>());
int64_t n = std::accumulate(
output_size.begin(), output_size.end(), 1, std::multiplies<int64_t>());
const int64_t m = prod_intlist(kernel_size);
const int64_t n = prod_intlist(output_size);
columns.resize_({nInputPlane * m, n});
}
// Initialize
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -368,7 +368,7 @@ static Tensor cat_sparse(TensorList tensors, int64_t dim) {
// The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting.
int64_t values_dim = wrapped - sparse_dim + 1;
// The final size along the catted dimension.
int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), 0, [values_dim](int64_t l, Tensor const &r) {
const int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), static_cast<int64_t>(0), [values_dim](int64_t l, Tensor const &r) {
return l + r._values().size(values_dim);
});
auto zeros_sizes = tensors[0]._values().sizes().vec();
Expand Down Expand Up @@ -1675,7 +1675,7 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option
TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty");
TORCH_INTERNAL_ASSERT(!names || names->size() == sizes.size());

auto numel = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<int64_t>());
const int64_t numel = prod_intlist(sizes);
if (self.has_names()) {
TORCH_CHECK(numel == self.size(dim),
"unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
Expand Down
6 changes: 2 additions & 4 deletions aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu
Expand Up @@ -188,10 +188,8 @@ void slow_conv_dilated_all_cuda_template(
int64_t nInputPlane = weight.size(1);
int64_t nOutputPlane = weight.size(0);
// Temporary buffers:
int64_t m = std::accumulate(
kernel_size.begin(), kernel_size.end(), 1, std::multiplies<int64_t>());
int64_t output_vsize = std::accumulate(
output_size.begin(), output_size.end(), 1, std::multiplies<int64_t>());
const int64_t m = prod_intlist(kernel_size);
const int64_t output_vsize = prod_intlist(output_size);
Tensor columns = at::empty({0}, options);
if (output.defined() || grad_weight.defined() || grad_input.defined()) {
columns.resize_({nInputPlane * m, output_vsize});
Expand Down
80 changes: 44 additions & 36 deletions aten/src/ATen/native/cuda/ScanKernels.cu
Expand Up @@ -128,16 +128,16 @@ __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *se
*/
template<typename scalar_t, class BinaryFunction>
__global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scalar_t *values_, int64_t *indices_,
int num_orows, int num_irows, int row_size, scalar_t init, BinaryFunction binary_op) {
for (int orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (int irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) {
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
scalar_t *self = self_ + orow * row_size * num_irows + irow;
scalar_t *values = values_ + orow * row_size * num_irows + irow;
int64_t *indices = indices_ + orow * row_size * num_irows + irow;
scalar_t out = init;
int64_t out_idx = 0;

for (int64_t col = 0; col < row_size; ++col) {
for (auto col = decltype(row_size){0}; col < row_size; ++col) {
if(THCNumerics<scalar_t>::isnan(*self) || (!THCNumerics<scalar_t>::isnan(out) && binary_op(*self, out))) {
out = *self;
out_idx = col;
Expand All @@ -152,21 +152,34 @@ __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scala
}
}

void check_fits_in_unsigned(int64_t val, const char* name) {
constexpr auto umax = std::numeric_limits<uint32_t>::max();
TORCH_CHECK(
val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value");
}


template<typename scalar_t, class BinaryFunction>
__host__ void scan_outer_dim_with_indices(const Tensor& self, Tensor& values, Tensor& indices,
int dim, scalar_t init, BinaryFunction binary_op) {
int row_size = self.size(dim);
int64_t row_size = self.size(dim);
auto sizes = self.sizes();

// Treat all outer dimensions (i.e. dim_ < dim) as one.
int num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies<int>());
const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim);

// Treat all inner dimensions (i.e. dim > dimension) as one.
int num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies<int>());
const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end());
//for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row,
//make sure that input is not bigger than supported by uint32_t
check_fits_in_unsigned(num_irows, "num_irows");
check_fits_in_unsigned(num_orows, "num_orows");
check_fits_in_unsigned(row_size, "row_size");


dim3 threads(std::min(512, int(num_irows)));
int maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int(threads.x))));
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
tensor_kernel_scan_outer_dim_with_indices<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
self.data_ptr<scalar_t>(), values.data_ptr<scalar_t>(), indices.data_ptr<int64_t>(),
num_orows, num_irows, row_size, init, binary_op);
Expand Down Expand Up @@ -254,16 +267,16 @@ void cummin_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int
*/
template<typename scalar_t, class BinaryOp>
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_,
unsigned num_orows, unsigned num_irows, unsigned row_size,
scalar_t init, BinaryOp binary_op)
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
const scalar_t init, BinaryOp binary_op)
{
for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
scalar_t *src = src_ + orow * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
scalar_t acc = init;

for (unsigned col = 0; col < row_size; ++col) {
for (uint32_t col = 0; col < row_size; ++col) {
acc = binary_op(acc, *src);
*tgt = acc;

Expand All @@ -286,23 +299,23 @@ __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_,
*/
template<typename T, int num_threads_x, int num_threads_y, class BinaryFunction>
__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *src_,
unsigned num_rows, unsigned row_size,
const uint32_t num_rows, const uint32_t row_size,
T init, BinaryFunction binary_op){
for (unsigned block_row = blockIdx.x * blockDim.y;
for (uint32_t block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
unsigned row = block_row + threadIdx.y;
uint32_t row = block_row + threadIdx.y;
T block_total = init;

T *row_src = src_ + row * row_size;
T *row_tgt = tgt_ + row * row_size;

// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (unsigned block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
unsigned col1 = block_col + threadIdx.x;
unsigned col2 = block_col + num_threads_x + threadIdx.x;
uint32_t col1 = block_col + threadIdx.x;
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
Expand All @@ -324,18 +337,18 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr
__syncthreads();

// Parallel reduction (up-sweep).
for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
unsigned offset = (2 * threadIdx.x + 1) * d - 1;
uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}

// Down-sweep.
for (unsigned s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
unsigned offset = 2 * (threadIdx.x + 1) * d - 1;
uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
Expand All @@ -361,8 +374,8 @@ __global__ typename std::enable_if<!c10::is_complex<T>::value, void>::type
tensor_kernel_scan_innermost_dim(
T* tgt_,
T* src_,
unsigned num_rows,
unsigned row_size,
const uint32_t num_rows,
const uint32_t row_size,
T init,
BinaryFunction binary_op) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
Expand All @@ -381,8 +394,8 @@ __global__ typename std::enable_if<c10::is_complex<T>::value, void>::type
tensor_kernel_scan_innermost_dim(
T* tgt_,
T* src_,
unsigned num_rows,
unsigned row_size,
const uint32_t num_rows,
const uint32_t row_size,
T init,
BinaryFunction binary_op) {
// As we cannot directly initialize shared array for complex types
Expand All @@ -399,23 +412,18 @@ tensor_kernel_scan_innermost_dim(
row_buf, tgt_, src_, num_rows, row_size, init, binary_op);
}

void check_fits_in_unsigned(int64_t val, const char* name) {
constexpr auto umax = std::numeric_limits<unsigned>::max();
TORCH_CHECK(
val >= 0 && val <= umax, name, " must fit in a 32-bit unsigned value");
}

template<typename scalar_t, class BinaryFunction>
__host__ void scan_outer_dim(const Tensor& self, Tensor& result,
int dim, scalar_t init, BinaryFunction binary_op) {
int64_t row_size = self.size(dim);
const int64_t row_size = self.size(dim);
auto sizes = self.sizes();

// Treat all outer dimensions (i.e. dim_ < dim) as one.
int64_t num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies<int64_t>());
const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim);

// Treat all inner dimensions (i.e. dim > dimension) as one.
int64_t num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies<int64_t>());
const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end());

dim3 threads(std::min(512, int(num_irows)));
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
Expand Down
7 changes: 2 additions & 5 deletions aten/src/ATen/native/group_norm.cpp
Expand Up @@ -106,11 +106,8 @@ Tensor group_norm(
input.sizes());

const auto input_shape = input.sizes();
const int64_t HxW = std::accumulate(
input_shape.cbegin() + 2,
input_shape.cend(),
1LL,
std::multiplies<int64_t>());
const int64_t HxW =
prod_intlist(input_shape.cbegin() + 2, input_shape.cend());

const Tensor kEmpty;
const auto& X = input.is_contiguous() ? input : input.contiguous();
Expand Down

0 comments on commit e17b8de

Please sign in to comment.