Skip to content
Permalink
Browse files Browse the repository at this point in the history
[security] Fix int overflow in RaggedRangeOp.
PiperOrigin-RevId: 461749624
  • Loading branch information
JXRiver authored and tensorflower-gardener committed Jul 19, 2022
1 parent 222ca8a commit 37cefa9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
35 changes: 20 additions & 15 deletions tensorflow/core/kernels/ragged_range_op.cc
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <limits>
#include <memory>
#include <string>
Expand Down Expand Up @@ -78,8 +79,25 @@ class RaggedRangeOp : public OpKernel {
T limit = broadcast_limits ? limits(0) : limits(row);
T delta = broadcast_deltas ? deltas(0) : deltas(row);
OP_REQUIRES(context, delta != 0, InvalidArgument("Requires delta != 0"));
rt_nested_splits(row + 1) =
rt_nested_splits(row) + RangeSize(start, limit, delta);
int64_t size; // The number of elements in the specified range.
if (((delta > 0) && (limit < start)) ||
((delta < 0) && (limit > start))) {
size = 0;
} else if (std::is_integral<T>::value) {
// The following is copied from tensorflow::RangeOp::Compute().
size = Eigen::divup(Eigen::numext::abs(limit - start),
Eigen::numext::abs(delta));
} else {
// The following is copied from tensorflow::RangeOp::Compute().
auto size_auto =
Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta));
OP_REQUIRES(
context, size_auto <= std::numeric_limits<int64_t>::max(),
errors::InvalidArgument("Requires ((limit - start) / delta) <= ",
std::numeric_limits<int64_t>::max()));
size = static_cast<int64_t>(size_auto);
}
rt_nested_splits(row + 1) = rt_nested_splits(row) + size;
}
SPLITS_TYPE nvals = rt_nested_splits(nrows);

Expand All @@ -99,19 +117,6 @@ class RaggedRangeOp : public OpKernel {
}
}
}

private:
// Returns the number of elements in the specified range.
SPLITS_TYPE RangeSize(T start, T limit, T delta) {
if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
return 0;
}
// The following is copied from tensorflow::RangeOp::Compute().
return (std::is_integral<T>::value
? ((std::abs(limit - start) + std::abs(delta) - 1) /
std::abs(delta))
: std::ceil(std::abs((limit - start) / delta)));
}
};

#define REGISTER_CPU_KERNEL(TYPE) \
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/ragged_range_op_test.cc
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <gtest/gtest.h>
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
Expand Down Expand Up @@ -77,6 +78,17 @@ TEST_F(RaggedRangeOpTest, FloatValues) {
test::AsTensor<float>({0, 2, 4, 6, 5, 6, 5, 4, 3, 2}), 0.1);
}

TEST_F(RaggedRangeOpTest, RangeSizeOverflow) {
BuildRaggedRangeGraph<float>();
AddInputFromArray<float>(TensorShape({2}), {1.1, 0.1}); // starts
AddInputFromArray<float>(TensorShape({2}), {10.0, 1e10}); // limits
AddInputFromArray<float>(TensorShape({2}), {1, 1e-10}); // deltas

EXPECT_EQ(absl::StrCat("Requires ((limit - start) / delta) <= ",
std::numeric_limits<int64_t>::max()),
RunOpKernel().error_message());
}

TEST_F(RaggedRangeOpTest, BroadcastDeltas) {
BuildRaggedRangeGraph<int>();
AddInputFromArray<int>(TensorShape({3}), {0, 5, 8}); // starts
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/python/ops/ragged/ragged_range_op_test.py
Expand Up @@ -84,8 +84,7 @@ def testBroadcast(self):
list(range(5, 15, 3))])

# Broadcast all arguments.
self.assertAllEqual(
ragged_math_ops.range(0, 5, 1), [list(range(0, 5, 1))])
self.assertAllEqual(ragged_math_ops.range(0, 5, 1), [list(range(0, 5, 1))])

def testEmptyRanges(self):
rt1 = ragged_math_ops.range([0, 5, 3], [0, 3, 5])
Expand All @@ -108,6 +107,10 @@ def testKernelErrors(self):
r'Requires delta != 0'):
self.evaluate(ragged_math_ops.range(0, 0, 0))

with self.assertRaisesRegex(errors.InvalidArgumentError,
r'Requires \(\(limit - start\) / delta\) <='):
self.evaluate(ragged_math_ops.range(0.1, 1e10, 1e-10))

def testShape(self):
self.assertAllEqual(
ragged_math_ops.range(0, 0, 1).shape.as_list(), [1, None])
Expand Down

0 comments on commit 37cefa9

Please sign in to comment.