Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for handling uint32 and uint64 dtypes in batch_util CopySliceToElement and MaybeMoveSliceToElement #28776

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/data/dataset_test_base.cc
Expand Up @@ -54,6 +54,8 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
break;
TF_CALL_NUMBER_TYPES(CASE);
TF_CALL_string(CASE);
TF_CALL_uint32(CASE);
TF_CALL_uint64(CASE);
// TODO(feihugis): figure out how to support variant tensors.
#undef CASE
default:
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc
Expand Up @@ -76,17 +76,31 @@ TestCase PlainTensorTestCase() {
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to the test here, could you also add a Python-level test to from_tensor_slices_test.py? Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2, 2}),
{1, 2, 3, 4}),
DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({2}), {2, 3}),
DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({2, 2}),
{2, 3, 4, 5}),
DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({2}), {3, 4}),
DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({2, 2}),
{3, 4, 5, 6}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({2, 1}),
{37.0, 38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({2, 1}),
{"a", "b"})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {1, 2}),
DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({}), {2}),
DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({2}), {2, 3}),
DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({}), {3}),
DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({2}), {3, 4}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {37.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"a"}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2}),
DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {3, 4}),
DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({}), {3}),
DatasetOpsTestBase::CreateTensor<uint32>(TensorShape({2}), {4, 5}),
DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({}), {4}),
DatasetOpsTestBase::CreateTensor<uint64>(TensorShape({2}), {5, 6}),
DatasetOpsTestBase::CreateTensor<double>(TensorShape({1}), {38.0}),
DatasetOpsTestBase::CreateTensor<string>(TensorShape({1}), {"b"})},
/*breakpoints*/ {0, 1, 3}};
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/util/batch_util.cc
Expand Up @@ -156,6 +156,8 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
switch (parent.dtype()) {
TF_CALL_ALL_TYPES(HANDLE_TYPE);
TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
TF_CALL_uint32(HANDLE_TYPE);
TF_CALL_uint64(HANDLE_TYPE);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_util::MaybeMoveSliceToElement is also missing these cases, I didn't add them here because I haven't verified it's ok

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can go ahead and add them there too. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can verify that this works by adding a test to unbatch_test.py that uses uint32 / uint64 elements.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will do

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#undef HANDLE_TYPE
default:
return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
Expand All @@ -180,6 +182,8 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) {
switch (parent->dtype()) {
TF_CALL_ALL_TYPES(HANDLE_TYPE);
TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
TF_CALL_uint32(HANDLE_TYPE);
TF_CALL_uint64(HANDLE_TYPE);
#undef HANDLE_TYPE
default:
return errors::Unimplemented(
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/data/kernel_tests/from_tensor_slices_test.py
Expand Up @@ -251,6 +251,21 @@ def testFromTensorSlicesMixedRagged(self):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())

def testFromTensorSlicesWithUintDtypes(self):
components = (
np.tile(np.array([[0], [1]], dtype=np.uint8), 2),
np.tile(np.array([[2], [256]], dtype=np.uint16), 2),
np.tile(np.array([[4], [65536]], dtype=np.uint32), 2),
np.tile(np.array([[8], [4294967296]], dtype=np.uint64), 2),
)
expected_types = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
expected_output = [tuple([c[i] for c in components]) for i in range(2)]

dataset = dataset_ops.Dataset.from_tensor_slices(components)
self.assertEqual(expected_types,
dataset_ops.get_legacy_output_types(dataset))
self.assertDatasetProduces(dataset, expected_output)


if __name__ == "__main__":
test.main()
18 changes: 18 additions & 0 deletions tensorflow/python/data/kernel_tests/unbatch_test.py
Expand Up @@ -184,6 +184,24 @@ def testSkipEagerUnbatchDynamicShapeMismatch(self):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(next_element)

def testUnbatchDatasetWithUintDtypes(self):
components = (
np.tile(np.array([[0], [1], [2], [3]], dtype=np.uint8), 2),
np.tile(np.array([[1], [2], [3], [256]], dtype=np.uint16), 2),
np.tile(np.array([[2], [3], [4], [65536]], dtype=np.uint32), 2),
np.tile(np.array([[3], [4], [5], [4294967296]], dtype=np.uint64), 2),
)
expected_types = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
expected_output = [tuple([c[i] for c in components]) for i in range(4)]

data = dataset_ops.Dataset.from_tensor_slices(components)
data = data.batch(2)
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))

data = data.unbatch()
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
self.assertDatasetProduces(data, expected_output)


if __name__ == "__main__":
test.main()