Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add a check for Key being scalar tensor for MapStage and OrderedMapSt…
…age ops.

According to documentation[1][2], key must be int64 value, but this wasn't enforced and the ops would fail with check failure for non-scalar key value.

[1]https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/ordered-map-stage
[2]https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/map-stage

PiperOrigin-RevId: 413822112
Change-Id: I9d118faf990e6361900aa32272eff486ad9f0e2e
  • Loading branch information
ishark authored and tensorflower-gardener committed Dec 3, 2021
1 parent 0288e9c commit f573155
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 124 deletions.
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/map_stage_op.cc
Expand Up @@ -536,6 +536,11 @@ class MapStageOp : public OpKernel {
OP_REQUIRES(ctx, key_tensor->NumElements() > 0,
errors::InvalidArgument("key must not be empty"));

OP_REQUIRES(ctx, key_tensor->NumElements() == 1,
errors::InvalidArgument(
"key must be an int64 scalar, got tensor with shape: ",
key_tensor->shape()));

// Create copy for insertion into Staging Area
Tensor key(*key_tensor);

Expand Down

0 comments on commit f573155

Please sign in to comment.