Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add IsScalar / IsVector (rank) checks to input min/max tensors for Fa…
…keQuantWithMinMaxVarsPerChannelGradientOp and FakeQuantWithMinMaxVarsGradientOp.

PiperOrigin-RevId: 462542629
  • Loading branch information
tensorflower-gardener committed Jul 22, 2022
1 parent 7f64135 commit f3cf67a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 4 deletions.
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/fake_quant_ops.cc
Expand Up @@ -261,6 +261,12 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
InvalidArgument("gradient and input must be the same size"));
const Tensor& min = context->input(2);
const Tensor& max = context->input(3);
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(min.shape()),
InvalidArgument("`min` must be rank 0 but is rank ", min.dims()));
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(max.shape()),
InvalidArgument("`max` must be rank 0 but is rank ", max.dims()));

Tensor* grad_wrt_input;
OP_REQUIRES_OK(context,
Expand Down Expand Up @@ -414,10 +420,16 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
InvalidArgument("gradient and input must be the same size"));
const int depth = input.dim_size(input.dims() - 1); // last dimension size.
const Tensor& min = context->input(2);
OP_REQUIRES(
context, TensorShapeUtils::IsVector(min.shape()),
InvalidArgument("`min` must be rank 1 but is rank ", min.dims()));
OP_REQUIRES(context, min.dim_size(0) == depth,
InvalidArgument("min has incorrect size, expected ", depth,
" was ", min.dim_size(0)));
const Tensor& max = context->input(3);
OP_REQUIRES(
context, TensorShapeUtils::IsVector(max.shape()),
InvalidArgument("`max` must be rank 1 but is rank ", max.dims()));
OP_REQUIRES(context, max.dim_size(0) == depth,
InvalidArgument("max has incorrect size, expected ", depth,
" was ", max.dim_size(0)));
Expand Down
Expand Up @@ -77,6 +77,71 @@ def test_invalid_inputs(self):
inputs=inputs, min=[0.0], max=[1.0, 1.1]))


class FakeQuantWithMinMaxVarsGradientOpTest(test_util.TensorFlowTestCase):

@test_util.run_in_graph_and_eager_modes
def test_invalid_inputs(self):
gradients = constant_op.constant(
value=[[1.0], [2.0], [4.0]], dtype=dtypes.float32)
inputs = constant_op.constant(
value=[[1.0], [2.0], [4.0]], dtype=dtypes.float32)

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be equal rank|must be rank 0"):
self.evaluate(
array_ops.fake_quant_with_min_max_vars_gradient(
gradients=gradients,
inputs=inputs,
min=0.0,
max=[[1.0], [2.0], [4.0]]))

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 0"):
self.evaluate(
array_ops.fake_quant_with_min_max_vars_gradient(
gradients=gradients,
inputs=inputs,
min=[[1.0], [2.0], [4.0]],
max=[[1.0], [2.0], [4.0]]))


class FakeQuantWithMinMaxVarsPerChannelGradientOpTest(
test_util.TensorFlowTestCase):

@test_util.run_in_graph_and_eager_modes
def test_invalid_inputs(self):
gradients = constant_op.constant(
value=[[1.0], [2.0], [4.0]], dtype=dtypes.float32)
inputs = constant_op.constant(
value=[[1.0], [2.0], [4.0]], dtype=dtypes.float32)

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Shapes must be equal rank|must be rank 1"):
self.evaluate(
array_ops.fake_quant_with_min_max_vars_per_channel_gradient(
gradients=gradients, inputs=inputs, min=[[0.0]], max=[1.0]))

with self.assertRaisesRegex(
(ValueError, errors.InvalidArgumentError),
"Dimension 0 in both shapes must be equal|incorrect size"):
self.evaluate(
array_ops.fake_quant_with_min_max_vars_per_channel_gradient(
gradients=gradients, inputs=inputs, min=[0.0, 0.1], max=[1.0]))

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Shapes must be equal rank|must be rank 1"):
self.evaluate(
array_ops.fake_quant_with_min_max_vars_per_channel_gradient(
gradients=gradients, inputs=inputs, min=[1.0], max=[[1.0]]))

with self.assertRaisesRegex(
(ValueError, errors.InvalidArgumentError),
"Dimension 0 in both shapes must be equal|incorrect size"):
self.evaluate(
array_ops.fake_quant_with_min_max_vars_per_channel_gradient(
gradients=gradients, inputs=inputs, min=[0.0], max=[1.0, 1.1]))


class QuantizedBiasedAddTest(test_util.TensorFlowTestCase):

@test_util.run_in_graph_and_eager_modes
Expand Down Expand Up @@ -337,10 +402,9 @@ def test_invalid_inputs(self):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 0"):
self.evaluate(
math_ops.quantize_down_and_shrink_range(input=inputs,
input_min=[],
input_max=4.0,
out_type=dtypes.quint8))
math_ops.quantize_down_and_shrink_range(
input=inputs, input_min=[], input_max=4.0,
out_type=dtypes.quint8))


if __name__ == "__main__":
Expand Down

0 comments on commit f3cf67a

Please sign in to comment.