From 0a7bd116bf29dccde91d0dea6d209fe6d592bd32 Mon Sep 17 00:00:00 2001 From: Alan Liu Date: Thu, 28 Apr 2022 11:37:31 -0700 Subject: [PATCH] Fix tf.raw_ops.UnsortedSegmentJoin vulnerability with invalid num_segments. Check that input is actually a scalar before treating it as such. PiperOrigin-RevId: 445206880 --- tensorflow/core/kernels/unsorted_segment_join_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/core/kernels/unsorted_segment_join_op.cc b/tensorflow/core/kernels/unsorted_segment_join_op.cc index 9acfe7fb1e4952..449db530513d5f 100644 --- a/tensorflow/core/kernels/unsorted_segment_join_op.cc +++ b/tensorflow/core/kernels/unsorted_segment_join_op.cc @@ -92,6 +92,9 @@ class UnsortedSegmentJoinOp : public OpKernel { const Tensor& num_segments_tensor = context->input(2); OP_REQUIRES(context, num_segments_tensor.NumElements() != 0, errors::InvalidArgument("Number of segments cannot be empty.")); + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(num_segments_tensor.shape()), + errors::InvalidArgument("Number of segments must be a scalar")); auto num_segments = num_segments_tensor.scalar()(); OP_REQUIRES(context, segment_dims != 0,