Skip to content
Permalink
Browse files Browse the repository at this point in the history
Validate inputs to QuantizedMul
PiperOrigin-RevId: 369756982
Change-Id: I00d960cc3b9316fd7a86bd37a44e341c96e17624
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Apr 21, 2021
1 parent 87cf4d3 commit efea03b
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions tensorflow/core/kernels/quantized_mul_op.cc
Expand Up @@ -284,10 +284,22 @@ class QuantizedMulOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& x = context->input(0);
const Tensor& y = context->input(1);
const float min_x = context->input(2).flat<float>()(0);
const float max_x = context->input(3).flat<float>()(0);
const float min_y = context->input(4).flat<float>()(0);
const float max_y = context->input(5).flat<float>()(0);
auto& min_x_tensor = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_x_tensor.shape()),
errors::InvalidArgument("min_x must be a scalar"));
const float min_x = min_x_tensor.flat<float>()(0);
auto& max_x_tensor = context->input(3);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_x_tensor.shape()),
errors::InvalidArgument("max_x must be a scalar"));
const float max_x = max_x_tensor.flat<float>()(0);
auto& min_y_tensor = context->input(4);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_y_tensor.shape()),
errors::InvalidArgument("min_y must be a scalar"));
const float min_y = min_y_tensor.flat<float>()(0);
auto& max_y_tensor = context->input(5);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_y_tensor.shape()),
errors::InvalidArgument("max_y must be a scalar"));
const float max_y = max_y_tensor.flat<float>()(0);

BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape()));
if (!bcast.IsValid()) {
Expand Down

0 comments on commit efea03b

Please sign in to comment.