Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.contrib.data.prefetch_to_device not compatible with tf.data.Iterator.from_structure #19244

Closed
RunOrVeith opened this issue May 12, 2018 · 9 comments
Assignees
Labels
stat:awaiting response Status - Awaiting response from author type:support Support issues

Comments

@RunOrVeith
Copy link

RunOrVeith commented May 12, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
    Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    Linux Ubuntu 16.04
  • TensorFlow installed from (source or binary):
    tensorflow-gpu binary
  • Bazel Version:
    N/A
  • TensorFlow version (use command below):
    v1.8.0-0-g93bc2e2072 1.8.0
  • Python version:
    3.6.3
  • CUDA/cuDNN version:
    CUDA 9.0 cuDNN 7.0.3
  • GPU model and memory:
    GTX 1070 8 GB VRAM
  • Exact command to reproduce:
import tensorflow as tf

class MyData(object):
    def __call__(self):
         return range(100)

expected_shapes = []
expected_types = tf.int32
iterator = tf.data.Iterator.from_structure(output_types=expected_types, output_shapes=expected_shapes)
dataset = tf.data.Dataset.from_generator(MyData(), output_types=expected_types, output_shapes=expected_shapes)

prefetch_op = tf.contrib.data.prefetch_to_device(device="/gpu:0")
dataset = dataset.apply(prefetch_op)
initializer = iterator.make_initializer(dataset)

Describe the problem

This raises NotImplementedError: prefetch_to_device() must be the last transformation in a dataset pipeline.

It is not possible to apply this to the dataset after the initializer has been created, since a new dataset is returned, instead of it being modified in place.

If one reads through this testcase, it is clear that it works when creating the iterator from the dataset.

It is not clear from the documentation of make_initializer that this function is a transformation of the dataset and thus counts as an additional step after prefetching.
I am not sure if this is a bug/was overlooked, or is known to be not implemented.

Proposed short term solution:

  1. Mention in the documentation of prefetch_to_device, that it is not supported in combination with make_initializer.
  2. Mention in the documentation of make_initializer that this operation modifies the dataset
    (although I don't think this is the correct choice of words, the issue is with a call to dataset._as_variant_tensor in make_initializer line 308).

Proposed longterm solution:

  1. This is already a TODO in line 289 of prefetching_ops.py:
    Implement _as_variant_tensor for _PrefetchToDeviceDataset.

Reason why this is needed:

Creating the data pipeline using from_structure and make_initializer allows to dynamically switch the input source to the network, e.g. between training and testing set after an epoch without having to reinitialize the graph or fall back to using feed dicts.

Source code / logs

Exact stack trace of the error:

Traceback (most recent call last):
  File "test.py", line 14, in <module>
    initializer = iterator.make_initializer(dataset)
  File "/home/veith/.pyenv/versions/3.6.3/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 308, in make_initializer
    dataset._as_variant_tensor(), self._iterator_resource, name=name)  # pylint: disable=protected-access
  File "/home/veith/.pyenv/versions/3.6.3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/prefetching_ops.py", line 291, in _as_variant_tensor
    raise NotImplementedError("`prefetch_to_device()` must be the last "
NotImplementedError: `prefetch_to_device()` must be the last transformation in a dataset pipeline.
@tensorflowbutler tensorflowbutler added the stat:awaiting response Status - Awaiting response from author label May 12, 2018
@tensorflowbutler
Copy link
Member

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks.
Bazel version

@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and the awaiting response label was assigned. Is this still an issue?

@RunOrVeith
Copy link
Author

Yes, this is still an issue.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label May 29, 2018
@facaiy
Copy link
Member

facaiy commented May 29, 2018

cc @mrry, I think the authority on this issue.

@mrry
Copy link
Contributor

mrry commented May 29, 2018

@rohan100jain is working on a solution to this.

@tensorflowbutler
Copy link
Member

Nagging Assignee @rohan100jain: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee @rohan100jain: It has been 29 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee @rohan100jain: It has been 44 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@rohan100jain
Copy link
Member

Please use CopyToDevice + Prefetch instead of using prefetch_to_device directly.

for example

ds = ...
ds = ds.apply(prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)

This should give you a regular dataset and all the other iterator / dataset support that comes with it.

@rohan100jain rohan100jain added type:support Support issues stat:awaiting response Status - Awaiting response from author labels Jul 26, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response Status - Awaiting response from author type:support Support issues
Projects
None yet
Development

No branches or pull requests

5 participants