From 86adfccccca835b32da093ec265195b9a170f212 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 16 May 2019 11:43:40 -0700 Subject: [PATCH] Adding support for handling uint32 and uint64 dtypes in batch_util CopySliceToElement and MaybeMoveSliceToElement New regression tests added in tensor_slice_dataset_op_test.cc, from_tensor_slices_test.py, and unbatch_test.py --- .../core/kernels/data/dataset_test_base.cc | 2 ++ .../data/tensor_slice_dataset_op_test.cc | 14 ++++++++++++++ tensorflow/core/util/batch_util.cc | 4 ++++ .../kernel_tests/from_tensor_slices_test.py | 15 +++++++++++++++ .../python/data/kernel_tests/unbatch_test.py | 18 ++++++++++++++++++ 5 files changed, 53 insertions(+) diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 6765a5af74dd4c..e3565024ae1424 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -54,6 +54,8 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) { break; TF_CALL_NUMBER_TYPES(CASE); TF_CALL_string(CASE); + TF_CALL_uint32(CASE); + TF_CALL_uint64(CASE); // TODO(feihugis): figure out how to support variant tensors. #undef CASE default: diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc index ee2619663f08d7..e356e8fc3f4daa 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc @@ -76,6 +76,12 @@ TestCase PlainTensorTestCase() { {DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), {1, 2, 3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {2, 3}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), + {2, 3, 4, 5}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), + {3, 4, 5, 6}), DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), {37.0, 38.0}), DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), @@ -83,10 +89,18 @@ TestCase PlainTensorTestCase() { /*expected_outputs*/ {DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {2, 3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {3, 4}), DatasetOpsTestBase::CreateTensor(TensorShape({1}), {37.0}), DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"a"}), DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), DatasetOpsTestBase::CreateTensor(TensorShape({2}), {3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {4, 5}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {5, 6}), DatasetOpsTestBase::CreateTensor(TensorShape({1}), {38.0}), DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"b"})}, /*breakpoints*/ {0, 1, 3}}; diff --git a/tensorflow/core/util/batch_util.cc b/tensorflow/core/util/batch_util.cc index 45556d53a46f9e..8b694d9cc32568 100644 --- a/tensorflow/core/util/batch_util.cc +++ b/tensorflow/core/util/batch_util.cc @@ -156,6 +156,8 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { switch (parent.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); + TF_CALL_uint32(HANDLE_TYPE); + TF_CALL_uint64(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopySliceToElement Unhandled data type: ", @@ -180,6 +182,8 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) { switch (parent->dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); + TF_CALL_uint32(HANDLE_TYPE); + TF_CALL_uint64(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented( diff --git a/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py b/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py index 69b3d9ffbfa39b..214ea508c56dd3 100644 --- a/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py +++ b/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py @@ -251,6 +251,21 @@ def testFromTensorSlicesMixedRagged(self): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + def testFromTensorSlicesWithUintDtypes(self): + components = ( + np.tile(np.array([[0], [1]], dtype=np.uint8), 2), + np.tile(np.array([[2], [256]], dtype=np.uint16), 2), + np.tile(np.array([[4], [65536]], dtype=np.uint32), 2), + np.tile(np.array([[8], [4294967296]], dtype=np.uint64), 2), + ) + expected_types = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) + expected_output = [tuple([c[i] for c in components]) for i in range(2)] + + dataset = dataset_ops.Dataset.from_tensor_slices(components) + self.assertEqual(expected_types, + dataset_ops.get_legacy_output_types(dataset)) + self.assertDatasetProduces(dataset, expected_output) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/unbatch_test.py b/tensorflow/python/data/kernel_tests/unbatch_test.py index 6bc8f442cf992f..16d6d8b8ebf9b9 100644 --- a/tensorflow/python/data/kernel_tests/unbatch_test.py +++ b/tensorflow/python/data/kernel_tests/unbatch_test.py @@ -184,6 +184,24 @@ def testSkipEagerUnbatchDynamicShapeMismatch(self): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(next_element) + def testUnbatchDatasetWithUintDtypes(self): + components = ( + np.tile(np.array([[0], [1], [2], [3]], dtype=np.uint8), 2), + np.tile(np.array([[1], [2], [3], [256]], dtype=np.uint16), 2), + np.tile(np.array([[2], [3], [4], [65536]], dtype=np.uint32), 2), + np.tile(np.array([[3], [4], [5], [4294967296]], dtype=np.uint64), 2), + ) + expected_types = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) + expected_output = [tuple([c[i] for c in components]) for i in range(4)] + + data = dataset_ops.Dataset.from_tensor_slices(components) + data = data.batch(2) + self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) + + data = data.unbatch() + self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data)) + self.assertDatasetProduces(data, expected_output) + if __name__ == "__main__": test.main()