File tree 2 files changed +26
-0
lines changed
python/kernel_tests/array_ops
2 files changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,15 @@ class ParallelConcatUpdate : public OpKernel {
7171
7272 void Compute (OpKernelContext* ctx) override {
7373 auto value = ctx->input (0 );
74+ // Value should be at least rank 1. Also the 0th dimension should be
75+ // at least loc_.
76+ OP_REQUIRES (ctx, value.dims () >= 1 ,
77+ errors::InvalidArgument (" value should be at least rank 1." ));
78+ OP_REQUIRES (
79+ ctx, value.dim_size (0 ) > loc_,
80+ errors::InvalidArgument (" 0th dimension of value = " , value.dim_size (0 ),
81+ " is less than loc_=" , loc_));
82+
7483 auto update = ctx->input (1 );
7584
7685 OP_REQUIRES (
Original file line number Diff line number Diff line change 1616
1717import numpy as np
1818
19+ from tensorflow .python import tf2
1920from tensorflow .python .eager import context
21+ from tensorflow .python .eager import def_function
2022from tensorflow .python .framework import constant_op
2123from tensorflow .python .framework import dtypes
24+ from tensorflow .python .framework import errors
2225from tensorflow .python .framework import ops
2326from tensorflow .python .framework import test_util
2427from tensorflow .python .ops import array_ops
28+ from tensorflow .python .ops import gen_array_ops
2529from tensorflow .python .ops import gradient_checker_v2
2630from tensorflow .python .platform import test
2731
@@ -69,6 +73,19 @@ def testSimpleParallelCPU(self):
6973 c = array_ops .parallel_stack (xs )
7074 self .assertAllEqual (c , data )
7175
76+ def testParallelConcatShapeZero (self ):
77+ if not tf2 .enabled ():
78+ self .skipTest ("only fails in TF2" )
79+
80+ @def_function .function
81+ def f ():
82+ y = gen_array_ops .parallel_concat (values = [["tf" ]], shape = 0 )
83+ return y
84+
85+ with self .assertRaisesRegex (errors .InvalidArgumentError ,
86+ r"0th dimension of value .* is less than" ):
87+ f ()
88+
7289 def testSimpleParallelGPU (self ):
7390 # tf.parallel_stack is only supported in graph mode.
7491 with ops .Graph ().as_default ():
You can’t perform that action at this time.
0 commit comments