Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTM return_state=True fail with tf.keras.Sequencial model #36624

Closed
durandg12 opened this issue Feb 10, 2020 · 3 comments
Closed

LSTM return_state=True fail with tf.keras.Sequencial model #36624

durandg12 opened this issue Feb 10, 2020 · 3 comments
Assignees
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.1 for tracking issues in 2.1 release type:bug Bug

Comments

@durandg12
Copy link

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): no
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.13.6
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v2.1.0-rc2-17-ge5bf8de410 2.1.0
  • Python version: v3.6.7:6ec5cf24b7, Oct 20 2018, 03:02:14

Describe the current behavior

The call method of a tf.keras.Sequential object fails and throws an error when one layer is an instance of the tf.keras.layers.LSTM class constructed with return_state=True. Given the error message, I believe it is because the output of the call method of such LSTM layer is a list instead of a Tensor, and the call method of Sequential does not know what to do with a list.

Describe the expected behavior

I think that the call method of Sequential should know that the Tensor output of LSTM is the first element of the list when return_state=True.

Code to reproduce the issue
Setting :

import tensorflow as tf
import numpy as np

print('Using Tensorflow version {} (git version {})'.format(tf.version.VERSION, tf.version.GIT_VERSION))

batch_size = 3
ts = 9
input_dim = 2
nump = np.arange(examples*batch_size*ts*input_dim, dtype=np.float32).reshape(batch_size, ts, input_dim)
dataset = tf.data.Dataset.from_tensor_slices(nump).batch(batch_size)
for x in dataset:
    print(x.shape)
return_state = True

Output:

Using Tensorflow version 2.1.0 (git version v2.1.0-rc2-17-ge5bf8de410)
(3, 9, 2)

Error with Sequential:

model_seq = tf.keras.Sequential([tf.keras.layers.LSTM(3, return_state=return_state)])
for x in dataset:
    print(model_seq(x))

Output:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-57-5500870ab2fc> in <module>
      1 model_seq = tf.keras.Sequential([tf.keras.layers.LSTM(3, return_state=return_state)])
      2 for x in dataset:
----> 3     print(model_seq(x))

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    820           with base_layer_utils.autocast_context_manager(
    821               self._compute_dtype):
--> 822             outputs = self.call(cast_inputs, *args, **kwargs)
    823           self._handle_activity_regularization(inputs, outputs)
    824           self._set_mask_metadata(inputs, outputs, input_masks)

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/engine/sequential.py in call(self, inputs, training, mask)
    283       # `outputs` will be the inputs to the next layer.
    284       inputs = outputs
--> 285       mask = outputs._keras_mask
    286 
    287     return outputs

AttributeError: 'list' object has no attribute '_keras_mask'

It works when constructing the model with the Functional API:

def lstm_model(return_state, ts, input_dim):
    inp = tf.keras.Input(shape=(ts, input_dim))
    out = tf.keras.layers.LSTM(3, return_state=return_state)(inp)
    return tf.keras.Model(inputs=inp, outputs=out)
    
model_func = lstm_model(return_state, ts, input_dim)

for x in dataset:
    print(model_func(x))

Output:

[<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[-8.8475537e-01,  2.9517543e-03, -9.9753261e-01],
       [-9.7553629e-01,  9.5521700e-06, -9.9959475e-01],
       [-9.9497062e-01,  3.0903845e-08, -9.9979442e-01]], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[-8.8475537e-01,  2.9517543e-03, -9.9753261e-01],
       [-9.7553629e-01,  9.5521700e-06, -9.9959475e-01],
       [-9.9497062e-01,  3.0903845e-08, -9.9979442e-01]], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[-7.6066346e+00,  2.9581292e-03, -3.3488092e+00],
       [-8.9999275e+00,  9.5521846e-06, -4.2520967e+00],
       [-9.0000000e+00,  3.0903848e-08, -4.5915442e+00]], dtype=float32)>]

Related question
In my Functional API example, lstm_modelfails if I use inp = tf.keras.Input(shape=(ts, None)) instead of providing the explicit input dimension. The error message I get is:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-64-9b042ffca48d> in <module>
      4     return tf.keras.Model(inputs=inp, outputs=out)
      5 
----> 6 model_func = lstm_model(return_state, ts, input_dim)
      7 
      8 for x in dataset:

<ipython-input-64-9b042ffca48d> in lstm_model(return_state, ts, input_dim)
      1 def lstm_model(return_state, ts, input_dim):
      2     inp = tf.keras.Input(shape=(ts, None))
----> 3     out = tf.keras.layers.LSTM(3, return_state=return_state)(inp)
      4     return tf.keras.Model(inputs=inp, outputs=out)
      5 

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/layers/recurrent.py in __call__(self, inputs, initial_state, constants, **kwargs)
    642 
    643     if initial_state is None and constants is None:
--> 644       return super(RNN, self).__call__(inputs, **kwargs)
    645 
    646     # If any of `initial_state` or `constants` are specified and are Keras

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    746           # Build layer if applicable (if the `build` method has been
    747           # overridden).
--> 748           self._maybe_build(inputs)
    749           cast_inputs = self._maybe_cast_inputs(inputs)
    750 

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _maybe_build(self, inputs)
   2114         # operations.
   2115         with tf_utils.maybe_init_scope(self):
-> 2116           self.build(input_shapes)
   2117       # We must set self.built since user defined build functions are not
   2118       # constrained to set self.built.

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/layers/recurrent.py in build(self, input_shape)
    562     if isinstance(self.cell, Layer):
    563       if not self.cell.built:
--> 564         self.cell.build(step_input_shape)
    565 
    566     # set or validate state_spec

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/utils/tf_utils.py in wrapper(instance, input_shape)
    304     if input_shape is not None:
    305       input_shape = convert_shapes(input_shape, to_tuples=True)
--> 306     output_shape = fn(instance, input_shape)
    307     # Return shapes from `fn` as TensorShapes.
    308     if output_shape is not None:

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/layers/recurrent.py in build(self, input_shape)
   2299         regularizer=self.kernel_regularizer,
   2300         constraint=self.kernel_constraint,
-> 2301         caching_device=default_caching_device)
   2302     self.recurrent_kernel = self.add_weight(
   2303         shape=(self.units, self.units * 4),

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in add_weight(self, name, shape, dtype, initializer, regularizer, trainable, constraint, partitioner, use_resource, synchronization, aggregation, **kwargs)
    444         synchronization=synchronization,
    445         aggregation=aggregation,
--> 446         caching_device=caching_device)
    447     backend.track_variable(variable)
    448 

~/path/to/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py in _add_variable_with_custom_getter(self, name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
    742         dtype=dtype,
    743         initializer=initializer,
--> 744         **kwargs_for_getter)
    745 
    746     # If we set an initializer and the variable processed it, tracking will not

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py in make_variable(name, shape, dtype, initializer, trainable, caching_device, validate_shape, constraint, use_resource, collections, synchronization, aggregation, partitioner)
    140       synchronization=synchronization,
    141       aggregation=aggregation,
--> 142       shape=variable_shape if variable_shape else None)
    143 
    144 

~/path/to/python3.6/site-packages/tensorflow_core/python/ops/variables.py in __call__(cls, *args, **kwargs)
    256   def __call__(cls, *args, **kwargs):
    257     if cls is VariableV1:
--> 258       return cls._variable_v1_call(*args, **kwargs)
    259     elif cls is Variable:
    260       return cls._variable_v2_call(*args, **kwargs)

~/path/to/python3.6/site-packages/tensorflow_core/python/ops/variables.py in _variable_v1_call(cls, initial_value, trainable, collections, validate_shape, caching_device, name, variable_def, dtype, expected_shape, import_scope, constraint, use_resource, synchronization, aggregation, shape)
    217         synchronization=synchronization,
    218         aggregation=aggregation,
--> 219         shape=shape)
    220 
    221   def _variable_v2_call(cls,

~/path/to/python3.6/site-packages/tensorflow_core/python/ops/variables.py in <lambda>(**kwargs)
    195                         shape=None):
    196     """Call on Variable class. Useful to force the signature."""
--> 197     previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
    198     for _, getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
    199       previous_getter = _make_getter(getter, previous_getter)

~/path/to/python3.6/site-packages/tensorflow_core/python/ops/variable_scope.py in default_variable_creator(next_creator, **kwargs)
   2594         synchronization=synchronization,
   2595         aggregation=aggregation,
-> 2596         shape=shape)
   2597   else:
   2598     return variables.RefVariable(

~/path/to/python3.6/site-packages/tensorflow_core/python/ops/variables.py in __call__(cls, *args, **kwargs)
    260       return cls._variable_v2_call(*args, **kwargs)
    261     else:
--> 262       return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    263 
    264 

~/path/to/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py in __init__(self, initial_value, trainable, collections, validate_shape, caching_device, name, dtype, variable_def, import_scope, constraint, distribute_strategy, synchronization, aggregation, shape)
   1409           aggregation=aggregation,
   1410           shape=shape,
-> 1411           distribute_strategy=distribute_strategy)
   1412 
   1413   def _init_from_args(self,

~/path/to/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py in _init_from_args(self, initial_value, trainable, collections, caching_device, name, dtype, constraint, synchronization, aggregation, distribute_strategy, shape)
   1540           with ops.name_scope("Initializer"), device_context_manager(None):
   1541             initial_value = ops.convert_to_tensor(
-> 1542                 initial_value() if init_from_fn else initial_value,
   1543                 name="initial_value", dtype=dtype)
   1544           if shape is not None:

~/path/to/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py in <lambda>()
    120           (type(init_ops.Initializer), type(init_ops_v2.Initializer))):
    121         initializer = initializer()
--> 122       init_val = lambda: initializer(shape, dtype=dtype)
    123       variable_dtype = dtype.base_dtype
    124   if use_resource is None:

~/path/to/python3.6/site-packages/tensorflow_core/python/ops/init_ops_v2.py in __call__(self, shape, dtype)
    413       scale /= max(1., fan_out)
    414     else:
--> 415       scale /= max(1., (fan_in + fan_out) / 2.)
    416     if self.distribution == "truncated_normal":
    417       # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)

TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

Is it normal? If so, why is that?

@amahendrakar
Copy link
Contributor

Was able to reproduce the issue. Please find the Gist here. Thanks!

@amahendrakar amahendrakar added comp:keras Keras related issues TF 2.1 for tracking issues in 2.1 release type:support Support issues labels Feb 11, 2020
@jvishnuvardhan jvishnuvardhan added type:bug Bug and removed type:support Support issues labels Feb 11, 2020
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 11, 2020
@qlzh727
Copy link
Member

qlzh727 commented Feb 14, 2020

So for Sequential model, we expect the layer within it only have one input and one output. The LSTM layer with return_states=True will cause it to return more than 1 output, which violate this rule.

I think the sequential model code need to be updated to show more explicit error for this case. We already show it if your model has the input_shape (which trigger model build under the hood), but we missed it in the deferred build case (input_shape is not provided by layers, but inferred when actual input is provided).

tensorflow-copybara pushed a commit that referenced this issue Feb 15, 2020
…deferred mode.

This is a update for #36624, which we should show explicit error, rather than let the code proceed and failed down the road.

PiperOrigin-RevId: 295248095
Change-Id: I9cf9d0267f222c3994f3199bd9b49d86196ebb3b
@qlzh727 qlzh727 closed this as completed Feb 18, 2020
@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.1 for tracking issues in 2.1 release type:bug Bug
Projects
None yet
Development

No branches or pull requests

4 participants