Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fixing null pointer read in TensorArrayConcat when the step container…
… is missing.

PiperOrigin-RevId: 504332194
  • Loading branch information
jsmeredith authored and tensorflower-gardener committed Jan 24, 2023
1 parent 87f68cc commit 239139d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/tensor_array_ops.cc
Expand Up @@ -80,8 +80,9 @@ Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) {
TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle));
ResourceMgr* rm = ctx->resource_manager();
if (rm == nullptr) return errors::Internal("No resource manager.");
TF_RETURN_IF_ERROR(
ctx->step_container()->Lookup(rm, container + ta_handle, tensor_array));
ScopedStepContainer* sc = ctx->step_container();
if (sc == nullptr) return errors::Internal("No step container.");
TF_RETURN_IF_ERROR(sc->Lookup(rm, container + ta_handle, tensor_array));
return OkStatus();
} else {
return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array);
Expand Down
Expand Up @@ -1846,6 +1846,22 @@ def testStackShapeOnStaticSize(self):
ta = ta.write(0, [0])
self.assertEqual([42, 1], ta.stack().shape.as_list())

def testTensorArrayConcatFailsWhenMissingStepContainer(self):
@def_function.function
def func():
y = data_flow_ops.TensorArrayConcatV2(
handle=["a", "b"],
flow_in=0.1,
dtype=dtypes.int32,
element_shape_except0=1,
)
return y

with self.assertRaisesRegex(
errors.NotFoundError, "Container .* does not exist"
):
self.evaluate(func())


class TensorArrayBenchmark(test.Benchmark):

Expand Down

0 comments on commit 239139d

Please sign in to comment.