Skip to content

Commit

Permalink
Adding more validation checks to _ParallelConcatUpdate to avoid NPE.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 402569467
Change-Id: I2db122dab68be2a5e4e8dd3375f5a70c4d2307ec
  • Loading branch information
rohan100jain authored and mihaimaruseac committed Oct 28, 2021
1 parent 8e297ba commit d11f21b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/inplace_ops.cc
Expand Up @@ -71,6 +71,15 @@ class ParallelConcatUpdate : public OpKernel {

void Compute(OpKernelContext* ctx) override {
auto value = ctx->input(0);
// Value should be at least rank 1. Also the 0th dimension should be
// at least loc_.
OP_REQUIRES(ctx, value.dims() >= 1,
errors::InvalidArgument("value should be at least rank 1."));
OP_REQUIRES(
ctx, value.dim_size(0) > loc_,
errors::InvalidArgument("0th dimension of value = ", value.dim_size(0),
" is less than loc_=", loc_));

auto update = ctx->input(1);

OP_REQUIRES(
Expand Down
20 changes: 20 additions & 0 deletions tensorflow/python/kernel_tests/stack_op_test.py
Expand Up @@ -20,12 +20,16 @@

import numpy as np

from tensorflow.python import tf2
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
Expand Down Expand Up @@ -75,7 +79,23 @@ def testSimpleParallelCPU(self):
c = array_ops.parallel_stack(xs)
self.assertAllEqual(c, data)

<<<<<<< HEAD:tensorflow/python/kernel_tests/stack_op_test.py
@test_util.run_deprecated_v1
=======
def testParallelConcatShapeZero(self):
if not tf2.enabled():
self.skipTest("only fails in TF2")

@def_function.function
def f():
y = gen_array_ops.parallel_concat(values=[["tf"]], shape=0)
return y

with self.assertRaisesRegex(errors.InvalidArgumentError,
r"0th dimension of value .* is less than"):
f()

>>>>>>> e67caccea81 (Adding more validation checks to _ParallelConcatUpdate to avoid NPE.):tensorflow/python/kernel_tests/array_ops/stack_op_test.py
def testSimpleParallelGPU(self):
np.random.seed(7)
with self.session(use_gpu=True):
Expand Down

0 comments on commit d11f21b

Please sign in to comment.