Skip to content
Permalink
Browse files Browse the repository at this point in the history
[Tensorflow] Fix security vulnerability with TensorListSplitOp
PiperOrigin-RevId: 506441188
  • Loading branch information
jcai19 authored and tensorflower-gardener committed Feb 1, 2023
1 parent 7119c6a commit 728113a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensorflow/compiler/tests/tensor_list_ops_test.py
Expand Up @@ -236,6 +236,17 @@ def testZerosLikeForTensorList(self):
self.assertAllEqual(z.shape.as_list(), [None])
self.assertAllEqual(z, [0.0, 0.0])

def testInvalidSplitLength(self):
with self.session(), self.test_scope():
tensor_list_split = list_ops.tensor_list_split(
tensor=[1], element_shape=[-1], lengths=[0]
)
with self.assertRaisesRegex(
errors.UnimplementedError, "All lengths must be positive"
):
self.evaluate(tensor_list_split)


if __name__ == "__main__":
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
os.environ.get("TF_XLA_FLAGS", ""))
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
Expand Up @@ -553,6 +553,8 @@ class TensorListSplitOp : public XlaOpKernel {
OP_REQUIRES(ctx, len == length,
errors::Unimplemented("All lengths have to be the same"));
}
OP_REQUIRES(ctx, length,
errors::Unimplemented("All lengths must be positive"));
OP_REQUIRES(
ctx, element_dims[0] % length == 0,
errors::Unimplemented("Buffer size has to be a multiple of length"));
Expand Down

0 comments on commit 728113a

Please sign in to comment.