Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix Segfault in Concat V2 shape function.
PiperOrigin-RevId: 412120654
Change-Id: I3ff915faea694f9ad8b00024e9af2de9909011be
  • Loading branch information
ishark authored and tensorflower-gardener committed Nov 24, 2021
1 parent fb51d44 commit 08d7b00
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensorflow/core/framework/common_shape_fns.cc
Expand Up @@ -2005,7 +2005,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
}

// Minimum required number of dimensions.
const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
const int64 min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;

ShapeHandle output_before;
ShapeHandle output_after;
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/kernel_tests/array_ops/concat_op_test.py
Expand Up @@ -16,6 +16,7 @@

import numpy as np

from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
Expand Down Expand Up @@ -570,6 +571,17 @@ def testConcatInvalidAxis(self):
t2 = [2]
gen_array_ops.concat_v2([t1, t2], 1).eval()

def testConcatInvalidAxisInTfFunction(self):

@def_function.function
def concat_wrapper():
y = gen_array_ops.concat_v2(
values=[[1, 2, 3], [4, 5, 6]], axis=0xb500005b)
return y

with self.assertRaises(ValueError):
concat_wrapper()

def testConcatNegativeAxis(self):
with test_util.use_gpu():
t1 = [[1, 2, 3], [4, 5, 6]]
Expand Down

0 comments on commit 08d7b00

Please sign in to comment.