Skip to content

Commit

Permalink
Add relevant shape check for tf.reshape to prevent crash
Browse files Browse the repository at this point in the history
This PR tries to address the issue raised in 46693 where
a shape with large number of elements will cause the
tf.reshape to crash.

This PR adds relevant shape check so that error message can
be returned gracefully.

This PR fixes 46693

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang committed Mar 1, 2021
1 parent b904b76 commit a41733a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/reshape_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/overflow.h"

namespace tensorflow {

Expand Down Expand Up @@ -135,6 +136,17 @@ class ReshapeOp : public OpKernel {
shape->AddDim(size);
*has_zero_dim = true;
} else {
if (MultiplyWithoutOverflow(shape->num_elements(), size) < 0) {
string msg;
for (int ii = 0; ii < num_dims; ++ii) {
if (ii != 0) {
strings::StrAppend(&msg, ", ");
}
strings::StrAppend(&msg, Svec(ii));
}
return errors::InvalidArgument("Shape [", msg,
"] has too many elements");
}
shape->AddDim(size);
(*product) *= size;
}
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/kernel_tests/reshape_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
Expand Down Expand Up @@ -214,6 +215,14 @@ def testInt64Shape(self):
y = array_ops.reshape(x, [1, 50000**2])
self.assertEqual([1, 50000**2], y.get_shape().as_list())

@test_util.run_v2_only
def testTooLargeShape(self):
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError, "too many elements"):
x = array_ops.reshape([1], np.array([21943, 45817, 30516, 61760, 38987]))
self.evaluate(x)



if __name__ == "__main__":
test.main()

0 comments on commit a41733a

Please sign in to comment.