Skip to content

Commit f2c3931

Browse files
rohan100jaintensorflower-gardener
authored andcommitted
Adding more validation checks to _ParallelConcatUpdate to avoid NPE.
PiperOrigin-RevId: 402569467 Change-Id: I2db122dab68be2a5e4e8dd3375f5a70c4d2307ec
1 parent f6da17f commit f2c3931

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

Diff for: tensorflow/core/kernels/inplace_ops.cc

+9
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ class ParallelConcatUpdate : public OpKernel {
7171

7272
void Compute(OpKernelContext* ctx) override {
7373
auto value = ctx->input(0);
74+
// Value should be at least rank 1. Also the 0th dimension should be
75+
// at least loc_.
76+
OP_REQUIRES(ctx, value.dims() >= 1,
77+
errors::InvalidArgument("value should be at least rank 1."));
78+
OP_REQUIRES(
79+
ctx, value.dim_size(0) > loc_,
80+
errors::InvalidArgument("0th dimension of value = ", value.dim_size(0),
81+
" is less than loc_=", loc_));
82+
7483
auto update = ctx->input(1);
7584

7685
OP_REQUIRES(

Diff for: tensorflow/python/kernel_tests/array_ops/stack_op_test.py

+17
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616

1717
import numpy as np
1818

19+
from tensorflow.python import tf2
1920
from tensorflow.python.eager import context
21+
from tensorflow.python.eager import def_function
2022
from tensorflow.python.framework import constant_op
2123
from tensorflow.python.framework import dtypes
24+
from tensorflow.python.framework import errors
2225
from tensorflow.python.framework import ops
2326
from tensorflow.python.framework import test_util
2427
from tensorflow.python.ops import array_ops
28+
from tensorflow.python.ops import gen_array_ops
2529
from tensorflow.python.ops import gradient_checker_v2
2630
from tensorflow.python.platform import test
2731

@@ -69,6 +73,19 @@ def testSimpleParallelCPU(self):
6973
c = array_ops.parallel_stack(xs)
7074
self.assertAllEqual(c, data)
7175

76+
def testParallelConcatShapeZero(self):
77+
if not tf2.enabled():
78+
self.skipTest("only fails in TF2")
79+
80+
@def_function.function
81+
def f():
82+
y = gen_array_ops.parallel_concat(values=[["tf"]], shape=0)
83+
return y
84+
85+
with self.assertRaisesRegex(errors.InvalidArgumentError,
86+
r"0th dimension of value .* is less than"):
87+
f()
88+
7289
def testSimpleParallelGPU(self):
7390
# tf.parallel_stack is only supported in graph mode.
7491
with ops.Graph().as_default():

0 commit comments

Comments
 (0)