Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add IsScalar (rank == 0) check to input_min/max tensors for QuantizeD…
…ownAndShrinkRangeOp.

PiperOrigin-RevId: 462401306
  • Loading branch information
tensorflower-gardener committed Jul 21, 2022
1 parent 81b7782 commit 73ad181
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
16 changes: 14 additions & 2 deletions tensorflow/core/kernels/quantize_down_and_shrink_range.cc
Expand Up @@ -40,8 +40,20 @@ class QuantizeDownAndShrinkRangeOp : public OpKernel {

void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
const float input_min_float = ctx->input(1).flat<float>()(0);
const float input_max_float = ctx->input(2).flat<float>()(0);
const Tensor& input_min = ctx->input(1);
const Tensor& input_max = ctx->input(2);

OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(input_min.shape()),
errors::InvalidArgument("`input_min` must be rank 0 but is rank ",
input_min.dims()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(input_max.shape()),
errors::InvalidArgument("`input_max` must be rank 0 but is rank ",
input_max.dims()));

const float input_min_float = input_min.scalar<float>()();
const float input_max_float = input_max.scalar<float>()();
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
Tensor* output_min = nullptr;
Expand Down
Expand Up @@ -53,8 +53,8 @@ TEST_F(QuantizeDownAndShrinkRangeTest, HandCrafted) {
const int value_count = 3;
AddInputFromArray<qint32>(TensorShape({value_count}),
{-(1 << 23), 0, (1 << 23)});
AddInputFromArray<float>(TensorShape({1}), {-256.0f});
AddInputFromArray<float>(TensorShape({1}), {256.0f});
AddInputFromArray<float>(TensorShape({}), {-256.0f});
AddInputFromArray<float>(TensorShape({}), {256.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({value_count}));
test::FillValues<quint8>(&expected, {0, 128, 255});
Expand Down
Expand Up @@ -261,5 +261,21 @@ def test_invalid_inputs(self):
out_type=dtypes.quint8))


class QuantizeDownAndShrinkRangeOpTest(test_util.TensorFlowTestCase):

@test_util.run_in_graph_and_eager_modes
def test_invalid_inputs(self):
inputs = constant_op.constant(
np.int32(0), shape=[3, 3, 3, 3], dtype=dtypes.qint32)

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 0"):
self.evaluate(
math_ops.quantize_down_and_shrink_range(input=inputs,
input_min=[],
input_max=4.0,
out_type=dtypes.quint8))


if __name__ == "__main__":
googletest.main()

0 comments on commit 73ad181

Please sign in to comment.