Skip to content

Commit

Permalink
Merge branch 'r2.4' into cherrypicks_7P4OH
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaimaruseac committed May 23, 2021
2 parents c104e6d + cd3bbd4 commit 07c4828
Show file tree
Hide file tree
Showing 32 changed files with 553 additions and 109 deletions.
11 changes: 11 additions & 0 deletions tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,12 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {

Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
NodeDef* producer;

if (consumer->input_size() < 1) {
return errors::FailedPrecondition("Node ", simplified_node_name,
" lacks inputs");
}

TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
const bool producer_is_cast = IsCastLike(*producer);
const bool can_optimize =
Expand Down Expand Up @@ -2430,6 +2436,11 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
~ReplaceMulWithSquare() override = default;

bool IsSupported(const NodeDef* node) const override {
if (!node || node->input_size() < 2) {
// Invalid node
return false;
}

return IsAnyMul(*node) && node->input(0) == node->input(1);
}

Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/grappler/optimizers/dependency_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
// The output values of this node may be needed.
return false;
}

if (node.input_size() < 1) {
// Node lacks input, is invalid
return false;
}

const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
CHECK(input != nullptr) << "node = " << node.name()
<< " input = " << node.input(0);
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/core/kernels/conv_grad_filter_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,14 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
const int filter_total_size = dims.spatial_dims[0].filter_size *
dims.spatial_dims[1].filter_size *
dims.in_depth;
OP_REQUIRES(
context,
filter_total_size * dims.out_depth == filter_backprop->NumElements(),
errors::InvalidArgument(
"filter_size does not have enough elements, requested ",
filter_total_size * dims.out_depth, ", got ",
filter_backprop->NumElements()));

// The output image size is the spatial size of the output.
const int output_image_size =
dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
Expand All @@ -518,6 +526,11 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {

const size_t work_unit_size = size_A + size_B + size_C;

OP_REQUIRES(
context, work_unit_size != 0,
errors::InvalidArgument(
"Work size for convolution would be 0, which is not acceptable"));

const size_t shard_size =
(target_working_set_size + work_unit_size - 1) / work_unit_size;

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/conv_grad_shape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ Status ConvBackpropComputeDimensionsV2(
// dimensions of the filter Tensor.
VLOG(2) << "input vs filter_in depth " << dims->in_depth << " "
<< filter_shape.dim_size(num_dims - 2);
if (filter_shape.dim_size(num_dims - 2) <= 0) {
return errors ::InvalidArgument(
label, ": filter depth must be strictly greated than zero");
}
if (dims->in_depth % filter_shape.dim_size(num_dims - 2)) {
return errors::InvalidArgument(
label, ": input depth must be evenly divisible by filter depth");
Expand Down
22 changes: 21 additions & 1 deletion tensorflow/core/kernels/count_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,22 @@ class SparseCount : public OpKernel {
"; values shape: ", values.shape().DebugString()));
}

OP_REQUIRES(context, shape.NumElements() != 0,
errors::InvalidArgument(
"The shape argument requires at least one element."));

bool is_1d = shape.NumElements() == 1;
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
auto shape_vector = shape.flat<int64>();
int num_batches = is_1d ? 1 : shape_vector(0);
int num_values = values.NumElements();

for (int b = 0; b < shape_vector.size(); b++) {
OP_REQUIRES(context, shape_vector(b) >= 0,
errors::InvalidArgument(
"Elements in dense_shape must be >= 0. Instead got:",
shape.DebugString()));
}

OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
errors::InvalidArgument(
"Number of values must match first dimension of indices.",
Expand All @@ -212,6 +224,14 @@ class SparseCount : public OpKernel {

for (int idx = 0; idx < num_values; ++idx) {
int batch = is_1d ? 0 : indices_values(idx, 0);
if (batch >= num_batches) {
OP_REQUIRES(context, batch < num_batches,
errors::InvalidArgument(
"Indices value along the first dimension must be ",
"lower than the first index of the shape.", "Got ",
batch, " as batch and ", num_batches,
" as the first dimension of the shape."));
}
const auto& value = values_values(idx);
if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
if (binary_output_) {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/ctc_decoder_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ class CTCGreedyDecoderOp : public OpKernel {
int prev_indices = -1;
for (int t = 0; t < seq_len_t(b); ++t) {
int max_class_indices;
OP_REQUIRES(ctx, input_list_t[t].dimension(1) > 0,
errors::InvalidArgument("Invalid input dimensions."));
log_prob_t(b, 0) +=
-RowMax<T>(input_list_t[t], b, &max_class_indices);
if (max_class_indices != blank_index &&
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/core/kernels/data/experimental/io_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,11 @@ class LoadDatasetOp::Dataset : public DatasetBase {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}

~Iterator() override { input_->Unref(); }
~Iterator() override {
if (input_) {
input_->Unref();
}
}

Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
Expand Down Expand Up @@ -330,7 +334,7 @@ class LoadDatasetOp::Dataset : public DatasetBase {
}

mutex mu_;
DatasetBase* input_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
};
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/fractional_avg_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class FractionalAvgPoolOp : public OpKernel {
std::vector<int> output_size(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
input_size[i] = tensor_in.dim_size(i);
OP_REQUIRES(
context, pooling_ratio_[i] <= input_size[i],
errors::InvalidArgument(
"Pooling ratio cannot be bigger than input tensor dim size."));
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
Expand Down
48 changes: 36 additions & 12 deletions tensorflow/core/kernels/image/draw_bounding_box_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,46 @@ class DrawBoundingBoxesOp : public OpKernel {

// At this point, {min,max}_box_{row,col}_clamp are inside the
// image.
CHECK_GE(min_box_row_clamp, 0);
CHECK_GE(max_box_row_clamp, 0);
CHECK_LT(min_box_row_clamp, height);
CHECK_LT(max_box_row_clamp, height);
CHECK_GE(min_box_col_clamp, 0);
CHECK_GE(max_box_col_clamp, 0);
CHECK_LT(min_box_col_clamp, width);
CHECK_LT(max_box_col_clamp, width);
OP_REQUIRES(
context, min_box_row_clamp >= 0,
errors::InvalidArgument("Min box row clamp is less than 0."));
OP_REQUIRES(
context, max_box_row_clamp >= 0,
errors::InvalidArgument("Max box row clamp is less than 0."));
OP_REQUIRES(context, min_box_row_clamp <= height,
errors::InvalidArgument(
"Min box row clamp is greater than height."));
OP_REQUIRES(context, max_box_row_clamp <= height,
errors::InvalidArgument(
"Max box row clamp is greater than height."));

OP_REQUIRES(
context, min_box_col_clamp >= 0,
errors::InvalidArgument("Min box col clamp is less than 0."));
OP_REQUIRES(
context, max_box_col_clamp >= 0,
errors::InvalidArgument("Max box col clamp is less than 0."));
OP_REQUIRES(context, min_box_col_clamp <= width,
errors::InvalidArgument(
"Min box col clamp is greater than width."));
OP_REQUIRES(context, max_box_col_clamp <= width,
errors::InvalidArgument(
"Max box col clamp is greater than width."));

// At this point, the min_box_row and min_box_col are either
// in the image or above/left of it, and max_box_row and
// max_box_col are either in the image or below/right or it.
CHECK_LT(min_box_row, height);
CHECK_GE(max_box_row, 0);
CHECK_LT(min_box_col, width);
CHECK_GE(max_box_col, 0);

OP_REQUIRES(
context, min_box_row <= height,
errors::InvalidArgument("Min box row is greater than height."));
OP_REQUIRES(context, max_box_row >= 0,
errors::InvalidArgument("Max box row is less than 0."));
OP_REQUIRES(
context, min_box_col <= width,
errors::InvalidArgument("Min box col is greater than width."));
OP_REQUIRES(context, max_box_col >= 0,
errors::InvalidArgument("Max box col is less than 0."));

// Draw top line.
if (min_box_row >= 0) {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/image/encode_png_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class EncodePngOp : public OpKernel {
OP_REQUIRES(context, image.dims() == 3,
errors::InvalidArgument("image must be 3-dimensional",
image.shape().DebugString()));
OP_REQUIRES(context, image.NumElements() > 0,
errors::Internal("Invalid image provided."));
OP_REQUIRES(
context,
FastBoundsCheck(image.NumElements(), std::numeric_limits<int32>::max()),
Expand Down
19 changes: 16 additions & 3 deletions tensorflow/core/kernels/linalg/matrix_diag_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,22 @@ class MatrixDiagOp : public OpKernel {
upper_diag_index = diag_index.flat<int32>()(1);
}
}
num_rows = context->input(2).flat<int32>()(0);
num_cols = context->input(3).flat<int32>()(0);
padding_value = context->input(4).flat<T>()(0);

auto& num_rows_tensor = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_rows_tensor.shape()),
errors::InvalidArgument("num_rows must be a scalar"));
num_rows = num_rows_tensor.flat<int32>()(0);

auto& num_cols_tensor = context->input(3);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_cols_tensor.shape()),
errors::InvalidArgument("num_cols must be a scalar"));
num_cols = num_cols_tensor.flat<int32>()(0);

auto& padding_value_tensor = context->input(4);
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(padding_value_tensor.shape()),
errors::InvalidArgument("padding_value must be a scalar"));
padding_value = padding_value_tensor.flat<T>()(0);
}

// Size validations.
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/core/kernels/quantize_and_dequantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,17 @@ class QuantizeAndDequantizeV4GradientOp : public OpKernel {
errors::InvalidArgument("gradient and input must be the same size"));
const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
const Tensor& input_min_tensor = ctx->input(2);
OP_REQUIRES(ctx,
input_min_tensor.dims() == 0 || input_min_tensor.dims() == 1,
errors::InvalidArgument(
"Input min tensor must have dimension 1. Recieved ",
input_min_tensor.dims(), "."));
const Tensor& input_max_tensor = ctx->input(3);
OP_REQUIRES(ctx,
input_max_tensor.dims() == 0 || input_max_tensor.dims() == 1,
errors::InvalidArgument(
"Input max tensor must have dimension 1. Recieved ",
input_max_tensor.dims(), "."));
if (axis_ != -1) {
OP_REQUIRES(
ctx, input_min_tensor.dim_size(0) == depth,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/quantized_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,8 @@ class QuantizedAddOp : public OpKernel {
tensor_min = min_x;
tensor_max = max_x;
}
OP_REQUIRES(context, vector_num_elements > 0,
errors::InvalidArgument("Must have some elements to add"));
VectorTensorAddition<T, Toutput>(
vector_data, vector_min, vector_max, vector_num_elements, tensor_data,
tensor_min, tensor_max, tensor_num_elements, min_z_value, max_z_value,
Expand Down
77 changes: 67 additions & 10 deletions tensorflow/core/kernels/quantized_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,50 @@ class QuantizedBatchNormOp : public OpKernel {

void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const float input_min = context->input(1).flat<float>()(0);
const float input_max = context->input(2).flat<float>()(0);
const auto& input_min_tensor = context->input(1);
OP_REQUIRES(context, input_min_tensor.NumElements() == 1,
errors::InvalidArgument("input_min must have 1 element"));
const float input_min = input_min_tensor.flat<float>()(0);
const auto& input_max_tensor = context->input(2);
OP_REQUIRES(context, input_max_tensor.NumElements() == 1,
errors::InvalidArgument("input_max must have 1 element"));
const float input_max = input_max_tensor.flat<float>()(0);
const Tensor& mean = context->input(3);
const float mean_min = context->input(4).flat<float>()(0);
const float mean_max = context->input(5).flat<float>()(0);
const auto& mean_min_tensor = context->input(4);
OP_REQUIRES(context, mean_min_tensor.NumElements() == 1,
errors::InvalidArgument("mean_min must have 1 element"));
const float mean_min = mean_min_tensor.flat<float>()(0);
const auto& mean_max_tensor = context->input(5);
OP_REQUIRES(context, mean_max_tensor.NumElements() == 1,
errors::InvalidArgument("mean_max must have 1 element"));
const float mean_max = mean_max_tensor.flat<float>()(0);
const Tensor& var = context->input(6);
const float var_min = context->input(7).flat<float>()(0);
const float var_max = context->input(8).flat<float>()(0);
const auto& var_min_tensor = context->input(7);
OP_REQUIRES(context, var_min_tensor.NumElements() == 1,
errors::InvalidArgument("var_min must have 1 element"));
const float var_min = var_min_tensor.flat<float>()(0);
const auto& var_max_tensor = context->input(8);
OP_REQUIRES(context, var_max_tensor.NumElements() == 1,
errors::InvalidArgument("var_max must have 1 element"));
const float var_max = var_max_tensor.flat<float>()(0);
const Tensor& beta = context->input(9);
const float beta_min = context->input(10).flat<float>()(0);
const float beta_max = context->input(11).flat<float>()(0);
const auto& beta_min_tensor = context->input(10);
OP_REQUIRES(context, beta_min_tensor.NumElements() == 1,
errors::InvalidArgument("beta_min must have 1 element"));
const float beta_min = beta_min_tensor.flat<float>()(0);
const auto& beta_max_tensor = context->input(11);
OP_REQUIRES(context, beta_max_tensor.NumElements() == 1,
errors::InvalidArgument("beta_max must have 1 element"));
const float beta_max = beta_max_tensor.flat<float>()(0);
const Tensor& gamma = context->input(12);
const float gamma_min = context->input(13).flat<float>()(0);
const float gamma_max = context->input(14).flat<float>()(0);
const auto& gamma_min_tensor = context->input(13);
OP_REQUIRES(context, gamma_min_tensor.NumElements() == 1,
errors::InvalidArgument("gamma_min must have 1 element"));
const float gamma_min = gamma_min_tensor.flat<float>()(0);
const auto& gamma_max_tensor = context->input(14);
OP_REQUIRES(context, gamma_max_tensor.NumElements() == 1,
errors::InvalidArgument("gamma_max must have 1 element"));
const float gamma_max = gamma_max_tensor.flat<float>()(0);

OP_REQUIRES(context, input.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
Expand All @@ -203,6 +233,33 @@ class QuantizedBatchNormOp : public OpKernel {
OP_REQUIRES(context, gamma.dims() == 1,
errors::InvalidArgument("gamma must be 1-dimensional",
gamma.shape().DebugString()));
OP_REQUIRES(context, mean.NumElements() > 1,
errors::InvalidArgument("Must have at least a mean value",
gamma.shape().DebugString()));
OP_REQUIRES(context, mean.NumElements() > 1,
errors::InvalidArgument("Must have at least a mean value"));
const auto last_dim = input.shape().dims() - 1;
OP_REQUIRES(context,
mean.shape().dim_size(0) == input.shape().dim_size(last_dim),
errors::InvalidArgument("Must provide as many means as the "
"last dimension of the input tensor: ",
mean.shape().DebugString(), " vs. ",
input.shape().DebugString()));
OP_REQUIRES(
context, mean.shape().dim_size(0) == var.shape().dim_size(0),
errors::InvalidArgument(
"Mean and variance tensors must have the same shape: ",
mean.shape().DebugString(), " vs. ", var.shape().DebugString()));
OP_REQUIRES(
context, mean.shape().dim_size(0) == beta.shape().dim_size(0),
errors::InvalidArgument(
"Mean and beta tensors must have the same shape: ",
mean.shape().DebugString(), " vs. ", beta.shape().DebugString()));
OP_REQUIRES(
context, mean.shape().dim_size(0) == gamma.shape().dim_size(0),
errors::InvalidArgument(
"Mean and gamma tensors must have the same shape: ",
mean.shape().DebugString(), " vs. ", gamma.shape().DebugString()));

Tensor* output = nullptr;
OP_REQUIRES_OK(context,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/quantized_bias_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class QuantizedBiasAddOp : public OpKernel {
"Must provide as many biases as the last dimension "
"of the input tensor: ",
bias.shape().DebugString(), " vs. ", input.shape().DebugString()));
OP_REQUIRES(context, bias.NumElements() > 0,
errors::InvalidArgument("Must provide at least 1 bias"));

Tensor* output = nullptr;
OP_REQUIRES_OK(context,
Expand Down

0 comments on commit 07c4828

Please sign in to comment.