Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow stride > window_size for sliding_window_batch #20223

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _map_fn(x, y, z):
[t.shape.as_list() for t in get_next])

with self.test_session() as sess:
# stride < window_size.
# Slide over a finite input, where the window_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7})
Expand All @@ -71,11 +72,9 @@ def _map_fn(x, y, z):
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

# Slide over a finite input, where the window_size does not
# divide the total number of elements.
sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9})

num_batches = (20 * 7 - 17) // 9 + 1
for i in range(num_batches):
result = sess.run(get_next)
Expand All @@ -86,6 +85,41 @@ def _map_fn(x, y, z):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

# stride == window_size.
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 14})
num_batches = 20 * 7 // 14
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i*14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

# stride > window_size.
sess.run(init_op, feed_dict={count: 20, window_size: 10, stride: 14})
num_batches = 20 * 7 // 14
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(10):
self.assertAllEqual(component[(i*14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Drop the last batch which is smaller than window_size.
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 19})
num_batches = (20 * 7 - 7) // 19 # = 19 * 7 // 19
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i*19 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

# Slide over a finite input, which is less than window_size,
# should fail straight away.
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4})
Expand All @@ -108,10 +142,6 @@ def _map_fn(x, y, z):
# Invalid stride should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5})

def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/data/python/ops/sliding.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def sliding_window_batch(window_size, stride=1):
elements in the sliding window.
stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
steps moving the sliding window forward for one iteration. The default
is `1`. It must be in `[1, window_size)`.
is `1`. It must be positive.

Returns:
A `Dataset` transformation function, which can be passed to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ END
name: "stride"
description: <<END
A scalar representing the steps moving the sliding window
forward in one iteration. It must be in `[1, window_size)`.
forward in one iteration. It must be positive.
END
}
summary: "Creates a dataset that passes a sliding window over `input_dataset`."
Expand Down
50 changes: 38 additions & 12 deletions tensorflow/core/kernels/data/slide_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/batch_util.h"

namespace tensorflow {
Expand All @@ -32,16 +33,25 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
int64 stride = 1;
int64 stride = 0;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
OP_REQUIRES(
ctx, stride > 0 && stride < window_size,
errors::InvalidArgument("Stride must be in [1, window_size)."));
ctx, stride > 0,
errors::InvalidArgument("Stride must be greater than zero."));
if (stride == window_size) {
LOG(WARNING) << "stride: " << stride
<< " is equal to window_size: " << window_size
<< ", to use `batch` instead.";
} else if (stride > window_size) {
LOG(WARNING) << "stride: " << stride
<< " is greater than window_size: " << window_size
<< ", you will lose some data.";
}

*output = new Dataset(ctx, window_size, stride, input);
}
Expand Down Expand Up @@ -124,12 +134,15 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
batch_elements.reserve(window_size);
const bool first_call = cache_.empty();
if (first_call) {
cache_.reserve(window_size);
} else {
// Reuse cache in the previous iteration.
cache_.swap(batch_elements);
// Use cache if stride < window_size.
if (stride < window_size) {
const bool first_call = cache_.empty();
if (first_call) {
cache_.reserve(window_size);
} else {
// Reuse cache in the previous iteration.
cache_.swap(batch_elements);
}
}
// Fill up with new elements.
*end_of_sequence = false;
Expand All @@ -149,9 +162,22 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
DCHECK(*end_of_sequence);
return Status::OK();
}
// Cache the data used for the next iteration.
for (size_t i = stride; i < window_size; ++i) {
cache_.emplace_back(batch_elements[i]);

if (stride < window_size) {
// Cache the data used for the next iteration.
for (size_t i = stride; i < window_size; ++i) {
cache_.emplace_back(batch_elements[i]);
}
} else if (stride > window_size) {
// Drop the data before the next iteration.
std::vector<Tensor> batch_element_tuple;
for (size_t i = window_size; i < stride && !*end_of_sequence; ++i) {
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
end_of_sequence));
if (*end_of_sequence) {
input_impl_.reset();
}
}
}
}

Expand Down