Skip to content

Saving model with tf.keras.layers.RNN with unroll=True fails for save_format=tf #36666

@padoremu

Description

@padoremu

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution: Linux Ubuntu 18.04
  • Mobile device if the issue happens on mobile device: -
  • TensorFlow installed from: binary (tf-nightly via docker)
  • TensorFlow version: GIT_VERSION = v1.12.1-23779-g96c5c8a 2.2.0-dev20200202
  • Python version: 3.6.9
  • Bazel version (if compiling from source): -
  • GCC/Compiler version (if compiling from source): -
  • CUDA/cuDNN version: CPU only
  • GPU model and memory: CPU only

Describe the current behavior
Saving a tf.keras.Sequential model with tf.keras.layers.RNN with unroll=True fails for save_format=tf(but succeeds for save_format=ht).

Describe the expected behavior
Saving should succeed for save_format=tf as well.

Code to reproduce the issue

import tensorflow as tf

model = tf.keras.Sequential()

model.add(tf.keras.layers.Input(shape=(1, 1,)))

cell = tf.keras.layers.GRUCell(10)

model.add(tf.keras.layers.RNN(cell, unroll=True))
    
model.save("test.tf", save_format='tf') # fails
#model.save("test.h5", save_format='h5') # works

Other info / logs
Unfortunately, saving as h5 is not an option (which would actually be my favorite), since it fails when having more than one cell (see #36093).

Traceback in case of failure:

  File "test.py", line 11, in <module>
    model.save("test.tf", save_format='tf') # fails
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/network.py", line 999, in save
    signatures, options)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/save.py", line 138, in save_model
    signatures, options)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save.py", line 78, in save
    save_lib.save(model, filepath, signatures, options)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/saved_model/save.py", line 955, in save
    checkpoint_graph_view)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/saved_model/save.py", line 142, in list_functions
    self._serialization_cache)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py", line 2535, in _list_functions_for_serialization
    .list_functions_for_serialization(serialization_cache))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/base_serialization.py", line 91, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 79, in functions_to_serialize
    serialization_cache).functions_to_serialize)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
    serialization_cache)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/model_serialization.py", line 53, in _get_serialized_attributes_internal
    serialization_cache))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 103, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 161, in wrap_layer_functions
    original_fns = _replace_child_layer_functions(layer, serialization_cache)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 249, in _replace_child_layer_functions
    serialization_cache).functions)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
    serialization_cache)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py", line 103, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 171, in wrap_layer_functions
    '{}_layer_call_and_return_conditional_losses'.format(layer.name))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 487, in add_function
    self.add_trace(*self._input_signature)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 402, in add_trace
    trace_with_training(True)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 400, in trace_with_training
    fn.get_concrete_function(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 531, in get_concrete_function
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 953, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 859, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 505, in _initialize
    *args, **kwds))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 2440, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 2771, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 2661, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 440, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 508, in wrapper
    ret = method(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/utils.py", line 170, in wrap_with_training_arg
    lambda: replace_training_and_call(False))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/utils/tf_utils.py", line 59, in smart_cond
    pred, true_fn=true_fn, false_fn=false_fn, name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/utils.py", line 169, in <lambda>
    lambda: replace_training_and_call(True),
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/utils.py", line 165, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py", line 550, in call_and_return_conditional_losses
    return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/layers/recurrent.py", line 734, in call
    raise ValueError('Cannot unroll a RNN if the '
ValueError: Cannot unroll a RNN if the time dimension is undefined. 
- If using a Sequential model, specify the time dimension by passing an `input_shape` or `batch_input_shape` argument to your first layer. If your first layer is an Embedding, you can also use the `input_length` argument.
- If using the functional API, specify the time dimension by passing a `shape` or `batch_shape` argument to your Input layer.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions