Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix tf.raw_ops.UnsortedSegmentJoin vulnerability with invalid num_seg…
…ments.

Check that input is actually a scalar before treating it as such.

PiperOrigin-RevId: 445206880
  • Loading branch information
poulsbo authored and tensorflower-gardener committed Apr 28, 2022
1 parent fa57990 commit 13d38a0
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/kernels/unsorted_segment_join_op.cc
Expand Up @@ -92,6 +92,9 @@ class UnsortedSegmentJoinOp : public OpKernel {
const Tensor& num_segments_tensor = context->input(2);
OP_REQUIRES(context, num_segments_tensor.NumElements() != 0,
errors::InvalidArgument("Number of segments cannot be empty."));
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(num_segments_tensor.shape()),
errors::InvalidArgument("Number of segments must be a scalar"));
auto num_segments = num_segments_tensor.scalar<NUM_SEGMENTS_TYPE>()();

OP_REQUIRES(
Expand Down

0 comments on commit 13d38a0

Please sign in to comment.