Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CHECK failures #53924

Merged
merged 8 commits into from
Jan 24, 2022
23 changes: 13 additions & 10 deletions tensorflow/core/framework/tensor_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64_t> dim_sizes) {
if (!kIsPartial && !large_size) {
for (auto s : dim_sizes) {
if (TF_PREDICT_FALSE(s < 0)) {
return errors::Internal(
return errors::InvalidArgument(
"Expected shape dimensions to be non-negative, got ", s);
}
}
Expand Down Expand Up @@ -411,7 +411,8 @@ template <class Shape>
Status TensorShapeBase<Shape>::AddDimWithStatus(int64_t size) {
if (!kIsPartial) {
if (TF_PREDICT_FALSE(size < 0)) {
return errors::Internal("Expected a non-negative size, got ", size);
return errors::InvalidArgument("Expected a non-negative size, got ",
size);
}
}

Expand All @@ -420,7 +421,7 @@ Status TensorShapeBase<Shape>::AddDimWithStatus(int64_t size) {
}

if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
return errors::Internal("Too many dimensions in tensor");
return errors::InvalidArgument("Too many dimensions in tensor");
}

int64_t new_num_elements;
Expand All @@ -429,9 +430,9 @@ Status TensorShapeBase<Shape>::AddDimWithStatus(int64_t size) {
} else {
new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
if (TF_PREDICT_FALSE(new_num_elements < 0)) {
return errors::Internal("Encountered overflow when multiplying ",
num_elements(), " with ", size,
", result: ", new_num_elements);
return errors::InvalidArgument("Encountered overflow when multiplying ",
num_elements(), " with ", size,
", result: ", new_num_elements);
}
}

Expand Down Expand Up @@ -522,7 +523,8 @@ template <class Shape>
Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64_t size) {
if (!kIsPartial) {
if (TF_PREDICT_FALSE(size < 0)) {
return errors::Internal("Expected a non-negative size, got ", size);
return errors::InvalidArgument("Expected a non-negative size, got ",
size);
}
}

Expand Down Expand Up @@ -594,13 +596,14 @@ void TensorShapeBase<Shape>::set_dim(int d, int64_t size) {
template <class Shape>
Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64_t size) {
if (TF_PREDICT_FALSE(d < 0)) {
return errors::Internal("Index must be non-negative, got ", d);
return errors::InvalidArgument("Index must be non-negative, got ", d);
}
if (TF_PREDICT_FALSE(d >= dims())) {
return errors::Internal("Index must be less than ", dims(), ", got ", d);
return errors::InvalidArgument("Index must be less than ", dims(), ", got ",
d);
}
if (TF_PREDICT_FALSE(!kIsPartial && size < 0)) {
return errors::Internal("Expected a non-negative size, got ", size);
return errors::InvalidArgument("Expected a non-negative size, got ", size);
}

if (tag() == REP16 && size < kMaxRep16) {
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/framework/tensor_shape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ TEST(TensorShapeTest, AddDimWithStatus) {
ASSERT_EQ(4, s.dims());

status = s.AddDimWithStatus(-1);
EXPECT_EQ(tensorflow::error::INTERNAL, status.code());
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
}

TEST(TensorShapeTest, Factory) {
Expand All @@ -225,7 +225,7 @@ TEST(TensorShapeTest, Factory) {
ASSERT_EQ(3, s.dims());

status = TensorShape::BuildTensorShapeBase({-10, 5, 20}, &s);
EXPECT_EQ(tensorflow::error::INTERNAL, status.code());
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
}

// -----------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,12 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel {
previous_batch_index = next_batch_index;
}
gtl::InlinedVector<int64_t, 8> std_order(dense_shape->NumElements(), 0);
TensorShape shape;
OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape(
dense_shape->vec<int64_t>(), &shape));
sparse::SparseTensor tensor;
OP_REQUIRES_OK(
ctx, sparse::SparseTensor::Create(
*indices, *values, TensorShape(dense_shape->vec<int64_t>()),
std_order, &tensor));
OP_REQUIRES_OK(ctx, sparse::SparseTensor::Create(*indices, *values, shape,
std_order, &tensor));
*output = new Dataset<T>(ctx, std::move(tensor));
}

Expand Down
6 changes: 5 additions & 1 deletion tensorflow/core/kernels/reshape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ limitations under the License.
#include <vector>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
Expand Down Expand Up @@ -99,7 +101,9 @@ void ReshapeSparseTensor(OpKernelContext *context,
target_shape_in.shape().DebugString()));

const int64_t output_rank = target_shape_in.NumElements();
const TensorShape input_shape(input_shape_in.vec<int64_t>());
TensorShape input_shape;
OP_REQUIRES_OK(context, TensorShape::BuildTensorShape(
input_shape_in.vec<int64_t>(), &input_shape));
const int64_t dense_size = input_shape.num_elements();
const int64_t nnz = input_indices_in.shape().dim_size(0);

Expand Down
20 changes: 18 additions & 2 deletions tensorflow/core/kernels/segment_reduction_ops_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_

#include <cstdint>

#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/platform/types.h"
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
Expand Down Expand Up @@ -474,6 +478,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
bool is_mean, bool is_sqrtn,
bool has_num_segments, T default_value)
: OpKernel(context),
dtidx_(DataTypeToEnum<Index>::v()),
is_mean_(is_mean),
is_sqrtn_(is_sqrtn),
has_num_segments_(has_num_segments),
Expand Down Expand Up @@ -503,10 +508,20 @@ class SparseSegmentReductionOpBase : public OpKernel {
const auto segment_vec = segment_ids.vec<SegmentId>();
// Note that the current implementation assumes that segment_vec values are
// sorted.
const SegmentId last_segment_id =
num_indices > 0 ? segment_vec(num_indices - 1) : 0;
int64_t limit = dtidx_ == DataType::DT_INT32 ? kint32max : kint64max;

OP_REQUIRES(
context, last_segment_id < limit,
errors::InvalidArgument("Last segment id must be < kintmax, got ",
last_segment_id, " limit ", limit));

const SegmentId last_segment_id_plus_one =
num_indices > 0
? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
: 0;

if (has_num_segments_) {
OP_REQUIRES(
context, output_rows >= last_segment_id_plus_one,
Expand All @@ -518,7 +533,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
errors::InvalidArgument("segment ids must be >= 0"));

TensorShape output_shape = input.shape();
output_shape.set_dim(0, output_rows);
OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows));

// Note that we do not initialize the output buffer with a default value, so
// we need to explicitly set missing indices to the default value.
Expand Down Expand Up @@ -605,6 +620,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
}

private:
const DataType dtidx_;
template <typename Tin>
using EnableIfBfloat16OrHalf =
typename std::enable_if<std::is_same<Tin, bfloat16>::value ||
Expand Down Expand Up @@ -1098,7 +1114,7 @@ class SparseSegmentGradOpBase : public OpKernel {
const auto segment_vec = segment_ids.vec<SegmentId>();

TensorShape output_shape = input.shape();
output_shape.set_dim(0, M);
OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, M));
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
if (M == 0 || N == 0) return;
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/core/kernels/serialize_sparse_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ limitations under the License.

#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
Expand Down Expand Up @@ -366,7 +368,10 @@ class SerializeManySparseOp : public OpKernel {
errors::InvalidArgument(
"Rank of input SparseTensor should be > 1, but saw rank: ", rank));

TensorShape tensor_input_shape(input_shape->vec<int64_t>());
TensorShape tensor_input_shape;
OP_REQUIRES_OK(context,
TensorShape::BuildTensorShape(input_shape->vec<int64_t>(),
&tensor_input_shape));
gtl::InlinedVector<int64_t, 8> std_order(rank);
std::iota(std_order.begin(), std_order.end(), 0);
SparseTensor input_st;
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/kernels/set_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

namespace tensorflow {
Expand Down Expand Up @@ -69,8 +70,9 @@ Status SparseTensorFromContext(OpKernelContext* ctx, const int32_t base_index,
bool validate_indices,
sparse::SparseTensor* tensor) {
// Assume row-major order.
const TensorShape shape =
TensorShape(ctx->input(base_index + 2).vec<int64_t>());
TensorShape shape;
TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape(
ctx->input(base_index + 2).vec<int64_t>(), &shape));
CheckRankAtLeast2(ctx, shape);
std::vector<int64_t> order(shape.dims());
std::iota(order.begin(), order.end(), 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,18 @@ class SparseTensorToCSRSparseMatrixCPUOp : public OpKernel {
const int64_t total_nnz = values.NumElements();

// Allocate output Tensors.
Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
Tensor csr_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz}));
Tensor csr_row_ptr(cpu_allocator(), DT_INT32,
TensorShape({(num_rows + 1) * batch_size}));
TensorShape batch_ptr_shape;
OP_REQUIRES_OK(
ctx, TensorShape::BuildTensorShape({batch_size + 1}, &batch_ptr_shape));
Tensor batch_ptr(cpu_allocator(), DT_INT32, batch_ptr_shape);
TensorShape csr_col_ind_shape;
OP_REQUIRES_OK(
ctx, TensorShape::BuildTensorShape({total_nnz}, &csr_col_ind_shape));
Tensor csr_col_ind(cpu_allocator(), DT_INT32, csr_col_ind_shape);
TensorShape csr_row_ind_shape;
OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape(
{(num_rows + 1) * batch_size}, &csr_row_ind_shape));
Tensor csr_row_ptr(cpu_allocator(), DT_INT32, csr_row_ind_shape);

// Fill the row pointers with zeros.
functor::SetZeroFunctor<CPUDevice, int32> set_zero;
Expand Down
12 changes: 10 additions & 2 deletions tensorflow/core/kernels/sparse_reduce_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ limitations under the License.
#define EIGEN_USE_THREADS

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
Expand Down Expand Up @@ -172,10 +174,13 @@ class SparseReduceOp : public OpKernel {
// making deep copies here. Remove this if/when we change Reorder()'s
// semantics.
const auto shape_vec = shape_t->vec<int64_t>();
TensorShape shape;
OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape(shape_vec, &shape));

SparseTensor sp;
OP_REQUIRES_OK(ctx, SparseTensor::Create(
tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
TensorShape(shape_vec), &sp));
shape, &sp));
ReduceDetails reduction = SparseTensorReduceHelper(
sp, reduction_axes_t->flat<int32>(), keep_dims_);

Expand Down Expand Up @@ -275,10 +280,13 @@ class SparseReduceSparseOp : public OpKernel {

OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t));

TensorShape shape;
OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape(shape_t->vec<int64_t>(),
&shape));
SparseTensor sp;
OP_REQUIRES_OK(ctx, SparseTensor::Create(tensor::DeepCopy(*indices_t),
tensor::DeepCopy(*values_t),
TensorShape(shape_t->vec<int64_t>()), &sp));
shape, &sp));
ReduceDetails reduction = SparseTensorReduceHelper(
sp, reduction_axes_t->flat<int32>(), keep_dims_);

Expand Down
26 changes: 19 additions & 7 deletions tensorflow/core/kernels/sparse_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

namespace tensorflow {
Expand All @@ -39,27 +40,38 @@ struct SparseSliceFunctor<CPUDevice, T> {
const int input_dims = input_shape.NumElements();

sparse::SparseTensor sparse_tensor;
OP_REQUIRES_OK(
context, sparse::SparseTensor::Create(
input_indices, input_values,
TensorShape(input_shape.vec<int64_t>()), &sparse_tensor));
TensorShape sparse_tensor_shape;
OP_REQUIRES_OK(context,
TensorShapeBase<TensorShape>::BuildTensorShapeBase(
input_shape.vec<int64_t>(), &sparse_tensor_shape));
OP_REQUIRES_OK(context, sparse::SparseTensor::Create(
input_indices, input_values,
sparse_tensor_shape, &sparse_tensor));

const gtl::ArraySlice<int64_t> start(input_start.flat<int64_t>().data(),
input_dims);
const gtl::ArraySlice<int64_t> size(input_size.flat<int64_t>().data(),
input_dims);

const sparse::SparseTensor output =
const StatusOr<sparse::SparseTensor> output_or =
sparse::SparseTensor::Slice<T>(sparse_tensor, start, size);
OP_REQUIRES_OK(context, output_or.status());
auto output = output_or.ValueOrDie();

context->set_output(0, output.indices());
context->set_output(1, output.values());

const TensorShape output_shape(output.shape());
TensorShape output_shape;
OP_REQUIRES_OK(context, TensorShapeBase<TensorShape>::BuildTensorShapeBase(
output.shape(), &output_shape));

TensorShape allocated_shape;
OP_REQUIRES_OK(context, TensorShapeBase<TensorShape>::BuildTensorShapeBase(
{output_shape.dims()}, &allocated_shape));

Tensor* shape = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(2, {output_shape.dims()}, &shape));
context->allocate_output(2, allocated_shape, &shape));
for (int dim = 0; dim < output_shape.dims(); ++dim) {
shape->vec<int64_t>()(dim) = output_shape.dim_size(dim);
}
Expand Down
9 changes: 6 additions & 3 deletions tensorflow/core/kernels/sparse_softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
Expand Down Expand Up @@ -62,14 +63,16 @@ class SparseSoftmaxOp : public OpKernel {
errors::InvalidArgument(
"Input should have rank >= 2, but received shape: ",
shape_t->SummarizeValue(3)));
TensorShape shape;
OP_REQUIRES_OK(context, TensorShape::BuildTensorShape(
shape_t->flat<int64_t>(), &shape));

const int64_t nnz = indices_t->dim_size(0);
const int rank = static_cast<int>(indices_t->dim_size(1));
SparseTensor st;
OP_REQUIRES_OK(
context, SparseTensor::Create(
tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
TensorShape(shape_t->flat<int64_t>()), &st));
context, SparseTensor::Create(tensor::DeepCopy(*indices_t),
tensor::DeepCopy(*values_t), shape, &st));

Tensor *output_values = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({nnz}),
Expand Down