Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF 2.0: Cannot use recurrent_dropout with LSTMs/GRUs #29187

Closed
sbagroy986 opened this issue May 30, 2019 · 12 comments
Closed

TF 2.0: Cannot use recurrent_dropout with LSTMs/GRUs #29187

sbagroy986 opened this issue May 30, 2019 · 12 comments
Assignees
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@sbagroy986
Copy link

sbagroy986 commented May 30, 2019

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No (one line modification to stock example)
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 14.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): tensorflow-gpu==2.0.0-alpha0 (also fails with every other tf 2.0 build I have explored)
  • Python version: 3.6
  • Bazel version (if compiling from source): N/A
  • GCC/Compiler version (if compiling from source): N/A
  • CUDA/cuDNN version: Tried multiple
  • GPU model and memory: Tried multiple

Describe the current behavior
The program crashes with a TypeError as below:

TypeError: An op outside of the function building code is being passed a "Graph" tensor. It is possible to have Graph tensors leak out of the function building context by including a tf.init_scope in your function building code. For example, the following function will fail: @tf.function def has_init_scope(): my_constant = tf.constant(1.) with tf.init_scope(): added = my_constant * 2 The graph tensor has name: encoder/unified_gru/ones_like:0

This occurs when trying to backprop the gradients through the LSTM/GRU with recurrent_dropout enabled.

Describe the expected behavior
No error

Code to reproduce the issue
Since this problem shows up at the time of training, one needs to have the entire training pipeline (dataset, model etc.) setup to demonstrate this bug. As a result, I used the Neural Machine Translation tutorial from TensorFlow and modified their model to include recurrent_dropout. The entire code can be found in this Colab notebook; run the code blocks all the way till the block where we're training the model to see the bug.

Other info.logs

x---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
 in ()
      8 
      9   for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
---> 10     batch_loss = train_step(inp, targ, enc_hidden)
     11     total_loss += batch_loss
     12 

6 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    436         # Lifting succeeded, so variables are initialized and we can run the
    437         # stateless function.
--> 438         return self._stateless_fn(*args, **kwds)
    439     else:
    440       canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   1286     """Calls a graph function specialized to the inputs."""
   1287     graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 1288     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   1289 
   1290   @property

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs)
    572     """
    573     return self._call_flat(
--> 574         (t for t in nest.flatten((args, kwargs))
    575          if isinstance(t, (ops.Tensor,
    576                            resource_variable_ops.ResourceVariable))))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args)
    625     # Only need to override the gradient in graph mode and when we have outputs.
    626     if context.executing_eagerly() or not self.outputs:
--> 627       outputs = self._inference_function.call(ctx, args)
    628     else:
    629       self._register_gradient()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args)
    413             attrs=("executor_type", executor_type,
    414                    "config_proto", config),
--> 415             ctx=ctx)
    416       # Replace empty list with None
    417       outputs = outputs or None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     68     if any(ops._is_keras_symbolic_tensor(x) for x in inputs):
     69       raise core._SymbolicException
---> 70     raise e
     71   # pylint: enable=protected-access
     72   return tensors

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
     59                                                op_name, inputs, attrs,
---> 60                                                num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: encoder/unified_gru/ones_like:0
@achandraa achandraa self-assigned this May 31, 2019
@achandraa achandraa added 2.0.0-alpha0 comp:keras Keras related issues type:bug Bug labels May 31, 2019
@achandraa
Copy link

Have tried with TensorFlow version 2.0.0-alpha and was able to reproduce the issue.

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 31, 2019
@qlzh727
Copy link
Member

qlzh727 commented Jun 3, 2019

Thanks for reporting the issue, let me take a look.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 4, 2019
pull bot pushed a commit to Cache-Cloud/tensorflow that referenced this issue Jun 4, 2019
They were missing for the non-defun branch.

See tensorflow#29187 for more
details.

PiperOrigin-RevId: 251442578
@qlzh727
Copy link
Member

qlzh727 commented Jun 4, 2019

Thanks for reporting the issue, it should now be fixed by 180f28a

@qlzh727 qlzh727 closed this as completed Jun 4, 2019
@tensorflow-bot
Copy link

tensorflow-bot bot commented Jun 4, 2019

Are you satisfied with the resolution of your issue?
Yes
No

@qlzh727
Copy link
Member

qlzh727 commented Jun 4, 2019

Btw, the current colab might not apply the dropout correctly if you only enable the dropout/recurrent_dropout on the GRU layer. Under the hood, the keras layer will check whether the current context is in training or inference, and only apply the dropout during training. If the GRU layer was using by a keras model together with model.fit/eval/predict, then the training context will be applied correctly. However, if the user is writing their own custom training loop, then the training context need to be set manually, eg by

tf.keras.backend.set_learning_phase(1)  # training
run_train_step()

tf.keras.backend.set_learning_phase(0)
run_eval_step()

The other alternative is that make sure the encoder/decoder's call() method is training state aware. eg, the method could take a new kwarg training=None, and set to different value during training and inference. The training value need to be popagated to GRU's call() method as well.

@sbagroy986
Copy link
Author

sbagroy986 commented Jun 6, 2019

@qlzh727: Thanks a ton for your help on this!

Quick follow-up: has this been fixed in the GPU version as well? I tried the (nightly) version from yesterday and it didn't seem to work.

sleighsoft pushed a commit to sleighsoft/tensorflow that referenced this issue Jun 12, 2019
They were missing for the non-defun branch.

See tensorflow#29187 for more
details.

PiperOrigin-RevId: 251442578
@Zoltrix
Copy link

Zoltrix commented Jul 31, 2019

The issue still persists in the beta release

@njwfish
Copy link

njwfish commented Aug 12, 2019

I still have this issue in beta 2.0.0b1

@qlzh727
Copy link
Member

qlzh727 commented Aug 12, 2019

For any of you that still facing the issue, could u provide a snippet to reproduce the issue?

@knobel-dk
Copy link

knobel-dk commented Aug 31, 2019

could u provide a snippet to reproduce the issue?

Similar error even without GRU.

cp_callback = ModelCheckpoint(filepath="checkpoints/", save_weights_only=False, verbose=0, save_best_only=True)

feature_layer = tf.keras.layers.DenseFeatures(feature_columns)

model = tf.keras.Sequential([
  feature_layer,
  layers.Dense(128, activation='relu'),
  layers.Dense(128, activation='relu'),
  layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'], run_eagerly=True)
model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=[cp_callback])

@qlzh727
Copy link
Member

qlzh727 commented Sep 9, 2019

@knobel-dk, I am bit confused about your message, this issue was about the recurrent_dropout for the LSTM/GRU layer, but your code doesn't have any LSTM/GRU layer within it.

Could you be more specific about the error you are facing?

@knobel-dk
Copy link

@knobel-dk, I am bit confused about your message, this issue was about the recurrent_dropout for the LSTM/GRU layer, but your code doesn't have any LSTM/GRU layer within it.

Could you be more specific about the error you are facing?

Thanks. Yes I have confused myself too. Those stateful Jupyter notebooks.. I fixed my problem by updating the TF2 version. Thanks.

@lvenugopalan lvenugopalan added the TF 2.0 Issues relating to TensorFlow 2.0 label Apr 29, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests

9 participants