Skip to content
Permalink
Browse files Browse the repository at this point in the history
Re-enable testTensorListReserveWithNonScalarNumElements to work with …
…mlir as well.

PiperOrigin-RevId: 466460987
  • Loading branch information
pak-laura authored and tensorflower-gardener committed Aug 9, 2022
1 parent 8d1b332 commit b5f6fbf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/list_kernels.cc
Expand Up @@ -31,9 +31,11 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/platform/errors.h"

namespace tensorflow {

Expand Down Expand Up @@ -322,6 +324,11 @@ class TensorListReserve : public OpKernel {
void Compute(OpKernelContext* c) override {
PartialTensorShape element_shape;
OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
OP_REQUIRES(
c, TensorShapeUtils::IsScalar(c->input(1).shape()),
errors::InvalidArgument(
"The num_elements to reserve must be a tensor size 1, but got ",
c->input(1).shape()));
int32_t num_elements = c->input(1).scalar<int32>()();
OP_REQUIRES(c, num_elements >= 0,
errors::InvalidArgument("The num_elements to reserve must be a "
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/python/kernel_tests/data_structures/list_ops_test.py
Expand Up @@ -94,6 +94,16 @@ def testPopFromEmptyTensorListFails(self, max_num_elements):
l = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.evaluate(l)

def testTensorListReserveWithNonScalarNumElements(self):
# list_kernels.cc in tf/core/kernels raises InvalidArgumentError, and
# tf_ops_n_z.cc in tf/compiler/mlir/tf/ir raises UnknownError.
with self.assertRaises((errors.InvalidArgumentError, errors.UnknownError)):
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32,
element_shape=[2, 3],
num_elements=constant_op.constant([1, 1]))
self.evaluate(l)

def testPopUninitializedTensorUseListElementShape(self):
l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3)
Expand Down

0 comments on commit b5f6fbf

Please sign in to comment.