Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ForwardAccumulator fails with experimental_run_functions_eagerly(True) #39075

Closed
hartikainen opened this issue May 1, 2020 · 4 comments
Closed
Assignees
Labels
comp:eager Eager related issues TF 2.2 Issues related to TF 2.2 type:bug Bug

Comments

@hartikainen
Copy link
Contributor

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macos Catalina
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.2.0rc4
  • Python version: 3.7.5
  • Bazel version (if compiling from source): n/a
  • GCC/Compiler version (if compiling from source): n/a
  • CUDA/cuDNN version: n/a
  • GPU model and memory: n/a

Describe the current behavior
Running the examples in tf.ForwardAccumulator docs fail with RecursionError: maximum recursion depth exceeded when running with tf.config.experimental_run_functions_eagerly(True).

Describe the expected behavior
Running the examples in tf.ForwardAccumulator docs with tf.config.experimental_run_functions_eagerly(True) work the same way as when running with tf.config.experimental_run_functions_eagerly(False).

Standalone code to reproduce the issue
This is the standard example from https://www.tensorflow.org/api_docs/python/tf/autodiff/ForwardAccumulator, with just the experimental_run_functions_eagerly(True) call added.

import tensorflow as tf

tf.config.experimental_run_functions_eagerly(True)


v = tf.Variable([1., 2.])
with tf.autodiff.ForwardAccumulator(
    v,
    # The "vector" in Hessian-vector product.
    tf.constant([1., 0.])) as acc:
  with tf.GradientTape() as tape:
    y = tf.reduce_sum(v ** 3.)
  backward = tape.gradient(y, v)
backward  # gradient from backprop

acc.jvp(backward)  # forward-over-backward Hessian-vector product

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

...
    self._push_tape()
  File "/Users/hartikainen/conda/envs/policy-evaluation/lib/python3.7/site-packages/tensorflow/python/eager/backprop.py", line 849, in _push_tape
    watch_accessed_variables=self._watch_accessed_variables)
  File "/Users/hartikainen/conda/envs/policy-evaluation/lib/python3.7/site-packages/tensorflow/python/eager/tape.py", line 48, in push_new_tape
    return Tape(tape)
RecursionError: maximum recursion depth exceeded
@hartikainen hartikainen added the type:bug Bug label May 1, 2020
@geetachavan1 geetachavan1 added the TF 2.2 Issues related to TF 2.2 label May 1, 2020
@ravikyram
Copy link
Contributor

ravikyram commented May 4, 2020

@hartikainen

I have tried in colab with TF 2.1.0, 2.2-rc4 and i am able to reproduce the issue.With tf.config.experimental_run_functions_eagerly(True) i am able to reproduce the issue.However with tf.config.experimental_run_functions_eagerly(False) i am not seeing any issue.Please, find the gist here.Thanks!

@ravikyram ravikyram added the comp:eager Eager related issues label May 4, 2020
@ravikyram ravikyram assigned jvishnuvardhan and unassigned ravikyram May 4, 2020
@hartikainen
Copy link
Contributor Author

Yep, that's what I see too: fails with tf.config.experimental_run_functions_eagerly(True) and works with tf.config.experimental_run_functions_eagerly(False). Sorry if that was not clear from the title and description.

@jvishnuvardhan jvishnuvardhan added comp:autograph Autograph related issues and removed comp:eager Eager related issues labels May 4, 2020
@mdanatg mdanatg added comp:eager Eager related issues and removed comp:autograph Autograph related issues labels May 4, 2020
@allenlavoie
Copy link
Member

Thank you for the report. I'll opt that forwardprop utility function out of run_functions_eagerly. The change should land in a few hours.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@geetachavan1 geetachavan1 added this to Done in TensorFlow 2.3.0 Jun 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:eager Eager related issues TF 2.2 Issues related to TF 2.2 type:bug Bug
Projects
Development

No branches or pull requests

6 participants