File tree 2 files changed +13
-0
lines changed
2 files changed +13
-0
lines changed Original file line number Diff line number Diff line change @@ -236,6 +236,17 @@ def testZerosLikeForTensorList(self):
236236 self .assertAllEqual (z .shape .as_list (), [None ])
237237 self .assertAllEqual (z , [0.0 , 0.0 ])
238238
239+ def testInvalidSplitLength (self ):
240+ with self .session (), self .test_scope ():
241+ tensor_list_split = list_ops .tensor_list_split (
242+ tensor = [1 ], element_shape = [- 1 ], lengths = [0 ]
243+ )
244+ with self .assertRaisesRegex (
245+ errors .UnimplementedError , "All lengths must be positive"
246+ ):
247+ self .evaluate (tensor_list_split )
248+
249+
239250if __name__ == "__main__" :
240251 os .environ ["TF_XLA_FLAGS" ] = ("--tf_xla_min_cluster_size=2 " +
241252 os .environ .get ("TF_XLA_FLAGS" , "" ))
Original file line number Diff line number Diff line change @@ -553,6 +553,8 @@ class TensorListSplitOp : public XlaOpKernel {
553553 OP_REQUIRES (ctx, len == length,
554554 errors::Unimplemented (" All lengths have to be the same" ));
555555 }
556+ OP_REQUIRES (ctx, length,
557+ errors::Unimplemented (" All lengths must be positive" ));
556558 OP_REQUIRES (
557559 ctx, element_dims[0 ] % length == 0 ,
558560 errors::Unimplemented (" Buffer size has to be a multiple of length" ));
You can’t perform that action at this time.
0 commit comments