-
Notifications
You must be signed in to change notification settings - Fork 75.3k
Saving model with tf.keras.layers.RNN with unroll=True fails for save_format=tf #36666
Copy link
Copy link
Closed
Closed
Copy link
Labels
TF 2.3Issues related to TF 2.3Issues related to TF 2.3comp:kerasKeras related issuesKeras related issuestype:bugBugBug
Description
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution: Linux Ubuntu 18.04
- Mobile device if the issue happens on mobile device: -
- TensorFlow installed from: binary (tf-nightly via docker)
- TensorFlow version: GIT_VERSION = v1.12.1-23779-g96c5c8a 2.2.0-dev20200202
- Python version: 3.6.9
- Bazel version (if compiling from source): -
- GCC/Compiler version (if compiling from source): -
- CUDA/cuDNN version: CPU only
- GPU model and memory: CPU only
Describe the current behavior
Saving a tf.keras.Sequential model with tf.keras.layers.RNN with unroll=True fails for save_format=tf(but succeeds for save_format=ht).
Describe the expected behavior
Saving should succeed for save_format=tf as well.
Code to reproduce the issue
import tensorflow as tf
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(1, 1,)))
cell = tf.keras.layers.GRUCell(10)
model.add(tf.keras.layers.RNN(cell, unroll=True))
model.save("test.tf", save_format='tf') # fails
#model.save("test.h5", save_format='h5') # worksOther info / logs
Unfortunately, saving as h5 is not an option (which would actually be my favorite), since it fails when having more than one cell (see #36093).
Traceback in case of failure:
File "test.py", line 11, in <module>
model.save("test.tf", save_format='tf') # fails
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/network.py", line 999, in save
signatures, options)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/save.py", line 138, in save_model
signatures, options)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save.py", line 78, in save
save_lib.save(model, filepath, signatures, options)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/saved_model/save.py", line 955, in save
checkpoint_graph_view)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
functions = saveable_view.list_functions(saveable_view.root)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/saved_model/save.py", line 142, in list_functions
self._serialization_cache)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py", line 2535, in _list_functions_for_serialization
.list_functions_for_serialization(serialization_cache))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/base_serialization.py", line 91, in list_functions_for_serialization
fns = self.functions_to_serialize(serialization_cache)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 79, in functions_to_serialize
serialization_cache).functions_to_serialize)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
serialization_cache)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/model_serialization.py", line 53, in _get_serialized_attributes_internal
serialization_cache))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 103, in _get_serialized_attributes_internal
functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 161, in wrap_layer_functions
original_fns = _replace_child_layer_functions(layer, serialization_cache)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 249, in _replace_child_layer_functions
serialization_cache).functions)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
serialization_cache)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 103, in _get_serialized_attributes_internal
functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 171, in wrap_layer_functions
'{}_layer_call_and_return_conditional_losses'.format(layer.name))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 487, in add_function
self.add_trace(*self._input_signature)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 402, in add_trace
trace_with_training(True)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 400, in trace_with_training
fn.get_concrete_function(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 531, in get_concrete_function
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 953, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 859, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 505, in _initialize
*args, **kwds))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 2440, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 2771, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 2661, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 440, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 508, in wrapper
ret = method(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/utils.py", line 170, in wrap_with_training_arg
lambda: replace_training_and_call(False))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/utils/tf_utils.py", line 59, in smart_cond
pred, true_fn=true_fn, false_fn=false_fn, name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/smart_cond.py", line 54, in smart_cond
return true_fn()
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/utils.py", line 169, in <lambda>
lambda: replace_training_and_call(True),
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/utils.py", line 165, in replace_training_and_call
return wrapped_call(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 550, in call_and_return_conditional_losses
return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/layers/recurrent.py", line 734, in call
raise ValueError('Cannot unroll a RNN if the '
ValueError: Cannot unroll a RNN if the time dimension is undefined.
- If using a Sequential model, specify the time dimension by passing an `input_shape` or `batch_input_shape` argument to your first layer. If your first layer is an Embedding, you can also use the `input_length` argument.
- If using the functional API, specify the time dimension by passing a `shape` or `batch_shape` argument to your Input layer.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
TF 2.3Issues related to TF 2.3Issues related to TF 2.3comp:kerasKeras related issuesKeras related issuestype:bugBugBug