Skip to content

Commit

Permalink
Merge branch 'r2.4' into cherrypicks_SRGF7
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaimaruseac committed May 26, 2021
2 parents 2578f8c + e0a5bb3 commit 9ccfe82
Show file tree
Hide file tree
Showing 21 changed files with 237 additions and 37 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/kernels/ctc_decoder_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class CTCDecodeHelper {
if (inputs_shape.dims() != 3) {
return errors::InvalidArgument("inputs is not a 3-Tensor");
}
if (inputs_shape.num_elements() == 0) {
return errors::InvalidArgument("inputs must not be empty");
}

const int64 max_time = inputs_shape.dim_size(0);
const int64 batch_size = inputs_shape.dim_size(1);
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/ctc_loss_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,18 @@ class CTCLossOp : public OpKernel {
errors::InvalidArgument("sequence_length is not a vector"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
errors::InvalidArgument("labels_indices is not a matrix"));
OP_REQUIRES(ctx, labels_indices->dim_size(1) > 1,
errors::InvalidArgument(
"labels_indices second dimension must be >= 1. Received ",
labels_indices->dim_size(1)));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
errors::InvalidArgument("labels_values is not a vector"));

const TensorShape& inputs_shape = inputs->shape();
const int64 max_time = inputs_shape.dim_size(0);
OP_REQUIRES(ctx, max_time != 0,
errors::InvalidArgument(
"Max time or first dimension of input cannot be 0."));
const int64 batch_size = inputs_shape.dim_size(1);
const int64 num_classes_raw = inputs_shape.dim_size(2);
OP_REQUIRES(
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/dequantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ class DequantizeOp : public OpKernel {
if (axis_ > -1) {
num_slices = input.dim_size(axis_);
}
OP_REQUIRES(ctx, input_min_tensor.NumElements() == num_slices,
errors::InvalidArgument(
"input_min_tensor must have as many elements as input on "
"the dequantization axis (",
axis_, "), got ", input_min_tensor.NumElements(),
", expected ", num_slices));
OP_REQUIRES(ctx, input_max_tensor.NumElements() == num_slices,
errors::InvalidArgument(
"input_max_tensor must have as many elements as input on "
"the dequantization axis (",
axis_, "), got ", input_max_tensor.NumElements(),
", expected ", num_slices));

Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/core/kernels/fractional_avg_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,19 @@ class FractionalAvgPoolGradOp : public OpKernel {
const int64 out_cols = out_backprop.dim_size(2);
const int64 out_depth = out_backprop.dim_size(3);

OP_REQUIRES(context, row_seq_tensor.NumElements() > out_rows,
errors::InvalidArgument("Given out_backprop shape ",
out_backprop.shape().DebugString(),
", row_seq_tensor must have at least ",
out_rows + 1, " elements, but got ",
row_seq_tensor.NumElements()));
OP_REQUIRES(context, col_seq_tensor.NumElements() > out_cols,
errors::InvalidArgument("Given out_backprop shape ",
out_backprop.shape().DebugString(),
", col_seq_tensor must have at least ",
out_cols + 1, " elements, but got ",
col_seq_tensor.NumElements()));

auto row_seq_tensor_flat = row_seq_tensor.flat<int64>();
auto col_seq_tensor_flat = col_seq_tensor.flat<int64>();
auto orig_input_tensor_shape_flat = orig_input_tensor_shape.flat<int64>();
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/core/kernels/fractional_max_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,20 @@ class FractionalMaxPoolGradOp : public OpKernel {

// Just to make it similar to FractionalMaxPoolOp.
constexpr int tensor_in_and_out_dims = 4;
OP_REQUIRES(
context, tensor_in.dims() == tensor_in_and_out_dims,
errors::InvalidArgument("orig_input should be a tensor of rank 4, got ",
tensor_in.DebugString()));
OP_REQUIRES(context, tensor_in.NumElements() > 0,
errors::InvalidArgument("orig_input must not be empty, got ",
tensor_in.DebugString()));
OP_REQUIRES(context, tensor_out.dims() == tensor_in_and_out_dims,
errors::InvalidArgument(
"orig_output should be a tensor of rank 4, got ",
tensor_out.DebugString()));
OP_REQUIRES(context, tensor_out.NumElements() > 0,
errors::InvalidArgument("orig_output must not be empty, got ",
tensor_out.DebugString()));
std::vector<int64> input_size(tensor_in_and_out_dims);
std::vector<int64> output_size(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
Expand Down
28 changes: 27 additions & 1 deletion tensorflow/core/kernels/fused_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,32 @@ class FusedBatchNormOpBase : public OpKernel {
errors::InvalidArgument("Error during tensor copy."));
}

const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
OP_REQUIRES(
context, scale.NumElements() == num_channels,
errors::InvalidArgument("scale must have the same number of elements "
"as the channels of x, got ",
scale.NumElements(), " and ", num_channels));
OP_REQUIRES(
context, offset.NumElements() == num_channels,
errors::InvalidArgument("offset must have the same number of elements "
"as the channels of x, got ",
offset.NumElements(), " and ", num_channels));
if (estimated_mean.NumElements() != 0) {
OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
errors::InvalidArgument(
"mean must be empty or have the same number of "
"elements as the channels of x, got ",
estimated_mean.NumElements(), " and ", num_channels));
}
if (estimated_variance.NumElements() != 0) {
OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
errors::InvalidArgument(
"variance must be empty or have the same number of "
"elements as the channels of x, got ",
estimated_variance.NumElements(), " and ", num_channels));
}

if (has_side_input_) {
OP_REQUIRES(context, side_input->shape() == x.shape(),
errors::InvalidArgument(
Expand All @@ -1291,7 +1317,7 @@ class FusedBatchNormOpBase : public OpKernel {
// NOTE(ezhulenev): This requirement is coming from implementation
// details of cudnnBatchNormalizationForwardTrainingEx.
OP_REQUIRES(
context, !is_training_ || x.dim_size(3) % 4 == 0,
context, !is_training_ || num_channels % 4 == 0,
errors::InvalidArgument("FusedBatchNorm with activation requires "
"channel dimension to be a multiple of 4."));
}
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/image/draw_bounding_box_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class DrawBoundingBoxesOp : public OpKernel {
errors::InvalidArgument("Channel depth should be either 1 (GRY), "
"3 (RGB), or 4 (RGBA)"));

OP_REQUIRES(
context, boxes.dim_size(2) == 4,
errors::InvalidArgument(
"The size of the third dimension of the box must be 4. Received: ",
boxes.dim_size(2)));

const int64 batch_size = images.dim_size(0);
const int64 height = images.dim_size(1);
const int64 width = images.dim_size(2);
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/core/kernels/maxpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ static void SpatialMaxPoolWithArgMaxHelper(
// CHECK(input_backprop_index >= in_start && input_backprop_index <
// in_end)
FastBoundsCheck(input_backprop_index - in_start, in_end - in_start);
input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
if (index < out_backprop.NumElements()) {
input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
}
}
}
};
Expand Down Expand Up @@ -1077,6 +1079,8 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, out_shape, &grad_out));

if (out_shape.num_elements() == 0) return; // nothing to be done

LaunchMaxPoolingGradWithArgmax<Device, T>::launch(
context, params, grad_in, argmax, grad_out, include_batch_in_index_);
}
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/core/kernels/pooling_ops_3d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,19 @@ struct LaunchAvgPooling3dGradOp<CPUDevice, T> {
const std::array<int64, 3>& output_shape,
const std::array<int64, 3>& padding,
TensorFormat data_format, Tensor* output) {
OP_REQUIRES(
context, tensor_in_shape.dim_size(0) == out_backprop.dim_size(0),
errors::InvalidArgument(
"Expected first dimension of tensor_in_shape and "
"out_backprop to match, got ",
tensor_in_shape.dim_size(0), " and ", out_backprop.dim_size(0)));
OP_REQUIRES(
context, tensor_in_shape.dim_size(4) == out_backprop.dim_size(4),
errors::InvalidArgument(
"Expected last dimension of tensor_in_shape and "
"out_backprop to match, got ",
tensor_in_shape.dim_size(4), " and ", out_backprop.dim_size(4)));

output->flat<T>().setZero();
std::array<int64, 3> input_size = {{tensor_in_shape.dim_size(3),
tensor_in_shape.dim_size(2),
Expand Down Expand Up @@ -693,6 +706,7 @@ class MaxPooling3dGradGradOp : public OpKernel {

Pool3dParameters params{context, ksize_, stride_,
padding_, data_format_, tensor_in.shape()};
if (!context->status().ok()) return; // params is invalid

Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
Expand All @@ -710,6 +724,17 @@ class MaxPooling3dGradGradOp : public OpKernel {
context, out_grad_backprop.NumElements() > 0,
errors::InvalidArgument("received empty tensor out_grad_backprop: ",
out_grad_backprop.DebugString()));
OP_REQUIRES(context,
tensor_in.NumElements() == out_grad_backprop.NumElements(),
errors::InvalidArgument("tensor_in and out_grad_backprop must "
"have same number of elements, got <",
tensor_in.DebugString(), "> and <",
out_grad_backprop.DebugString(), ">"));
OP_REQUIRES(
context, tensor_out.NumElements() == output->NumElements(),
errors::InvalidArgument(
"tensor_out and output must have same number of elements, got <",
tensor_out.DebugString(), "> and <", output->DebugString(), ">"));

LaunchMaxPooling3dGradGradOp<Device, T>::launch(
context, params, tensor_in, tensor_out, out_grad_backprop, output);
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/requantization_range_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class RequantizationRangeOp : public OpKernel {

void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
OP_REQUIRES(ctx, ctx->input(1).NumElements() > 0,
errors::InvalidArgument("Input min must not be empty."));
OP_REQUIRES(ctx, ctx->input(2).NumElements() > 0,
errors::InvalidArgument("Input max must not be empty."));
const float input_min_float = ctx->input(1).flat<float>()(0);
const float input_max_float = ctx->input(2).flat<float>()(0);
Tensor* output_min = nullptr;
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/reverse_sequence_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class ReverseSequenceOp : public OpKernel {
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
OP_REQUIRES(context, batch_dim_ >= 0,
errors::InvalidArgument("Invalid batch_dim ", batch_dim_));
OP_REQUIRES(context, seq_dim_ >= 0,
errors::InvalidArgument("Invalid seq_dim ", seq_dim_));
}

void Compute(OpKernelContext* context) override {
Expand Down
36 changes: 36 additions & 0 deletions tensorflow/core/kernels/sdca_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,31 @@ Status ModelWeights::Initialize(OpKernelContext* const context) {
OpInputList sparse_weights_inputs;
TF_RETURN_IF_ERROR(
context->input_list("sparse_weights", &sparse_weights_inputs));
if (sparse_indices_inputs.size() != sparse_weights_inputs.size())
return errors::InvalidArgument(
"sparse_indices and sparse_weights must have the same length, got ",
sparse_indices_inputs.size(), " and ", sparse_weights_inputs.size());
OpInputList dense_weights_inputs;
TF_RETURN_IF_ERROR(
context->input_list("dense_weights", &dense_weights_inputs));

OpOutputList sparse_weights_outputs;
TF_RETURN_IF_ERROR(context->output_list("out_delta_sparse_weights",
&sparse_weights_outputs));
if (sparse_weights_outputs.size() != sparse_weights_inputs.size())
return errors::InvalidArgument(
"out_delta_sparse_weights and sparse_weights must have the same "
"length, got ",
sparse_weights_outputs.size(), " and ", sparse_weights_inputs.size());

OpOutputList dense_weights_outputs;
TF_RETURN_IF_ERROR(
context->output_list("out_delta_dense_weights", &dense_weights_outputs));
if (dense_weights_outputs.size() != dense_weights_inputs.size())
return errors::InvalidArgument(
"out_delta_dense_weights and dense_weights must have the same length, "
"got ",
dense_weights_outputs.size(), " and ", dense_weights_inputs.size());

for (int i = 0; i < sparse_weights_inputs.size(); ++i) {
Tensor* delta_t;
Expand Down Expand Up @@ -327,13 +341,28 @@ Status Examples::Initialize(OpKernelContext* const context,
OpInputList sparse_example_indices_inputs;
TF_RETURN_IF_ERROR(context->input_list("sparse_example_indices",
&sparse_example_indices_inputs));
if (sparse_example_indices_inputs.size() != num_sparse_features)
return errors::InvalidArgument(
"Expected ", num_sparse_features,
" tensors in sparse_example_indices but got ",
sparse_example_indices_inputs.size());
OpInputList sparse_feature_indices_inputs;
TF_RETURN_IF_ERROR(context->input_list("sparse_feature_indices",
&sparse_feature_indices_inputs));
if (sparse_feature_indices_inputs.size() != num_sparse_features)
return errors::InvalidArgument(
"Expected ", num_sparse_features,
" tensors in sparse_feature_indices but got ",
sparse_feature_indices_inputs.size());
OpInputList sparse_feature_values_inputs;
if (num_sparse_features_with_values > 0) {
TF_RETURN_IF_ERROR(context->input_list("sparse_feature_values",
&sparse_feature_values_inputs));
if (sparse_feature_values_inputs.size() != num_sparse_features_with_values)
return errors::InvalidArgument(
"Expected ", num_sparse_features_with_values,
" tensors in sparse_feature_values but got ",
sparse_feature_values_inputs.size());
}

const Tensor* example_weights_t;
Expand Down Expand Up @@ -400,6 +429,13 @@ Status Examples::CreateSparseFeatureRepresentation(
sparse_example_indices_inputs[i].template flat<int64>();
auto feature_indices =
sparse_feature_indices_inputs[i].template flat<int64>();
if (example_indices.size() != feature_indices.size()) {
mutex_lock l(mu);
result = errors::InvalidArgument(
"Found mismatched example_indices and feature_indices [",
example_indices, "] vs [", feature_indices, "]");
return;
}

// Parse features for each example. Features for a particular example
// are at the offsets (start_id, end_id]
Expand Down
13 changes: 10 additions & 3 deletions tensorflow/core/kernels/sparse_split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,18 @@ class SparseSplitOp : public OpKernel {
input_shape.vec<int64>()(axis),
"), got ", num_split_));

// Prevent overflow by constructing the dense shape separately
TensorShape dense_shape;
const auto input_shape_flat = input_shape.flat<int64>();
for (int i = 0; i < input_shape.NumElements(); i++) {
OP_REQUIRES_OK(context,
dense_shape.AddDimWithStatus(input_shape_flat(i)));
}

sparse::SparseTensor sparse_tensor;
OP_REQUIRES_OK(context,
sparse::SparseTensor::Create(
input_indices, input_values,
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
sparse::SparseTensor::Create(input_indices, input_values,
dense_shape, &sparse_tensor));

std::vector<sparse::SparseTensor> outputs;
OP_REQUIRES_OK(context, sparse::SparseTensor::Split<T>(
Expand Down
13 changes: 10 additions & 3 deletions tensorflow/lite/core/subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1033,10 +1033,17 @@ TfLiteStatus Subgraph::Invoke() {
TF_LITE_ENSURE_STATUS(EnsureTensorDataIsReadable(tensor_index));
}
if (tensor->data.raw == nullptr && tensor->bytes > 0) {
if (registration.builtin_code == kTfLiteBuiltinReshape && i == 1) {
if (registration.builtin_code == kTfLiteBuiltinReshape && i == 1 &&
tensor->dims->size != 1) {
// In general, having a tensor here with no buffer will be an error.
// However, for the reshape operator, the second input tensor is only
// used for the shape, not for the data. Thus, null buffer is ok.
// However, for the reshape operator, the second input tensor is
// sometimes only used for the shape, not for the data. Thus, null
// buffer is ok in this situation.
// The situation where null buffer is not ok for reshape operator is
// only when there are 2 inputs given to the node and the one
// corresponding to the shape (i == 1) is a vector that contains all
// dimensions. See `GetOutputShape()` function in
// `tensorflow/lite/kernels/reshape.cc`
continue;
} else {
// In all other cases, we need to return an error as otherwise we will
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/lite/kernels/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));

// Prevent division by 0 in the helper
TF_LITE_ENSURE(context, NumElements(params) > 0);

switch (indices->type) {
case kTfLiteInt32:
return EvalGatherNd<int32_t>(context, params, indices, output);
Expand Down

0 comments on commit 9ccfe82

Please sign in to comment.