Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rolling window batch operation for tf.data.Dataset #16123

Merged
merged 23 commits into from
Mar 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorflow/contrib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
@@rejection_resample
@@scan
@@shuffle_and_repeat
@@sliding_window_batch
@@sloppy_interleave
@@unbatch

Expand Down Expand Up @@ -67,6 +68,9 @@
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
from tensorflow.python.data.ops.iterator_ops import Iterator
from tensorflow.python.ops.parsing_ops import parse_single_example_v2 as parse_single_example
# pylint: enable=unused-import

from tensorflow.python.util.all_util import remove_undocumented
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/contrib/data/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,23 @@ py_test(
],
)

tf_py_test(
name = "slide_dataset_op_test",
size = "small",
srcs = ["slide_dataset_op_test.py"],
additional_deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
"//third_party/py/numpy",
],
)

filegroup(
name = "all_files",
srcs = glob(
Expand Down
242 changes: 242 additions & 0 deletions tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.contrib.data.python.ops import sliding
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test


class SlideDatasetTest(test.TestCase):

def testSlideDataset(self):
"""Test an dataset that maps a TF function across its input elements."""
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7))

count = array_ops.placeholder(dtypes.int64, shape=[])
window_size = array_ops.placeholder(dtypes.int64, shape=[])
stride = array_ops.placeholder(dtypes.int64, shape=[])

def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)

# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(count) -> _SlideDataset(window_size, stride).
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.map(_map_fn)
.repeat(count)
.apply(sliding.sliding_window_batch(window_size, stride))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()

self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])

with self.test_session() as sess:
# Slide over a finite input, where the window_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7})
# Same formula with convolution layer.
num_batches = (20 * 7 - 14) // 7 + 1
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i*7 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

# Slide over a finite input, where the window_size does not
# divide the total number of elements.
sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9})

num_batches = (20 * 7 - 17) // 9 + 1
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(17):
self.assertAllEqual(component[(i*9 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

# Slide over a finite input, which is less than window_size,
# should fail straight away.
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

# Slide over an empty input should fail straight away.
sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

# Empty window_size should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0})

# Invalid stride should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5})

def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices)
self.assertAllEqual(a.values, b.values)
self.assertAllEqual(a.dense_shape, b.dense_shape)

def testSlideSparse(self):

def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])

iterator = dataset_ops.Dataset.range(10).map(_sparse).apply(
sliding.sliding_window_batch(5, 3)).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()

with self.test_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
dense_shape=[5, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

def testSlideSparseWithDifferentDenseShapes(self):

def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=array_ops.expand_dims(
math_ops.range(i, dtype=dtypes.int64), 1),
values=array_ops.fill([math_ops.to_int32(i)], i),
dense_shape=[i])

iterator = dataset_ops.Dataset.range(10).map(_sparse).apply(
sliding.sliding_window_batch(5, 3)).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()

with self.test_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
actual = sess.run(get_next)
expected_indices = []
expected_values = []
for j in range(5):
for k in range(i * 3 + j):
expected_indices.append([j, k])
expected_values.append(i * 3 + j)
expected = sparse_tensor.SparseTensorValue(
indices=expected_indices,
values=expected_values,
dense_shape=[5, i * 3 + 5 - 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

def testNestedSlideSparse(self):

def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])

iterator = (dataset_ops.Dataset.range(10)
.map(_sparse)
.apply(sliding.sliding_window_batch(4, 2))
.apply(sliding.sliding_window_batch(3, 1))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()

with self.test_session() as sess:
sess.run(init_op)
# Slide: 1st batch.
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0],
[2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]],
values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
dense_shape=[3, 4, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
# Slide: 2nd batch.
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0],
[2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]],
values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
dense_shape=[3, 4, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

def testSlideShapeError(self):

def generator():
yield [1.0, 2.0, 3.0]
yield [4.0, 5.0, 6.0]
yield [7.0, 8.0, 9.0, 10.0]

iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32,
output_shapes=[None])
.apply(sliding.sliding_window_batch(3, 1))
.make_initializable_iterator())
next_element = iterator.get_next()

with self.test_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r"Cannot batch tensors with different shapes in component 0. "
r"First element had shape \[3\] and element 2 had shape \[4\]."):
sess.run(next_element)


if __name__ == "__main__":
test.main()
1 change: 1 addition & 0 deletions tensorflow/contrib/data/python/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ py_library(
"interleave_ops.py",
"resampling.py",
"scan_ops.py",
"sliding.py",
"stats_ops.py",
"unique.py",
],
Expand Down
102 changes: 102 additions & 0 deletions tensorflow/contrib/data/python/ops/sliding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sliding dataset transformations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops


class _SlideDataset(dataset_ops.Dataset):
"""A `Dataset` that passes a sliding window over its input."""

def __init__(self, input_dataset, window_size, stride=1):
"""See `sliding_window_batch` for details."""
super(_SlideDataset, self).__init__()
self._input_dataset = input_dataset
self._window_size = ops.convert_to_tensor(
window_size, dtype=dtypes.int64, name="window_size")
self._stride = ops.convert_to_tensor(
stride, dtype=dtypes.int64, name="stride")

def _as_variant_tensor(self):
return gen_dataset_ops.slide_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
window_size=self._window_size,
stride=self._stride,
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)))

@property
def output_classes(self):
return self._input_dataset.output_classes

@property
def output_shapes(self):
input_shapes = self._input_dataset.output_shapes
return nest.pack_sequence_as(input_shapes, [
tensor_shape.vector(None).concatenate(s)
for s in nest.flatten(self._input_dataset.output_shapes)
])

@property
def output_types(self):
return self._input_dataset.output_types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a function to this file that creates a SlideDataset can be used with tf.data.Dataset.apply()? I'd be inclined to give it a slightly more verbose name than "slide", e.g. sliding_window_batch():.

See e.g. the implementation of tf.contrib.data.map_and_batch():

def map_and_batch(map_func, batch_size, num_parallel_batches=1):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@facaiy did you address this comment? Thank you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drpngx Yes, sliding_window_batch has been added as suggested. Some doc build broken, and I fixed the failed tests. Could you help restart all tests? Thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay... this comment is addressed now, thanks!



def sliding_window_batch(window_size, stride=1):
"""A sliding window with size of `window_size` and step of `stride`.

This transformation passes a sliding window over this dataset. The
window size is `window_size` and step size is `stride`. If the left
elements cannot fill up the sliding window, this transformation will
drop the final smaller element. For example:

```python
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { [1], [2], [3], [4], [5], [6] }

a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) ==
{
[[1], [2], [3]],
[[3], [4], [5]],
}
```

Args:
window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
elements in the sliding window.
stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
steps moving the sliding window forward for one iteration. The default
is `1`. It must be in `[1, window_size)`.

Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
return _SlideDataset(dataset, window_size, stride)

return _apply_fn