Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent segfault in embedding_lookup_sparse.cc
Previous fixes missed one additional case.

PiperOrigin-RevId: 417676944
Change-Id: I8ab412155cf9b1e897448a6611d209eaa7ca9e66
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Dec 21, 2021
1 parent f435ae9 commit a4e401d
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tensorflow/lite/kernels/embedding_lookup_sparse.cc
Expand Up @@ -159,6 +159,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
const size_t values_size = NumElements(value);

const int lookup_rank = SizeOfDimension(indices, 1);
const int embedding_rank = NumDimensions(value);
Expand Down Expand Up @@ -253,6 +254,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
current_squares_weight += w * w;
current_total_weight += w;
for (int k = 0; k < embedding_size; k++) {
// only index if indices are valid
if (current_output_offset + k < 0) continue;
if (current_output_offset + k >= output_size) continue;
if (example_embedding_offset + k < 0) continue;
if (example_embedding_offset + k >= values_size) continue;
output_ptr[current_output_offset + k] +=
value_ptr[example_embedding_offset + k] * w;
}
Expand Down

0 comments on commit a4e401d

Please sign in to comment.