Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add IsScalar (rank == 0) check to min/max input tensors for Quantized…
…Add/Relu/Relu6 op.

PiperOrigin-RevId: 461902847
  • Loading branch information
tensorflower-gardener committed Jul 19, 2022
1 parent a740437 commit 49b3824
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 12 deletions.
34 changes: 30 additions & 4 deletions tensorflow/core/kernels/quantized_activation_ops.cc
Expand Up @@ -32,8 +32,21 @@ class QuantizedReluOp : public OpKernel {

void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const float min_input = context->input(1).flat<float>()(0);
const float max_input = context->input(2).flat<float>()(0);
const Tensor& min_input_tensor = context->input(1);
const Tensor& max_input_tensor = context->input(2);

OP_REQUIRES(
context, TensorShapeUtils::IsScalar(min_input_tensor.shape()),
errors::InvalidArgument("`min_input` must be rank 0 but is rank ",
min_input_tensor.dims()));
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(max_input_tensor.shape()),
errors::InvalidArgument("`max_input` must be rank 0 but is rank ",
max_input_tensor.dims()));

const float min_input = min_input_tensor.scalar<float>()();
const float max_input = max_input_tensor.scalar<float>()();

Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
Expand Down Expand Up @@ -65,8 +78,21 @@ class QuantizedRelu6Op : public OpKernel {

void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const float min_input = context->input(1).flat<float>()(0);
const float max_input = context->input(2).flat<float>()(0);
const Tensor& min_input_tensor = context->input(1);
const Tensor& max_input_tensor = context->input(2);

OP_REQUIRES(
context, TensorShapeUtils::IsScalar(min_input_tensor.shape()),
errors::InvalidArgument("`min_input` must be rank 0 but is rank ",
min_input_tensor.dims()));
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(max_input_tensor.shape()),
errors::InvalidArgument("`max_input` must be rank 0 but is rank ",
max_input_tensor.dims()));

const float min_input = min_input_tensor.scalar<float>()();
const float max_input = max_input_tensor.scalar<float>()();

Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/kernels/quantized_activation_ops_test.cc
Expand Up @@ -55,8 +55,8 @@ TEST_F(QuantizedActivationsTest, TestRelu) {

AddInputFromArray<quint8>(input_quantized.shape(),
input_quantized.flat<quint8>());
AddInputFromArray<float>(TensorShape({1}), {input_min});
AddInputFromArray<float>(TensorShape({1}), {input_max});
AddInputFromArray<float>(TensorShape({}), {input_min});
AddInputFromArray<float>(TensorShape({}), {input_max});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_quantized = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
Expand Down Expand Up @@ -86,8 +86,8 @@ TEST_F(QuantizedActivationsTest, TestRelu6) {

AddInputFromArray<quint8>(input_quantized.shape(),
input_quantized.flat<quint8>());
AddInputFromArray<float>(TensorShape({1}), {input_min});
AddInputFromArray<float>(TensorShape({1}), {input_max});
AddInputFromArray<float>(TensorShape({}), {input_min});
AddInputFromArray<float>(TensorShape({}), {input_max});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_quantized = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
Expand Down
27 changes: 23 additions & 4 deletions tensorflow/core/kernels/quantized_add_op.cc
Expand Up @@ -25,6 +25,7 @@ limitations under the License.

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
Expand Down Expand Up @@ -457,10 +458,28 @@ class QuantizedAddOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& x = context->input(0);
const Tensor& y = context->input(1);
const float min_x = context->input(2).flat<float>()(0);
const float max_x = context->input(3).flat<float>()(0);
const float min_y = context->input(4).flat<float>()(0);
const float max_y = context->input(5).flat<float>()(0);
const Tensor& min_x_tensor = context->input(2);
const Tensor& max_x_tensor = context->input(3);
const Tensor& min_y_tensor = context->input(4);
const Tensor& max_y_tensor = context->input(5);

OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_x_tensor.shape()),
errors::InvalidArgument("`min_x` must be rank 0 but is rank ",
min_x_tensor.dims()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_x_tensor.shape()),
errors::InvalidArgument("`max_x` must be rank 0 but is rank ",
max_x_tensor.dims()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_y_tensor.shape()),
errors::InvalidArgument("`min_y` must be rank 0 but is rank ",
min_y_tensor.dims()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_y_tensor.shape()),
errors::InvalidArgument("`max_y` must be rank 0 but is rank ",
max_y_tensor.dims()));

const float min_x = min_x_tensor.scalar<float>()();
const float max_x = max_x_tensor.scalar<float>()();
const float min_y = min_y_tensor.scalar<float>()();
const float max_y = max_y_tensor.scalar<float>()();

BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape()));
if (!bcast.IsValid()) {
Expand Down
Expand Up @@ -206,5 +206,60 @@ def test_invalid_inputs(self):
out_type=dtypes.qint8))


class QuantizedAddOpTest(test_util.TensorFlowTestCase):

@test_util.run_in_graph_and_eager_modes
def test_invalid_inputs(self):
x = constant_op.constant(
np.int8(0), shape=[3, 3, 3, 3], dtype=dtypes.quint8)
y = constant_op.constant(np.int8(0), shape=[3], dtype=dtypes.quint8)

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 0"):
self.evaluate(
math_ops.quantized_add(
x=x,
y=y,
min_x=[],
max_x=1.0,
min_y=0.0,
max_y=1.0,
Toutput=dtypes.qint32))


class QuantizedReluOpTest(test_util.TensorFlowTestCase):

@test_util.run_in_graph_and_eager_modes
def test_invalid_inputs(self):
inputs = constant_op.constant(
np.int8(0), shape=[3, 3, 3, 3], dtype=dtypes.quint8)

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 0"):
self.evaluate(
nn_ops.quantized_relu(
features=inputs,
min_features=[],
max_features=127.0,
out_type=dtypes.quint8))


class QuantizedRelu6OpTest(test_util.TensorFlowTestCase):

@test_util.run_in_graph_and_eager_modes
def test_invalid_inputs(self):
inputs = constant_op.constant(
np.int8(0), shape=[3, 3, 3, 3], dtype=dtypes.quint8)

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 0"):
self.evaluate(
nn_ops.quantized_relu6(
features=inputs,
min_features=[],
max_features=127.0,
out_type=dtypes.quint8))


if __name__ == "__main__":
googletest.main()

0 comments on commit 49b3824

Please sign in to comment.