Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add assert_element_shape method for tf.contrib.data #17480

Merged
merged 4 commits into from
Apr 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/contrib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

@@Counter

@@assert_element_shape
@@batch_and_drop_remainder
@@bucket_by_sequence_length
@@dense_to_sparse_batch
Expand All @@ -50,6 +51,7 @@

# pylint: disable=unused-import

from tensorflow.contrib.data.python.ops.batching import assert_element_shape
from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder
from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch
from tensorflow.contrib.data.python.ops.batching import map_and_batch
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/data/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test

Expand Down Expand Up @@ -539,5 +541,73 @@ def fill_tuple(x):
lambda: build_dataset(seq_lens2), 8)


class RestructuredDatasetTest(test.TestCase):

def test_assert_element_shape(self):

def create_unknown_shape_dataset(x):
return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
np.zeros((3, 4), dtype=np.int32)),
[x],
[dtypes.float32, dtypes.int32])

dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
unknown_shapes = (tensor_shape.TensorShape(None),
tensor_shape.TensorShape(None))
self.assertEqual(unknown_shapes, dataset.output_shapes)

expected_shapes = (tensor_shape.TensorShape(2),
tensor_shape.TensorShape((3, 4)))
result = dataset.apply(batching.assert_element_shape(expected_shapes))
self.assertEqual(expected_shapes, result.output_shapes)

iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

def test_assert_wrong_element_shape(self):

def create_dataset(_):
return (array_ops.ones(2, dtype=dtypes.float32),
array_ops.zeros((3, 4), dtype=dtypes.int32))

dataset = dataset_ops.Dataset.range(3).map(create_dataset)
wrong_shapes = (tensor_shape.TensorShape(2),
tensor_shape.TensorShape((3, 10)))
with self.assertRaises(ValueError):
dataset.apply(batching.assert_element_shape(wrong_shapes))

def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):

def create_unknown_shape_dataset(x):
return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
np.zeros((3, 4), dtype=np.int32)),
[x],
[dtypes.float32, dtypes.int32])

dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
unknown_shapes = (tensor_shape.TensorShape(None),
tensor_shape.TensorShape(None))
self.assertEqual(unknown_shapes, dataset.output_shapes)

wrong_shapes = (tensor_shape.TensorShape(2),
tensor_shape.TensorShape((3, 10)))
iterator = (
dataset.apply(batching.assert_element_shape(wrong_shapes))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)


if __name__ == "__main__":
test.main()
1 change: 1 addition & 0 deletions tensorflow/contrib/data/python/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ py_library(
deps = [
":contrib_op_loader",
":gen_dataset_ops",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
Expand Down
40 changes: 40 additions & 0 deletions tensorflow/contrib/data/python/ops/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.framework import with_shape
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
Expand Down Expand Up @@ -345,6 +346,45 @@ def output_shapes(self):
return self._output_shapes


def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.

```python
shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)]
result = dataset.apply(tf.contrib.data.assert_element_shape(shapes))
print(result.output_shapes) # ==> "((16, 256), <unknown>)"
```

If dataset shapes and expected_shape, are fully defined, assert they match.
Otherwise, add assert op that will validate the shapes when tensors are
evaluated, and set shapes on tensors, respectively.

Args:
expected_shapes: A nested structure of `tf.TensorShape` objects.

Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}
"""

def _check_shape(*elements):
flatten_tensors = nest.flatten(elements)
flatten_shapes = nest.flatten(expected_shapes)
checked_tensors = [with_shape(shape, tensor)
for shape, tensor in zip(flatten_shapes,
flatten_tensors)]
return nest.pack_sequence_as(elements, checked_tensors)

def _apply_fn(dataset):
return _RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
output_shapes=expected_shapes,
output_classes=dataset.output_classes)

return _apply_fn


class _MapAndBatchDataset(dataset_ops.MapDataset):
"""A `Dataset` that maps a function over a batch of elements."""

Expand Down