-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add rolling window batch operation for tf.data.Dataset #16123
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
93b3d7b
ENH: add slide_dataset_op
facaiy ce17bcb
TST: add test case
facaiy f4683a3
DOC: add docment
facaiy 644c5c2
CLN: implement sliding_window_batch
facaiy dfe2762
CLN: hiddent SlideDataset
facaiy f02ff45
CLN: remove Dataset.slide
facaiy 4606b35
DOC: 2017 -> 2018
facaiy a037f8f
CLN: use push_back
facaiy a7dfbd8
DOC: drop the final smaller block
facaiy f9b427e
CLN: rename slide_size -> window_size
facaiy 0e2d542
CLN: rename slide_step -> stride
facaiy 83d2a85
DOC: no default for stride at c++ side
facaiy 3191291
DOC: revise comments
facaiy d0400fc
BLD: expose sliding_window_batch API
facaiy c448b6e
CLN: code style
facaiy 8960be1
DOC: revise documents
facaiy 91d4c0c
Merged from upstream/master
facaiy 002edc1
CLN: move to IteratorContext
facaiy 13d3797
TST: remove contrib.dataset_ops
facaiy d8a26df
Merge remote-tracking branch 'upstream/master' into ENH/rolling_window
facaiy 08421ab
DOC: move desp to api def
facaiy e7f2703
CLN: fix python 2 indent
facaiy cf49555
DOC: used by core.apply method
facaiy File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
242 changes: 242 additions & 0 deletions
242
tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for the experimental input pipeline ops.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
|
||
from tensorflow.contrib.data.python.ops import sliding | ||
from tensorflow.python.data.ops import dataset_ops | ||
from tensorflow.python.framework import dtypes | ||
from tensorflow.python.framework import errors | ||
from tensorflow.python.framework import sparse_tensor | ||
from tensorflow.python.ops import array_ops | ||
from tensorflow.python.ops import math_ops | ||
from tensorflow.python.platform import test | ||
|
||
|
||
class SlideDatasetTest(test.TestCase): | ||
|
||
def testSlideDataset(self): | ||
"""Test an dataset that maps a TF function across its input elements.""" | ||
components = (np.arange(7), | ||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], | ||
np.array(37.0) * np.arange(7)) | ||
|
||
count = array_ops.placeholder(dtypes.int64, shape=[]) | ||
window_size = array_ops.placeholder(dtypes.int64, shape=[]) | ||
stride = array_ops.placeholder(dtypes.int64, shape=[]) | ||
|
||
def _map_fn(x, y, z): | ||
return math_ops.square(x), math_ops.square(y), math_ops.square(z) | ||
|
||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) -> | ||
# RepeatDataset(count) -> _SlideDataset(window_size, stride). | ||
iterator = (dataset_ops.Dataset.from_tensor_slices(components) | ||
.map(_map_fn) | ||
.repeat(count) | ||
.apply(sliding.sliding_window_batch(window_size, stride)) | ||
.make_initializable_iterator()) | ||
init_op = iterator.initializer | ||
get_next = iterator.get_next() | ||
|
||
self.assertEqual([[None] + list(c.shape[1:]) for c in components], | ||
[t.shape.as_list() for t in get_next]) | ||
|
||
with self.test_session() as sess: | ||
# Slide over a finite input, where the window_size divides the | ||
# total number of elements. | ||
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) | ||
# Same formula with convolution layer. | ||
num_batches = (20 * 7 - 14) // 7 + 1 | ||
for i in range(num_batches): | ||
result = sess.run(get_next) | ||
for component, result_component in zip(components, result): | ||
for j in range(14): | ||
self.assertAllEqual(component[(i*7 + j) % 7]**2, | ||
result_component[j]) | ||
with self.assertRaises(errors.OutOfRangeError): | ||
sess.run(get_next) | ||
|
||
# Slide over a finite input, where the window_size does not | ||
# divide the total number of elements. | ||
sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) | ||
|
||
num_batches = (20 * 7 - 17) // 9 + 1 | ||
for i in range(num_batches): | ||
result = sess.run(get_next) | ||
for component, result_component in zip(components, result): | ||
for j in range(17): | ||
self.assertAllEqual(component[(i*9 + j) % 7]**2, | ||
result_component[j]) | ||
with self.assertRaises(errors.OutOfRangeError): | ||
sess.run(get_next) | ||
|
||
# Slide over a finite input, which is less than window_size, | ||
# should fail straight away. | ||
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) | ||
with self.assertRaises(errors.OutOfRangeError): | ||
sess.run(get_next) | ||
|
||
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8}) | ||
with self.assertRaises(errors.OutOfRangeError): | ||
sess.run(get_next) | ||
|
||
# Slide over an empty input should fail straight away. | ||
sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4}) | ||
with self.assertRaises(errors.OutOfRangeError): | ||
sess.run(get_next) | ||
|
||
# Empty window_size should be an initialization time error. | ||
with self.assertRaises(errors.InvalidArgumentError): | ||
sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0}) | ||
|
||
# Invalid stride should be an initialization time error. | ||
with self.assertRaises(errors.InvalidArgumentError): | ||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) | ||
with self.assertRaises(errors.InvalidArgumentError): | ||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3}) | ||
with self.assertRaises(errors.InvalidArgumentError): | ||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5}) | ||
|
||
def assertSparseValuesEqual(self, a, b): | ||
self.assertAllEqual(a.indices, b.indices) | ||
self.assertAllEqual(a.values, b.values) | ||
self.assertAllEqual(a.dense_shape, b.dense_shape) | ||
|
||
def testSlideSparse(self): | ||
|
||
def _sparse(i): | ||
return sparse_tensor.SparseTensorValue( | ||
indices=[[0]], values=(i * [1]), dense_shape=[1]) | ||
|
||
iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( | ||
sliding.sliding_window_batch(5, 3)).make_initializable_iterator() | ||
init_op = iterator.initializer | ||
get_next = iterator.get_next() | ||
|
||
with self.test_session() as sess: | ||
sess.run(init_op) | ||
num_batches = (10 - 5) // 3 + 1 | ||
for i in range(num_batches): | ||
actual = sess.run(get_next) | ||
expected = sparse_tensor.SparseTensorValue( | ||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], | ||
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4], | ||
dense_shape=[5, 1]) | ||
self.assertTrue(sparse_tensor.is_sparse(actual)) | ||
self.assertSparseValuesEqual(actual, expected) | ||
with self.assertRaises(errors.OutOfRangeError): | ||
sess.run(get_next) | ||
|
||
def testSlideSparseWithDifferentDenseShapes(self): | ||
|
||
def _sparse(i): | ||
return sparse_tensor.SparseTensorValue( | ||
indices=array_ops.expand_dims( | ||
math_ops.range(i, dtype=dtypes.int64), 1), | ||
values=array_ops.fill([math_ops.to_int32(i)], i), | ||
dense_shape=[i]) | ||
|
||
iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( | ||
sliding.sliding_window_batch(5, 3)).make_initializable_iterator() | ||
init_op = iterator.initializer | ||
get_next = iterator.get_next() | ||
|
||
with self.test_session() as sess: | ||
sess.run(init_op) | ||
num_batches = (10 - 5) // 3 + 1 | ||
for i in range(num_batches): | ||
actual = sess.run(get_next) | ||
expected_indices = [] | ||
expected_values = [] | ||
for j in range(5): | ||
for k in range(i * 3 + j): | ||
expected_indices.append([j, k]) | ||
expected_values.append(i * 3 + j) | ||
expected = sparse_tensor.SparseTensorValue( | ||
indices=expected_indices, | ||
values=expected_values, | ||
dense_shape=[5, i * 3 + 5 - 1]) | ||
self.assertTrue(sparse_tensor.is_sparse(actual)) | ||
self.assertSparseValuesEqual(actual, expected) | ||
with self.assertRaises(errors.OutOfRangeError): | ||
sess.run(get_next) | ||
|
||
def testNestedSlideSparse(self): | ||
|
||
def _sparse(i): | ||
return sparse_tensor.SparseTensorValue( | ||
indices=[[0]], values=(i * [1]), dense_shape=[1]) | ||
|
||
iterator = (dataset_ops.Dataset.range(10) | ||
.map(_sparse) | ||
.apply(sliding.sliding_window_batch(4, 2)) | ||
.apply(sliding.sliding_window_batch(3, 1)) | ||
.make_initializable_iterator()) | ||
init_op = iterator.initializer | ||
get_next = iterator.get_next() | ||
|
||
with self.test_session() as sess: | ||
sess.run(init_op) | ||
# Slide: 1st batch. | ||
actual = sess.run(get_next) | ||
expected = sparse_tensor.SparseTensorValue( | ||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], | ||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], | ||
[2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], | ||
values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], | ||
dense_shape=[3, 4, 1]) | ||
self.assertTrue(sparse_tensor.is_sparse(actual)) | ||
self.assertSparseValuesEqual(actual, expected) | ||
# Slide: 2nd batch. | ||
actual = sess.run(get_next) | ||
expected = sparse_tensor.SparseTensorValue( | ||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], | ||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], | ||
[2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], | ||
values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], | ||
dense_shape=[3, 4, 1]) | ||
self.assertTrue(sparse_tensor.is_sparse(actual)) | ||
self.assertSparseValuesEqual(actual, expected) | ||
with self.assertRaises(errors.OutOfRangeError): | ||
sess.run(get_next) | ||
|
||
def testSlideShapeError(self): | ||
|
||
def generator(): | ||
yield [1.0, 2.0, 3.0] | ||
yield [4.0, 5.0, 6.0] | ||
yield [7.0, 8.0, 9.0, 10.0] | ||
|
||
iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32, | ||
output_shapes=[None]) | ||
.apply(sliding.sliding_window_batch(3, 1)) | ||
.make_initializable_iterator()) | ||
next_element = iterator.get_next() | ||
|
||
with self.test_session() as sess: | ||
sess.run(iterator.initializer) | ||
with self.assertRaisesRegexp( | ||
errors.InvalidArgumentError, | ||
r"Cannot batch tensors with different shapes in component 0. " | ||
r"First element had shape \[3\] and element 2 had shape \[4\]."): | ||
sess.run(next_element) | ||
|
||
|
||
if __name__ == "__main__": | ||
test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Sliding dataset transformations.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from tensorflow.python.data.ops import dataset_ops | ||
from tensorflow.python.data.util import nest | ||
from tensorflow.python.data.util import sparse | ||
from tensorflow.python.framework import dtypes | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.framework import tensor_shape | ||
from tensorflow.python.ops import gen_dataset_ops | ||
|
||
|
||
class _SlideDataset(dataset_ops.Dataset): | ||
"""A `Dataset` that passes a sliding window over its input.""" | ||
|
||
def __init__(self, input_dataset, window_size, stride=1): | ||
"""See `sliding_window_batch` for details.""" | ||
super(_SlideDataset, self).__init__() | ||
self._input_dataset = input_dataset | ||
self._window_size = ops.convert_to_tensor( | ||
window_size, dtype=dtypes.int64, name="window_size") | ||
self._stride = ops.convert_to_tensor( | ||
stride, dtype=dtypes.int64, name="stride") | ||
|
||
def _as_variant_tensor(self): | ||
return gen_dataset_ops.slide_dataset( | ||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access | ||
window_size=self._window_size, | ||
stride=self._stride, | ||
output_shapes=nest.flatten( | ||
sparse.as_dense_shapes(self.output_shapes, self.output_classes)), | ||
output_types=nest.flatten( | ||
sparse.as_dense_types(self.output_types, self.output_classes))) | ||
|
||
@property | ||
def output_classes(self): | ||
return self._input_dataset.output_classes | ||
|
||
@property | ||
def output_shapes(self): | ||
input_shapes = self._input_dataset.output_shapes | ||
return nest.pack_sequence_as(input_shapes, [ | ||
tensor_shape.vector(None).concatenate(s) | ||
for s in nest.flatten(self._input_dataset.output_shapes) | ||
]) | ||
|
||
@property | ||
def output_types(self): | ||
return self._input_dataset.output_types | ||
|
||
|
||
def sliding_window_batch(window_size, stride=1): | ||
"""A sliding window with size of `window_size` and step of `stride`. | ||
|
||
This transformation passes a sliding window over this dataset. The | ||
window size is `window_size` and step size is `stride`. If the left | ||
elements cannot fill up the sliding window, this transformation will | ||
drop the final smaller element. For example: | ||
|
||
```python | ||
# NOTE: The following examples use `{ ... }` to represent the | ||
# contents of a dataset. | ||
a = { [1], [2], [3], [4], [5], [6] } | ||
|
||
a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) == | ||
{ | ||
[[1], [2], [3]], | ||
[[3], [4], [5]], | ||
} | ||
``` | ||
|
||
Args: | ||
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of | ||
elements in the sliding window. | ||
stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the | ||
steps moving the sliding window forward for one iteration. The default | ||
is `1`. It must be in `[1, window_size)`. | ||
|
||
Returns: | ||
A `Dataset` transformation function, which can be passed to | ||
@{tf.data.Dataset.apply}. | ||
""" | ||
def _apply_fn(dataset): | ||
return _SlideDataset(dataset, window_size, stride) | ||
|
||
return _apply_fn |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a function to this file that creates a
SlideDataset
can be used withtf.data.Dataset.apply()
? I'd be inclined to give it a slightly more verbose name than "slide", e.g.sliding_window_batch():
.See e.g. the implementation of
tf.contrib.data.map_and_batch()
:tensorflow/tensorflow/contrib/data/python/ops/batching.py
Line 387 in 23ed83d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@facaiy did you address this comment? Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drpngx Yes,
sliding_window_batch
has been added as suggested. Some doc build broken, and I fixed the failed tests. Could you help restart all tests? Thank you.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay... this comment is addressed now, thanks!