Skip to content

tf.contrib.data.bucket_by_sequence_length fails for nested Dataset element #17932

@lhlmgr

Description

@lhlmgr

Hello everyone,

I just tried the new function to group variable length inputs for the dataset API, namely: tf.contrib.data.bucket_by_sequence_length, for a small Estimator-Model.

I implemented the input_fn such that it returns a dataset, where each element is a tuple (feature-dict, label). However, when I run it, I get following exception:

Traceback (most recent call last):
...
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 960, in apply
dataset = transformation_func(self)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 198, in _apply_fn
window_size_func=window_size_fn))
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 960, in apply
dataset = transformation_func(self)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 90, in _apply_fn
window_size_func)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 239, in init
self._make_key_func(key_func, input_dataset)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 289, in _make_key_func
self._key_func.add_to_graph(ops.get_default_graph())
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/function.py", line 488, in add_to_graph
self._create_definition_if_needed()
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl
outputs = self._func(*inputs)
File "/home/leo/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/grouping.py", line 279, in tf_key_func
ret = key_func(*nested_args)
TypeError: element_to_bucket_id() takes 1 positional argument but 2 were given

Here is a link to the function.

Here is a code snipped to reproduce the error:

import tensorflow as tf

def input_fn():
  def generator():
    text = [[1, 2, 3],
            [3, 4, 5, 6, 7],
            [1, 2],
            [8, 9, 0, 2, 3]]
    label = [1, 2, 1, 2]

    for x, y in zip(text, label):
      yield (x, y)

  dataset = tf.data.Dataset.from_generator(generator=generator,
                                           output_shapes=(tf.TensorShape([None]), tf.TensorShape([])),
                                           output_types=(tf.int32, tf.int32))

  dataset = dataset.map(parse_example)
  dataset = dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
                                                                    bucket_batch_sizes=[2, 2, 2],
                                                                    bucket_boundaries=[0, 8],
                                                                    pad_to_bucket_boundary=False))

  return dataset

def parse_example(x, y):
  return dict(
    x=x
  ), y

def element_length_fn(element):
  features, label = element
  return tf.shape(features["x"])[0]

if __name__ == '__main__':
  with tf.Session() as sess:
    dataset = input_fn()
    iter = dataset.make_one_shot_iterator()

    print(sess.run(iter.get_next()))

My Env-Specs are logged in: tf_env.txt

Thanks in advance!

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions