Skip to content

Commit 728113a

Browse files
jcai19tensorflower-gardener
authored andcommitted
[Tensorflow] Fix security vulnerability with TensorListSplitOp
PiperOrigin-RevId: 506441188
1 parent 7119c6a commit 728113a

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

Diff for: tensorflow/compiler/tests/tensor_list_ops_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,17 @@ def testZerosLikeForTensorList(self):
236236
self.assertAllEqual(z.shape.as_list(), [None])
237237
self.assertAllEqual(z, [0.0, 0.0])
238238

239+
def testInvalidSplitLength(self):
240+
with self.session(), self.test_scope():
241+
tensor_list_split = list_ops.tensor_list_split(
242+
tensor=[1], element_shape=[-1], lengths=[0]
243+
)
244+
with self.assertRaisesRegex(
245+
errors.UnimplementedError, "All lengths must be positive"
246+
):
247+
self.evaluate(tensor_list_split)
248+
249+
239250
if __name__ == "__main__":
240251
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
241252
os.environ.get("TF_XLA_FLAGS", ""))

Diff for: tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc

+2
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ class TensorListSplitOp : public XlaOpKernel {
553553
OP_REQUIRES(ctx, len == length,
554554
errors::Unimplemented("All lengths have to be the same"));
555555
}
556+
OP_REQUIRES(ctx, length,
557+
errors::Unimplemented("All lengths must be positive"));
556558
OP_REQUIRES(
557559
ctx, element_dims[0] % length == 0,
558560
errors::Unimplemented("Buffer size has to be a multiple of length"));

0 commit comments

Comments
 (0)