Skip to content
Permalink
Browse files Browse the repository at this point in the history
[lite] Move MultiplyAndCheckOverflow to util to be able to share it.
PiperOrigin-RevId: 416897229
Change-Id: I5feb44881bdcbb6ed911da4f17c55bb978754059
  • Loading branch information
karimnosseir authored and tensorflower-gardener committed Dec 16, 2021
1 parent fba06a0 commit f19be71
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 21 deletions.
2 changes: 2 additions & 0 deletions tensorflow/lite/BUILD
Expand Up @@ -1030,6 +1030,7 @@ cc_library(
copts = tflite_copts_warnings() + tflite_copts(),
deps = [
":kernel_api",
":macros",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
],
Expand Down Expand Up @@ -1083,6 +1084,7 @@ cc_test(
features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs
deps = [
":util",
"//tensorflow/lite/c:c_api_types",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_googletest//:gtest_main",
Expand Down
21 changes: 0 additions & 21 deletions tensorflow/lite/core/subgraph.cc
Expand Up @@ -690,27 +690,6 @@ TfLiteStatus Subgraph::CheckInputAndOutputForOverlap(const int* input_indices,
return kTfLiteOk;
}

namespace {
// Multiply two sizes and return true if overflow occurred;
// This is based off tensorflow/overflow.h but is simpler as we already
// have unsigned numbers. It is also generalized to work where sizeof(size_t)
// is not 8.
TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) {
// Multiplying a * b where a and b are size_t cannot result in overflow in a
// size_t accumulator if both numbers have no non-zero bits in their upper
// half.
constexpr size_t size_t_bits = 8 * sizeof(size_t);
constexpr size_t overflow_upper_half_bit_position = size_t_bits / 2;
*product = a * b;
// If neither integers have non-zero bits past 32 bits can't overflow.
// Otherwise check using slow devision.
if (TFLITE_EXPECT_FALSE((a | b) >> overflow_upper_half_bit_position != 0)) {
if (a != 0 && *product / a != b) return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace

TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims,
size_t dims_size, size_t* bytes) {
TF_LITE_ENSURE(&context_, bytes != nullptr);
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/lite/util.cc
Expand Up @@ -27,6 +27,7 @@ limitations under the License.

#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/macros.h"
#include "tensorflow/lite/schema/schema_generated.h"

namespace tflite {
Expand Down Expand Up @@ -176,4 +177,19 @@ bool IsValidationSubgraph(const char* name) {
// NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed.
return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0;
}

TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) {
// Multiplying a * b where a and b are size_t cannot result in overflow in a
// size_t accumulator if both numbers have no non-zero bits in their upper
// half.
constexpr size_t size_t_bits = 8 * sizeof(size_t);
constexpr size_t overflow_upper_half_bit_position = size_t_bits / 2;
*product = a * b;
// If neither integers have non-zero bits past 32 bits can't overflow.
// Otherwise check using slow devision.
if (TFLITE_EXPECT_FALSE((a | b) >> overflow_upper_half_bit_position != 0)) {
if (a != 0 && *product / a != b) return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace tflite
6 changes: 6 additions & 0 deletions tensorflow/lite/util.h
Expand Up @@ -99,6 +99,12 @@ constexpr char kValidationSubgraphNamePrefix[] = "VALIDATION:";
// Checks whether the prefix of the subgraph name indicates the subgraph is a
// validation subgraph.
bool IsValidationSubgraph(const char* name);

// Multiply two sizes and return true if overflow occurred;
// This is based off tensorflow/overflow.h but is simpler as we already
// have unsigned numbers. It is also generalized to work where sizeof(size_t)
// is not 8.
TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product);
} // namespace tflite

#endif // TENSORFLOW_LITE_UTIL_H_
8 changes: 8 additions & 0 deletions tensorflow/lite/util_test.cc
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <vector>

#include <gtest/gtest.h>
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/schema/schema_generated.h"

Expand Down Expand Up @@ -130,5 +131,12 @@ TEST(ValidationSubgraph, NameIsDetected) {
EXPECT_TRUE(IsValidationSubgraph("VALIDATION:main"));
}

TEST(MultiplyAndCheckOverflow, Validate) {
size_t res = 0;
EXPECT_TRUE(MultiplyAndCheckOverflow(1, 2, &res) == kTfLiteOk);
EXPECT_FALSE(MultiplyAndCheckOverflow(static_cast<size_t>(123456789023),
1223423425, &res) == kTfLiteOk);
}

} // namespace
} // namespace tflite

0 comments on commit f19be71

Please sign in to comment.