-
Notifications
You must be signed in to change notification settings - Fork 74.9k
Description
Hello everyone,
I just tried the new function to group variable length inputs for the dataset API, namely: tf.contrib.data.bucket_by_sequence_length
, for a small Estimator-Model.
I implemented the input_fn
such that it returns a dataset, where each element is a tuple (feature-dict, label). However, when I run it, I get following exception:
Traceback (most recent call last):
...
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 960, in apply
dataset = transformation_func(self)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 198, in _apply_fn
window_size_func=window_size_fn))
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 960, in apply
dataset = transformation_func(self)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 90, in _apply_fn
window_size_func)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 239, in init
self._make_key_func(key_func, input_dataset)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 289, in _make_key_func
self._key_func.add_to_graph(ops.get_default_graph())
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/function.py", line 488, in add_to_graph
self._create_definition_if_needed()
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl
outputs = self._func(*inputs)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 279, in tf_key_func
ret = key_func(*nested_args)
TypeError: element_to_bucket_id() takes 1 positional argument but 2 were given
Here is a link to the function.
Here is a code snipped to reproduce the error:
import tensorflow as tf
def input_fn():
def generator():
text = [[1, 2, 3],
[3, 4, 5, 6, 7],
[1, 2],
[8, 9, 0, 2, 3]]
label = [1, 2, 1, 2]
for x, y in zip(text, label):
yield (x, y)
dataset = tf.data.Dataset.from_generator(generator=generator,
output_shapes=(tf.TensorShape([None]), tf.TensorShape([])),
output_types=(tf.int32, tf.int32))
dataset = dataset.map(parse_example)
dataset = dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
bucket_batch_sizes=[2, 2, 2],
bucket_boundaries=[0, 8],
pad_to_bucket_boundary=False))
return dataset
def parse_example(x, y):
return dict(
x=x
), y
def element_length_fn(element):
features, label = element
return tf.shape(features["x"])[0]
if __name__ == '__main__':
with tf.Session() as sess:
dataset = input_fn()
iter = dataset.make_one_shot_iterator()
print(sess.run(iter.get_next()))
My Env-Specs are logged in: tf_env.txt
Thanks in advance!