Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow: how to save/restore tf.data.Dataset? #15019

Closed
taehyunkim1527 opened this issue Dec 1, 2017 · 8 comments
Closed

Tensorflow: how to save/restore tf.data.Dataset? #15019

taehyunkim1527 opened this issue Dec 1, 2017 · 8 comments
Assignees
Labels
stat:awaiting response Status - Awaiting response from author type:bug Bug

Comments

@taehyunkim1527
Copy link

I made a model with tf.data.Dataset() as a data IO function

then i exported the graph and tried to restore it with meta_graph file But it failed and following error messages occurred.

I think that tf.data.Dataset() made a C++ object instead of python queue used before.

And the graph_def only has a C++ object handler reference, so the graph_def alone without real C++ object can't load complete graph.

How can I load a executable graph with tf.data.Dataset()? Or is it impossible for now?

File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

NotFoundError (see above for traceback): Function
_make_dataset_5150cb86 is not defined.
         [[Node: batch_processing/OneShotIterator = OneShotIterator[container="", dataset_factory=_make_dataset_5150cb86[], output_shapes=[[?,1], [?,299,299,3]], output_types=[DT_INT32, DT_FLOAT], shared_name="",
_device="/job:workers/replica:0/task:0/device:CPU:0"]()]]

In short, all the tensorflow graphs without tf.data.Dataset work, when i add following codes.

graph = tf.train.export_meta_graph()
tf.reset_default_graph()
tf.train.import_meta_graph(graph)

But the graphs withtf.data.Datasetmake a error message above

@tensorflowbutler tensorflowbutler added the stat:awaiting response Status - Awaiting response from author label Dec 1, 2017
@tensorflowbutler
Copy link
Member

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks.
Have I written custom code
OS Platform and Distribution
TensorFlow installed from
TensorFlow version
Bazel version
CUDA/cuDNN version
GPU model and memory
Exact command to reproduce

@suryasumukh
Copy link

I have a similar issue -
I'm trying to load a trained model for inference. In the code snippet below, features:0 is the name of one of the tensors returned by dataset iterator. Is it possible to directly feed the tensor without having to initialize the iterator ?

Code

with tf.Session(graph=tf.Graph()) as session:
    graph_meta = tf.train.latest_checkpoint(model_dir) + '.meta'
    saver = tf.train.import_meta_graph(os.path.join(model_dir, graph_meta))
    saver.restore(session, tf.train.latest_checkpoint(model_dir))
    feed_dict = {
        'features:0': x_test # shape 102x13
    }
    predictions = session.run('logits:0', feed_dict)
    print(predictions.shape)

Error

FailedPreconditionError: GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
	 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,13], [?], [?]], output_types=[DT_DOUBLE, DT_DOUBLE, DT_DOUBLE], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]

@mrry mrry self-assigned this Dec 5, 2017
@mrry mrry added the type:bug Bug label Dec 5, 2017
@mrry
Copy link
Contributor

mrry commented Dec 5, 2017

@taehyunkim1527 Can you share a complete reproducible example of the problem? I was unable to reproduce the problem with the code fragment in your example, although it was possible to reproduce it by adding arguments to tf.train.export_meta_graph(). I am in the process of fixing the latter problems, but it would be great to confirm that the fix works for your issue too.

@suryasumukh I think you're running into a different problem. You certainly can feed values over the tensors returned from Iterator.get_next(), but you need to ensure that you feed all of the tensors returned by the iterator that might be used in the run() call. If you continue to have problems with this, please post a question on Stack Overflow.

@taehyunkim1527
Copy link
Author

@mrry When i use that series of code of 'exporting and importing' in tensorflow/benchmarks's ImageNet tasks, I experienced this error message.
I tried to implement a simple code to make the same error message today, but it failed.
I think it is not a problem of tf.data.Dataset() and let me debug again.
Thank you for your kindness.

@mrry
Copy link
Contributor

mrry commented Dec 6, 2017

Thanks for confirming that the problem doesn't arise with that exact code! I have a change in the pipeline that will make this path work in more cases (e.g. when using clear_devices=True or a scope prefix), so I'll reopen this issue until it lands.

@mrry mrry reopened this Dec 6, 2017
@gunan gunan closed this as completed in af8a550 Dec 7, 2017
@taehyunkim1527
Copy link
Author

taehyunkim1527 commented Jan 12, 2018

@mrry
Hi, I'm who opened this question formerly.
And i found what was the problem.
if we give arg 'clear_devices' as True for the function 'export_scoped_meta_graph()' of python\framework\meta_graph.py
It starts to make empty graph_def and copy previous node_defs newly.
In this process, this code does not consider about graph_def.library.
But the information about 'map_function' resides in graph_def.library

@karthi0804
Copy link

karthi0804 commented Apr 7, 2019

@suryasumukh In order to retrain the pre-trained model, the initializer of data iterator can be declared as a tf.operation with a name while training for the first time.

    data_iter = dataset.make_initializable_iterator()
    data_iter_init = data_iter.make_initializer(dataset, name='Data_itr_init')
    next_batch = data_iter.get_next()

Then, it can be sess.run with the name and fed with training data.
sess.run('Data_itr_init', feed_dict={"Model_in:0": train_X, "Model_out:0": train_Y})

@suryasumukh
Copy link

@suryasumukh In order to retrain the pre-trained model, the initializer of data iterator can be declared as a tf.operation with a name while training for the first time.

    data_iter = dataset.make_initializable_iterator()
    data_iter_init = data_iter.make_initializer(dataset, name='Data_itr_init')
    next_batch = data_iter.get_next()

Then, it can be sess.run with the name and fed with training data.
sess.run('Data_itr_init', feed_dict={"Model_in:0": train_X, "Model_out:0": train_Y})

Not the point

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response Status - Awaiting response from author type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants