diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc index b5b6f14bcda301..0b82b72ccc3c0f 100644 --- a/tensorflow/core/kernels/gather_nd_op.cc +++ b/tensorflow/core/kernels/gather_nd_op.cc @@ -71,6 +71,7 @@ class GatherNdOp : public OpKernel { // // Same for the GPU kernel. TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); +TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU); #undef REGISTER_GATHER_ND_CPU diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h index 46414a38fb0ecd..836a6aa59926f2 100644 --- a/tensorflow/core/kernels/gather_nd_op.h +++ b/tensorflow/core/kernels/gather_nd_op.h @@ -100,9 +100,9 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params, } if (slice_size_big > std::numeric_limits::max()) { - return errors::InvalidArgument( - "slice size is too large for indexing: ", slice_size_big, " > ", - std::numeric_limits::max()); + return errors::InvalidArgument("slice size is too large for indexing: ", + slice_size_big, " > ", + std::numeric_limits::max()); } const Index slice_size = static_cast(slice_size_big); diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h index cf9817dc3060be..c3d2f70139800f 100644 --- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -152,6 +152,7 @@ struct GatherNdSlice { REGISTER_GATHER_ND_FULL(type, int64) TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); +TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU); } // namespace functor diff --git a/tensorflow/core/kernels/gather_nd_op_test.cc b/tensorflow/core/kernels/gather_nd_op_test.cc index 9f8658ef0e81b6..b0b5c958b5a00d 100644 --- a/tensorflow/core/kernels/gather_nd_op_test.cc +++ b/tensorflow/core/kernels/gather_nd_op_test.cc @@ -57,9 +57,9 @@ namespace { class GatherNdOpTest : public OpsTestBase { protected: - void MakeOp(DataType index_type) { + void MakeOp(DataType param_type, DataType index_type) { TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(param_type)) .Input(FakeInput(index_type)) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); @@ -67,7 +67,7 @@ class GatherNdOpTest : public OpsTestBase { }; TEST_F(GatherNdOpTest, Simple) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5}), {0, 1, 2, 8, 4}); @@ -80,6 +80,32 @@ TEST_F(GatherNdOpTest, Simple) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +TEST_F(GatherNdOpTest, Quantized_UINT8) { + MakeOp(DT_QUINT8, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5}), {0, 1, 2, 8, 4}); + AddInputFromArray(TensorShape({2, 1}), {3, 4}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_QUINT8, TensorShape({2})); + test::FillValues(&expected, {8, 4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(GatherNdOpTest, Quantized_INT8) { + MakeOp(DT_QINT8, DT_INT32); + + AddInputFromArray(TensorShape({5}), {0, 1, 2, 8, 4}); + AddInputFromArray(TensorShape({2, 1}), {3, 4}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_QINT8, TensorShape({2})); + test::FillValues(&expected, {8, 4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + constexpr int kLookups = 2000; template