Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing validation in QuantizedBatchNormWithGlobalNormalization #49034

Merged
merged 1 commit into from
May 10, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 67 additions & 10 deletions tensorflow/core/kernels/quantized_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,50 @@ class QuantizedBatchNormOp : public OpKernel {

void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const float input_min = context->input(1).flat<float>()(0);
const float input_max = context->input(2).flat<float>()(0);
const auto& input_min_tensor = context->input(1);
OP_REQUIRES(context, input_min_tensor.NumElements() == 1,
errors::InvalidArgument("input_min must have 1 element"));
const float input_min = input_min_tensor.flat<float>()(0);
const auto& input_max_tensor = context->input(2);
OP_REQUIRES(context, input_max_tensor.NumElements() == 1,
errors::InvalidArgument("input_max must have 1 element"));
const float input_max = input_max_tensor.flat<float>()(0);
const Tensor& mean = context->input(3);
const float mean_min = context->input(4).flat<float>()(0);
const float mean_max = context->input(5).flat<float>()(0);
const auto& mean_min_tensor = context->input(4);
OP_REQUIRES(context, mean_min_tensor.NumElements() == 1,
errors::InvalidArgument("mean_min must have 1 element"));
const float mean_min = mean_min_tensor.flat<float>()(0);
const auto& mean_max_tensor = context->input(5);
OP_REQUIRES(context, mean_max_tensor.NumElements() == 1,
errors::InvalidArgument("mean_max must have 1 element"));
const float mean_max = mean_max_tensor.flat<float>()(0);
const Tensor& var = context->input(6);
const float var_min = context->input(7).flat<float>()(0);
const float var_max = context->input(8).flat<float>()(0);
const auto& var_min_tensor = context->input(7);
OP_REQUIRES(context, var_min_tensor.NumElements() == 1,
errors::InvalidArgument("var_min must have 1 element"));
const float var_min = var_min_tensor.flat<float>()(0);
const auto& var_max_tensor = context->input(8);
OP_REQUIRES(context, var_max_tensor.NumElements() == 1,
errors::InvalidArgument("var_max must have 1 element"));
const float var_max = var_max_tensor.flat<float>()(0);
const Tensor& beta = context->input(9);
const float beta_min = context->input(10).flat<float>()(0);
const float beta_max = context->input(11).flat<float>()(0);
const auto& beta_min_tensor = context->input(10);
OP_REQUIRES(context, beta_min_tensor.NumElements() == 1,
errors::InvalidArgument("beta_min must have 1 element"));
const float beta_min = beta_min_tensor.flat<float>()(0);
const auto& beta_max_tensor = context->input(11);
OP_REQUIRES(context, beta_max_tensor.NumElements() == 1,
errors::InvalidArgument("beta_max must have 1 element"));
const float beta_max = beta_max_tensor.flat<float>()(0);
const Tensor& gamma = context->input(12);
const float gamma_min = context->input(13).flat<float>()(0);
const float gamma_max = context->input(14).flat<float>()(0);
const auto& gamma_min_tensor = context->input(13);
OP_REQUIRES(context, gamma_min_tensor.NumElements() == 1,
errors::InvalidArgument("gamma_min must have 1 element"));
const float gamma_min = gamma_min_tensor.flat<float>()(0);
const auto& gamma_max_tensor = context->input(14);
OP_REQUIRES(context, gamma_max_tensor.NumElements() == 1,
errors::InvalidArgument("gamma_max must have 1 element"));
const float gamma_max = gamma_max_tensor.flat<float>()(0);

OP_REQUIRES(context, input.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
Expand All @@ -203,6 +233,33 @@ class QuantizedBatchNormOp : public OpKernel {
OP_REQUIRES(context, gamma.dims() == 1,
errors::InvalidArgument("gamma must be 1-dimensional",
gamma.shape().DebugString()));
OP_REQUIRES(context, mean.NumElements() > 1,
errors::InvalidArgument("Must have at least a mean value",
gamma.shape().DebugString()));
OP_REQUIRES(context, mean.NumElements() > 1,
errors::InvalidArgument("Must have at least a mean value"));
const auto last_dim = input.shape().dims() - 1;
OP_REQUIRES(context,
mean.shape().dim_size(0) == input.shape().dim_size(last_dim),
errors::InvalidArgument("Must provide as many means as the "
"last dimension of the input tensor: ",
mean.shape().DebugString(), " vs. ",
input.shape().DebugString()));
OP_REQUIRES(
context, mean.shape().dim_size(0) == var.shape().dim_size(0),
errors::InvalidArgument(
"Mean and variance tensors must have the same shape: ",
mean.shape().DebugString(), " vs. ", var.shape().DebugString()));
OP_REQUIRES(
context, mean.shape().dim_size(0) == beta.shape().dim_size(0),
errors::InvalidArgument(
"Mean and beta tensors must have the same shape: ",
mean.shape().DebugString(), " vs. ", beta.shape().DebugString()));
OP_REQUIRES(
context, mean.shape().dim_size(0) == gamma.shape().dim_size(0),
errors::InvalidArgument(
"Mean and gamma tensors must have the same shape: ",
mean.shape().DebugString(), " vs. ", gamma.shape().DebugString()));

Tensor* output = nullptr;
OP_REQUIRES_OK(context,
Expand Down