Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to load model with scalar inputs #35446

Closed
loveychen opened this issue Dec 27, 2019 · 3 comments
Closed

Failed to load model with scalar inputs #35446

loveychen opened this issue Dec 27, 2019 · 3 comments
Assignees
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:support Support issues

Comments

@loveychen
Copy link

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):

No

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):

Both Win10 and Ubuntu 18.04

  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary):

from binary using pip

  • TensorFlow version (use command below):

2.0.0

  • Python version:

3.7

  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:

cudnn-10.2-windows10-x64-v7.6.5.32

  • GPU model and memory:

24G

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior

I defined a model with two inputs: one of shape (seq_len, fea_size), and the other of shape ().
It is fine while training and saving the model, but it gets failed when loading the model from the export directory, with error show like this:

Traceback (most recent call last):
  File "D:/Work/Python/track_by_classification/src/test/test_model.py", line 101, in <module>
    test_model_with_scalar_input()
  File "D:/Work/Python/track_by_classification/src/test/test_model.py", line 90, in test_model_with_scalar_input
    new_model = tf.keras.models.load_model(model_home)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\saving\save.py", line 150, in load_model
    return saved_model_load.load(filepath, compile)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\load.py", line 86, in load
    model = tf_load.load_internal(path, loader_cls=KerasObjectLoader)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\saved_model\load.py", line 541, in load_internal
    export_dir)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\load.py", line 103, in __init__
    self._finalize()
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\load.py", line 132, in _finalize
    node._set_inputs(inputs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 2709, in _set_inputs
    outputs = self(inputs, **kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 842, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\utils.py", line 57, in return_outputs_and_add_losses
    outputs, losses = fn(inputs, *args, **kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\utils.py", line 111, in wrap_with_training_arg
    lambda: replace_training_and_call(False))
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\utils\tf_utils.py", line 59, in smart_cond
    pred, true_fn=true_fn, false_fn=false_fn, name=name)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\framework\smart_cond.py", line 59, in smart_cond
    name=name)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\ops\control_flow_ops.py", line 1174, in cond
    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\ops\cond_v2.py", line 84, in cond_v2
    op_return_value=pred)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\utils.py", line 110, in <lambda>
    lambda: replace_training_and_call(True),
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\keras\saving\saved_model\utils.py", line 106, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 457, in __call__
    result = self._call(*args, **kwds)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 503, in _call
    self._initialize(args, kwds, add_initializers_to=initializer_map)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 408, in _initialize
    *args, **kwds))
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\eager\function.py", line 1848, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\eager\function.py", line 2150, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\eager\function.py", line 2041, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 358, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "D:\WorkPrograms\Anaconda3\envs\tf2.0\lib\site-packages\tensorflow_core\python\saved_model\function_deserialization.py", line 262, in restored_function_body
    "\n\n".join(signature_descriptions)))
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (3 total):
    * [<tf.Tensor 'inputs:0' shape=(None, 10, 128) dtype=float32>, <tf.Tensor 'inputs_1:0' shape=(None, 1) dtype=float32>]
    * True
    * None
  Keyword arguments: {}

Expected these arguments to match one of the following 4 option(s):

Option 1:
  Positional arguments (3 total):
    * [TensorSpec(shape=(None, 10, 128), dtype=tf.float32, name='inputs/0'), TensorSpec(shape=(None,), dtype=tf.float32, name='inputs/1')]
    * False
    * None
  Keyword arguments: {}

Option 2:
  Positional arguments (3 total):
    * [TensorSpec(shape=(None, 10, 128), dtype=tf.float32, name='input_1'), TensorSpec(shape=(None,), dtype=tf.float32, name='input_2')]
    * False
    * None
  Keyword arguments: {}

Option 3:
  Positional arguments (3 total):
    * [TensorSpec(shape=(None, 10, 128), dtype=tf.float32, name='inputs/0'), TensorSpec(shape=(None,), dtype=tf.float32, name='inputs/1')]
    * True
    * None
  Keyword arguments: {}

Option 4:
  Positional arguments (3 total):
    * [TensorSpec(shape=(None, 10, 128), dtype=tf.float32, name='input_1'), TensorSpec(shape=(None,), dtype=tf.float32, name='input_2')]
    * True
    * None
  Keyword arguments: {}

And as I have hacked the tensorflow source code, I found the bug comes from here training_utils.ModelInputs.get_symbolic_inputs

  def get_symbolic_inputs(self, return_single_as_list=False):
    """Returns inputs to be set as self.inputs for a model."""
    # TODO(karmel): There is a side-effect here where what you get
    # with as_list and as_dict depends on whether you have called this
    # method first, since it modifies in place.
    for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)):
      if isinstance(v, (list, float, int)):
        v = np.asarray(v)
        if v.ndim == 1:
          v = np.expand_dims(v, 1)

      if isinstance(v, (np.ndarray, ops.EagerTensor)):
        # We fix the placeholder shape except the batch size.
        # This is suboptimal, but it is the best we can do with the info
        # we have. The user should call `model._set_inputs(placeholders)`
        # to specify custom placeholders if the need arises.
        shape = (None,) + tuple(v.shape[1:])
        if shape == (None,):
          shape = (None, 1)
        dtype = dtypes.as_dtype(v.dtype)
        if dtype.is_floating:
          dtype = K.floatx()
        v = K.placeholder(shape=shape, name=k, dtype=dtype)
      elif isinstance(v, tensor_spec.TensorSpec):
        shape = (None,) + tuple(v.shape.as_list()[1:])
        if shape == (None,):
        >>>>>  shape = (None, 1) <<<< here is where the errors comes from
        v = K.placeholder(shape=shape, name=k, dtype=v.dtype)

      self._flattened_inputs[i] = v

    if self._is_dict:
      return dict(zip(self._input_names, self._flattened_inputs))
    if self._is_single_input and not return_single_as_list:
      return self._flattened_inputs[0]
    return self._flattened_inputs

Describe the expected behavior

Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.

def test_model_with_scalar_input():
    model_home = "models"
    input_shape = 5, 10, 128
    n_class = 5

    batch_size, seq_len, fea_size = input_shape

    input_fea = tf.keras.layers.Input(shape=input_shape[1:])
    input_seq = tf.keras.layers.Input(shape=())

    name = "test-model"
    forward_layer = tf.keras.layers.LSTM(units=fea_size, return_sequences=True, name=name + "-forward")

    backward_layer = tf.keras.layers.LSTM(units=fea_size, return_sequences=True, go_backwards=True,
                                          name=name + "-backward")

    bi_lstm = tf.keras.layers.Bidirectional(forward_layer, backward_layer=backward_layer, name=name + "-bi-lstm")

    fc = tf.keras.layers.Dense(units=n_class, activation="softmax", name="fc-pred")

    mask = tf.sequence_mask(input_seq, seq_len)

    y = bi_lstm(input_fea, mask=mask)
    y = fc(y)

    model = tf.keras.models.Model(inputs=[input_fea, input_seq], outputs=y, name=name)

    features = np.random.normal(size=[batch_size * 100, seq_len, fea_size])
    labels = np.random.randint(0, n_class, size=[batch_size * 100, seq_len])
    # seq = np.random.randint(0, seq_len, size=[batch_size * 100])
    seq = [seq_len] * (batch_size * 100)

    dataset = tf.data.Dataset.from_tensor_slices(((features, seq), labels))
    dataset = dataset.batch(batch_size)

    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy")

    saver = tf.keras.callbacks.ModelCheckpoint(model_home, monitor="loss", verbose=1, save_best_only=True)
    model.fit(dataset, epochs=5, callbacks=[saver], steps_per_epoch=10)

    new_model = tf.keras.models.load_model(model_home)

    # print(new_model.layers)

    pred = new_model.predict(dataset)
    pred = np.argmax(pred, axis=-1)
    print(pred)

Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

The code showed above performs as expected on 2.1.0-rc1

@oanush oanush self-assigned this Dec 30, 2019
@oanush oanush added comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 labels Dec 30, 2019
@oanush
Copy link

oanush commented Dec 31, 2019

@loveychen ,
Hi,Can you please try using tf-nightly 2.1.0.dev20191230 version as it will be stable version going forward. Find the gist of the colab for your reference. Thanks!

@oanush oanush added stat:awaiting response Status - Awaiting response from author type:bug Bug labels Dec 31, 2019
@loveychen
Copy link
Author

Hi @oanush ,

I run the code with tf-nightly 2.1.0.dev20191230, and yes, it works as expected.

What's your time plan for the stable tf2.1?

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Dec 31, 2019
@oanush oanush added type:support Support issues and removed type:bug Bug labels Jan 8, 2020
@oanush
Copy link

oanush commented Jan 8, 2020

@loveychen ,
Time plan will be soon announced, right now there is no official update. I will close the issue as it is resolved, please feel free to open if you face any issue.Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:support Support issues
Projects
None yet
Development

No branches or pull requests

3 participants