Skip to content

Commit

Permalink
Fix quantize ops input validation issues.
Browse files Browse the repository at this point in the history
The majority of these are just missing checks on min/max.

PiperOrigin-RevId: 461800665
  • Loading branch information
cantonios authored and tensorflow-jenkins committed Aug 19, 2022
1 parent 0684525 commit 6840ef9
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 33 deletions.
17 changes: 16 additions & 1 deletion tensorflow/core/kernels/fake_quant_ops.cc
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
// Above is the related header but clang tidy doesn't recognize it.
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/platform/protobuf.h"
Expand Down Expand Up @@ -205,6 +206,13 @@ class FakeQuantWithMinMaxVarsOp : public OpKernel {
const Tensor& min = context->input(1);
const Tensor& max = context->input(2);

OP_REQUIRES(
context, TensorShapeUtils::IsScalar(min.shape()),
InvalidArgument("`min` must be rank 0 but is rank ", min.dims()));
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(max.shape()),
InvalidArgument("`max` must be rank 0 but is rank ", max.dims()));

Tensor* output;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
Expand Down Expand Up @@ -342,10 +350,17 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
const Tensor& input = context->input(0);
const int depth = input.dim_size(input.dims() - 1); // last dimension size.
const Tensor& min = context->input(1);
const Tensor& max = context->input(2);

OP_REQUIRES(
context, TensorShapeUtils::IsVector(min.shape()),
InvalidArgument("`min` must be rank 1 but is rank ", min.dims()));
OP_REQUIRES(context, min.dim_size(0) == depth,
InvalidArgument("min has incorrect size, expected ", depth,
" was ", min.dim_size(0)));
const Tensor& max = context->input(2);
OP_REQUIRES(
context, TensorShapeUtils::IsVector(max.shape()),
InvalidArgument("`max` must be rank 1 but is rank ", max.dims()));
OP_REQUIRES(context, max.dim_size(0) == depth,
InvalidArgument("max has incorrect size, expected ", depth,
" was ", max.dim_size(0)));
Expand Down
29 changes: 25 additions & 4 deletions tensorflow/core/kernels/quantized_bias_add_op.cc
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/quantization_utils.h"
Expand All @@ -38,10 +39,30 @@ class QuantizedBiasAddOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const Tensor& bias = context->input(1);
const float input_min = context->input(2).flat<float>()(0);
const float input_max = context->input(3).flat<float>()(0);
const float bias_min = context->input(4).flat<float>()(0);
const float bias_max = context->input(5).flat<float>()(0);

const Tensor& min_input = context->input(2);
const Tensor& max_input = context->input(3);
const Tensor& min_bias = context->input(4);
const Tensor& max_bias = context->input(5);
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(min_input.shape()),
errors::InvalidArgument("`min_input` must be rank 0 but is rank ",
min_input.dims()));
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(max_input.shape()),
errors::InvalidArgument("`max_input` must be rank 0 but is rank ",
max_input.dims()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_bias.shape()),
errors::InvalidArgument(
"`min_bias` must be rank 0 but is rank ", min_bias.dims()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_bias.shape()),
errors::InvalidArgument(
"`max_bias` must be rank 0 but is rank ", max_bias.dims()));

const float input_min = min_input.flat<float>()(0);
const float input_max = max_input.flat<float>()(0);
const float bias_min = min_bias.flat<float>()(0);
const float bias_max = max_bias.flat<float>()(0);

OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
errors::InvalidArgument("Input tensor must be at least 2D: ",
Expand Down
16 changes: 8 additions & 8 deletions tensorflow/core/kernels/quantized_bias_add_op_test.cc
Expand Up @@ -74,10 +74,10 @@ TEST_F(QuantizedBiasAddTest, Small) {
input_quantized.flat<quint8>());
AddInputFromArray<quint8>(bias_quantized.shape(),
bias_quantized.flat<quint8>());
AddInputFromArray<float>(TensorShape({1}), {input_min});
AddInputFromArray<float>(TensorShape({1}), {input_max});
AddInputFromArray<float>(TensorShape({1}), {bias_min});
AddInputFromArray<float>(TensorShape({1}), {bias_max});
AddInputFromArray<float>(TensorShape({}), {input_min});
AddInputFromArray<float>(TensorShape({}), {input_max});
AddInputFromArray<float>(TensorShape({}), {bias_min});
AddInputFromArray<float>(TensorShape({}), {bias_max});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_quantized = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
Expand Down Expand Up @@ -156,10 +156,10 @@ TEST_F(QuantizedBiasAddTest, RealData) {
input_quantized.flat<quint8>());
AddInputFromArray<quint8>(bias_quantized.shape(),
bias_quantized.flat<quint8>());
AddInputFromArray<float>(TensorShape({1}), {input_min});
AddInputFromArray<float>(TensorShape({1}), {input_max});
AddInputFromArray<float>(TensorShape({1}), {bias_min});
AddInputFromArray<float>(TensorShape({1}), {bias_max});
AddInputFromArray<float>(TensorShape({}), {input_min});
AddInputFromArray<float>(TensorShape({}), {input_max});
AddInputFromArray<float>(TensorShape({}), {bias_min});
AddInputFromArray<float>(TensorShape({}), {bias_max});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_quantized = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
Expand Down
14 changes: 11 additions & 3 deletions tensorflow/core/kernels/quantized_instance_norm.cc
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"

#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/quantization_utils.h"

#ifdef USE_NEON
Expand Down Expand Up @@ -274,8 +274,16 @@ class QuantizedInstanceNorm : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);

float input_min = context->input(1).flat<float>()(0);
float input_max = context->input(2).flat<float>()(0);
const Tensor& x_min = context->input(1);
const Tensor& x_max = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(x_min.shape()),
errors::InvalidArgument("`x_min` must be rank 0 but is rank ",
x_min.dims()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(x_max.shape()),
errors::InvalidArgument("`x_max` must be rank 0 but is rank ",
x_max.dims()));
float input_min = x_min.scalar<float>()();
float input_max = x_max.scalar<float>()();
float input_scale = (input_max - input_min) / 255.0f;

OP_REQUIRES(context, input_min < input_max,
Expand Down
36 changes: 31 additions & 5 deletions tensorflow/core/kernels/requantize.cc
Expand Up @@ -18,9 +18,11 @@ limitations under the License.
#define EIGEN_USE_THREADS

#include <math.h>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/meta_support.h"
Expand All @@ -38,10 +40,34 @@ class RequantizeOp : public OpKernel {

void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
const float input_min_float = ctx->input(1).flat<float>()(0);
const float input_max_float = ctx->input(2).flat<float>()(0);
const float requested_output_min_float = ctx->input(3).flat<float>()(0);
const float requested_output_max_float = ctx->input(4).flat<float>()(0);

const Tensor& input_min = ctx->input(1);
const Tensor& input_max = ctx->input(2);
const Tensor& requested_output_min = ctx->input(3);
const Tensor& requested_output_max = ctx->input(4);
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(input_min.shape()),
errors::InvalidArgument("`input_min` must be rank 0 but is rank ",
input_min.dims()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(input_max.shape()),
errors::InvalidArgument("`input_max` must be rank 0 but is rank ",
input_max.dims()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(requested_output_min.shape()),
errors::InvalidArgument(
"`requested_output_min` must be rank 0 but is rank ",
requested_output_min.dims()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(requested_output_max.shape()),
errors::InvalidArgument(
"`requested_output_max` must be rank 0 but is rank ",
requested_output_max.dims()));

const float input_min_float = input_min.flat<float>()(0);
const float input_max_float = input_max.flat<float>()(0);
const float requested_output_min_float =
requested_output_min.flat<float>()(0);
const float requested_output_max_float =
requested_output_max.flat<float>()(0);

Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
Expand Down
24 changes: 12 additions & 12 deletions tensorflow/core/kernels/requantize_op_test.cc
Expand Up @@ -53,10 +53,10 @@ TEST_F(RequantizeTest, HandCraftedRequantize) {
// Requantize to -1 to 1.
AddInputFromArray<qint32>(TensorShape({value_count}),
{-(1 << 23), 0, (1 << 23)});
AddInputFromArray<float>(TensorShape({1}), {-256.0f});
AddInputFromArray<float>(TensorShape({1}), {256.0f});
AddInputFromArray<float>(TensorShape({1}), {-1.0f});
AddInputFromArray<float>(TensorShape({1}), {1.0f});
AddInputFromArray<float>(TensorShape({}), {-256.0f});
AddInputFromArray<float>(TensorShape({}), {256.0f});
AddInputFromArray<float>(TensorShape({}), {-1.0f});
AddInputFromArray<float>(TensorShape({}), {1.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({value_count}));
test::FillValues<quint8>(&expected, {0, 128, 255});
Expand All @@ -71,10 +71,10 @@ TEST_F(RequantizeTest, InvalidOutputMin) {

AddInputFromArray<qint32>(TensorShape({value_count}),
{-(1 << 23), 0, (1 << 23)});
AddInputFromArray<float>(TensorShape({1}), {-256.0f});
AddInputFromArray<float>(TensorShape({1}), {256.0f});
AddInputFromArray<float>(TensorShape({1}), {0.01f});
AddInputFromArray<float>(TensorShape({1}), {1.0f});
AddInputFromArray<float>(TensorShape({}), {-256.0f});
AddInputFromArray<float>(TensorShape({}), {256.0f});
AddInputFromArray<float>(TensorShape({}), {0.01f});
AddInputFromArray<float>(TensorShape({}), {1.0f});
EXPECT_EQ("requested_output_min must be <= 0, but got 0.01",
RunOpKernel().error_message());
}
Expand All @@ -85,10 +85,10 @@ TEST_F(RequantizeTest, InvalidOutputMax) {

AddInputFromArray<qint32>(TensorShape({value_count}),
{-(1 << 23), 0, (1 << 23)});
AddInputFromArray<float>(TensorShape({1}), {-256.0f});
AddInputFromArray<float>(TensorShape({1}), {256.0f});
AddInputFromArray<float>(TensorShape({1}), {-10.0f});
AddInputFromArray<float>(TensorShape({1}), {-11.0f});
AddInputFromArray<float>(TensorShape({}), {-256.0f});
AddInputFromArray<float>(TensorShape({}), {256.0f});
AddInputFromArray<float>(TensorShape({}), {-10.0f});
AddInputFromArray<float>(TensorShape({}), {-11.0f});
EXPECT_EQ(
"requested_output_max must be >= requested_output_min, but got -11 and "
"-10",
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/python/kernel_tests/quantization_ops/BUILD
@@ -0,0 +1,24 @@
# Tests of TensorFlow quantization ops written using the Python API.

# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_py_test")

package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"],
)

tf_py_test(
name = "quantization_ops_test",
size = "small",
srcs = ["quantization_ops_test.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//third_party/py/numpy",
],
)

0 comments on commit 6840ef9

Please sign in to comment.