Closed
Description
Function import_meta_graph
fails with KeyError: InfeedEnqueueTuple
when importing a meta file of a graph trained on TPU (and with TPUEstimator
). For graphs trained on CPU (also with TPUEstimator
) it works correctly. Is this expected behaviour? How can I load a graph with parameters trained on TPU for CPU evaluation?
I noticed a similar error also related to TPU in another Github project, but no solution: tensorflow/minigo#426.
System information
- Reproduced on MacOS and on Colab.
- tensorflow 1.12.0 (from PyPI)
- Python 3.6
- CUDA/cuDNN/GPU not used
Describe the current behavior
Raises exception KeyError: InfeedEnqueueTuple
.
Describe the expected behavior
Should load the graph without an exception.
Code to reproduce the issue
tf.train.import_meta_graph(META_PATH, clear_devices=True)
where META_PATH
is a path to a meta file saved with TPUEstimator
.
Other info / logs
Full traceback:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-9-962933bf6153> in <module>()
1
----> 2 tf.train.import_meta_graph(META_PATH, clear_devices=True)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in import_meta_graph(meta_graph_or_file, clear_devices, import_scope, **kwargs)
1672 """ # pylint: disable=g-doc-exception
1673 return _import_meta_graph_with_return_elements(
-> 1674 meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
1675
1676
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py in _import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope, return_elements, **kwargs)
1694 import_scope=import_scope,
1695 return_elements=return_elements,
-> 1696 **kwargs))
1697
1698 saver = _create_saver_from_imported_meta_graph(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/meta_graph.py in import_scoped_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate, return_elements)
804 input_map=input_map,
805 producer_op_list=producer_op_list,
--> 806 return_elements=return_elements)
807
808 # Restores all the other collections.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
486 'in a future version' if date is None else ('after %s' % date),
487 instructions)
--> 488 return func(*args, **kwargs)
489 return tf_decorator.make_decorator(func, new_func, 'deprecated',
490 _add_deprecated_arg_notice_to_docstring(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
389 if producer_op_list is not None:
390 # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
--> 391 _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
392
393 graph = ops.get_default_graph()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py in _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
156 # Remove any default attr values that aren't in op_def.
157 if node.op in producer_op_dict:
--> 158 op_def = op_dict[node.op]
159 producer_op_def = producer_op_dict[node.op]
160 # We make a copy of node.attr to iterate through since we may modify
KeyError: 'InfeedEnqueueTuple'