Skip to content
Permalink
Browse files Browse the repository at this point in the history
[lite] Check for overflow when creating required bytes.
PiperOrigin-RevId: 417629001
Change-Id: Ia7feb3ea8e988f4fd4b3c98c1a1fed4557d99fd7
  • Loading branch information
karimnosseir authored and tensorflower-gardener committed Dec 21, 2021
1 parent ca38b92 commit 1de4972
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions tensorflow/lite/kernels/embedding_lookup_sparse.cc
Expand Up @@ -72,6 +72,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/util.h"

namespace tflite {
namespace ops {
Expand Down Expand Up @@ -175,25 +176,33 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
TF_LITE_ENSURE(context, output_shape != nullptr);
int k = 0;
int embedding_size = 1;
int lookup_size = 1;
size_t embedding_size = 1;
size_t lookup_size = 1;
for (int i = 0; i < lookup_rank - 1; i++, k++) {
const int dim = dense_shape->data.i32[i];
lookup_size *= dim;
const size_t dim = dense_shape->data.i32[i];
TF_LITE_ENSURE_MSG(
context,
MultiplyAndCheckOverflow(lookup_size, dim, &lookup_size) == kTfLiteOk,
"Lookup size overflowed.");
output_shape->data[k] = dim;
}
for (int i = 1; i < embedding_rank; i++, k++) {
const int dim = SizeOfDimension(value, i);
embedding_size *= dim;
const size_t dim = SizeOfDimension(value, i);
TF_LITE_ENSURE_MSG(context,
MultiplyAndCheckOverflow(embedding_size, dim,
&embedding_size) == kTfLiteOk,
"Embedding size overflowed.");
output_shape->data[k] = dim;
}
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
const int output_size = lookup_size * embedding_size;
const size_t output_size = lookup_size * embedding_size;
TfLiteTensorRealloc(output_size * sizeof(float), output);

float* output_ptr = GetTensorData<float>(output);
const float* weights_ptr = GetTensorData<float>(weights);
const float* value_ptr = GetTensorData<float>(value);
// Makes sure reallocation was successful.
TF_LITE_ENSURE(context, output_ptr != nullptr);

std::fill_n(output_ptr, output_size, 0.0f);

Expand Down

0 comments on commit 1de4972

Please sign in to comment.