Skip to content

Commit f3cf67a

Browse files
Add IsScalar / IsVector (rank) checks to input min/max tensors for FakeQuantWithMinMaxVarsPerChannelGradientOp and FakeQuantWithMinMaxVarsGradientOp.
PiperOrigin-RevId: 462542629
1 parent 7f64135 commit f3cf67a

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

Diff for: tensorflow/core/kernels/fake_quant_ops.cc

+12
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,12 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
261261
InvalidArgument("gradient and input must be the same size"));
262262
const Tensor& min = context->input(2);
263263
const Tensor& max = context->input(3);
264+
OP_REQUIRES(
265+
context, TensorShapeUtils::IsScalar(min.shape()),
266+
InvalidArgument("`min` must be rank 0 but is rank ", min.dims()));
267+
OP_REQUIRES(
268+
context, TensorShapeUtils::IsScalar(max.shape()),
269+
InvalidArgument("`max` must be rank 0 but is rank ", max.dims()));
264270

265271
Tensor* grad_wrt_input;
266272
OP_REQUIRES_OK(context,
@@ -414,10 +420,16 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
414420
InvalidArgument("gradient and input must be the same size"));
415421
const int depth = input.dim_size(input.dims() - 1); // last dimension size.
416422
const Tensor& min = context->input(2);
423+
OP_REQUIRES(
424+
context, TensorShapeUtils::IsVector(min.shape()),
425+
InvalidArgument("`min` must be rank 1 but is rank ", min.dims()));
417426
OP_REQUIRES(context, min.dim_size(0) == depth,
418427
InvalidArgument("min has incorrect size, expected ", depth,
419428
" was ", min.dim_size(0)));
420429
const Tensor& max = context->input(3);
430+
OP_REQUIRES(
431+
context, TensorShapeUtils::IsVector(max.shape()),
432+
InvalidArgument("`max` must be rank 1 but is rank ", max.dims()));
421433
OP_REQUIRES(context, max.dim_size(0) == depth,
422434
InvalidArgument("max has incorrect size, expected ", depth,
423435
" was ", max.dim_size(0)));

Diff for: tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py

+68-4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,71 @@ def test_invalid_inputs(self):
7777
inputs=inputs, min=[0.0], max=[1.0, 1.1]))
7878

7979

80+
class FakeQuantWithMinMaxVarsGradientOpTest(test_util.TensorFlowTestCase):
81+
82+
@test_util.run_in_graph_and_eager_modes
83+
def test_invalid_inputs(self):
84+
gradients = constant_op.constant(
85+
value=[[1.0], [2.0], [4.0]], dtype=dtypes.float32)
86+
inputs = constant_op.constant(
87+
value=[[1.0], [2.0], [4.0]], dtype=dtypes.float32)
88+
89+
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
90+
"must be equal rank|must be rank 0"):
91+
self.evaluate(
92+
array_ops.fake_quant_with_min_max_vars_gradient(
93+
gradients=gradients,
94+
inputs=inputs,
95+
min=0.0,
96+
max=[[1.0], [2.0], [4.0]]))
97+
98+
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
99+
"must be rank 0"):
100+
self.evaluate(
101+
array_ops.fake_quant_with_min_max_vars_gradient(
102+
gradients=gradients,
103+
inputs=inputs,
104+
min=[[1.0], [2.0], [4.0]],
105+
max=[[1.0], [2.0], [4.0]]))
106+
107+
108+
class FakeQuantWithMinMaxVarsPerChannelGradientOpTest(
109+
test_util.TensorFlowTestCase):
110+
111+
@test_util.run_in_graph_and_eager_modes
112+
def test_invalid_inputs(self):
113+
gradients = constant_op.constant(
114+
value=[[1.0], [2.0], [4.0]], dtype=dtypes.float32)
115+
inputs = constant_op.constant(
116+
value=[[1.0], [2.0], [4.0]], dtype=dtypes.float32)
117+
118+
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
119+
"Shapes must be equal rank|must be rank 1"):
120+
self.evaluate(
121+
array_ops.fake_quant_with_min_max_vars_per_channel_gradient(
122+
gradients=gradients, inputs=inputs, min=[[0.0]], max=[1.0]))
123+
124+
with self.assertRaisesRegex(
125+
(ValueError, errors.InvalidArgumentError),
126+
"Dimension 0 in both shapes must be equal|incorrect size"):
127+
self.evaluate(
128+
array_ops.fake_quant_with_min_max_vars_per_channel_gradient(
129+
gradients=gradients, inputs=inputs, min=[0.0, 0.1], max=[1.0]))
130+
131+
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
132+
"Shapes must be equal rank|must be rank 1"):
133+
self.evaluate(
134+
array_ops.fake_quant_with_min_max_vars_per_channel_gradient(
135+
gradients=gradients, inputs=inputs, min=[1.0], max=[[1.0]]))
136+
137+
with self.assertRaisesRegex(
138+
(ValueError, errors.InvalidArgumentError),
139+
"Dimension 0 in both shapes must be equal|incorrect size"):
140+
self.evaluate(
141+
array_ops.fake_quant_with_min_max_vars_per_channel_gradient(
142+
gradients=gradients, inputs=inputs, min=[0.0], max=[1.0, 1.1]))
143+
144+
80145
class QuantizedBiasedAddTest(test_util.TensorFlowTestCase):
81146

82147
@test_util.run_in_graph_and_eager_modes
@@ -337,10 +402,9 @@ def test_invalid_inputs(self):
337402
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
338403
"must be rank 0"):
339404
self.evaluate(
340-
math_ops.quantize_down_and_shrink_range(input=inputs,
341-
input_min=[],
342-
input_max=4.0,
343-
out_type=dtypes.quint8))
405+
math_ops.quantize_down_and_shrink_range(
406+
input=inputs, input_min=[], input_max=4.0,
407+
out_type=dtypes.quint8))
344408

345409

346410
if __name__ == "__main__":

0 commit comments

Comments
 (0)