Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Func with custom gradient using tf.numpy_function/tf.py_function incompatible with tf.vectorized_map #53726

Open
antalszava opened this issue Jan 11, 2022 · 3 comments
Assignees
Labels
comp:eager Eager related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.7 Issues related to TF 2.7.0 type:bug Bug

Comments

@antalszava
Copy link

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 20.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): installed via pip
  • TensorFlow version (use command below): v2.7.0-rc1-69-gc256c071bb2 2.7.0
  • Python version: 3.8.5
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: -
  • GPU model and memory: -

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior

When using tf.vectorized_map on a function that provides a custom gradient and the function uses either tf.numpy_function or tf.py_function, then unexpected behaviour arises:

  • tf.numpy_function: the gradient is a vector of zeros;
  • tf.py_function: UnknownError: KeyError: b'pyfunc_12' error arises.

Describe the expected behavior
The gradient is computed without issues.

Contributing

  • Do you want to contribute a PR? (yes/no): no
  • Briefly describe your candidate solution(if contributing):

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

import tensorflow as tf
import numpy as np

@tf.function
@tf.custom_gradient
def sin(x):

    # Note: change the following line to use tf.numpy_function to
    # reproduce the other issue mentioned
    res = tf.py_function(func=np.sin, inp=(x,), Tout=tf.float64)

    def grad_fn(dy):
        j = tf.cos(x)
        return dy * j

    return res, grad_fn

inputs = tf.Variable(tf.ones((10,), dtype=tf.float64))

with tf.GradientTape() as tape:
    loss = tf.reduce_sum([sin(x) for x in inputs])

print("loss:", loss)
print("gradient:", tape.gradient(loss, inputs))

with tf.GradientTape() as tape:
    loss = tf.reduce_sum(tf.vectorized_map(sin, inputs))

print("Vectorized loss:", loss)
print("Vectorized gradient:", tape.gradient(loss, inputs))

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

Output when using

loss: tf.Tensor(8.414709848078965, shape=(), dtype=float64)
gradient: tf.Tensor(
[0.54030231 0.54030231 0.54030231 0.54030231 0.54030231 0.54030231
 0.54030231 0.54030231 0.54030231 0.54030231], shape=(10,), dtype=float64)
Vectorized loss: tf.Tensor(8.414709848078965, shape=(), dtype=float64)
Vectorized gradient: tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(10,), dtype=float64)
loss: tf.Tensor(8.414709848078965, shape=(), dtype=float64)
gradient: tf.Tensor(
[0.54030231 0.54030231 0.54030231 0.54030231 0.54030231 0.54030231
 0.54030231 0.54030231 0.54030231 0.54030231], shape=(10,), dtype=float64)
Vectorized loss: tf.Tensor(8.414709848078965, shape=(), dtype=float64)

Then getting UnknownError: KeyError: b'pyfunc_12'.

For both, a WARNING:tensorflow:Using a while_loop for converting EagerPyFunc is also emitted.

@mohantym
Copy link
Contributor

Hi @Saduf2019 ! Could you please look at this issue ? Attaching Gist in 2.6 and 2.7 for reference. Thanks!

@mohantym mohantym assigned Saduf2019 and unassigned mohantym Jan 12, 2022
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 10, 2022
@JW1992
Copy link
Contributor

JW1992 commented Feb 10, 2022

Hi @antalszava , thanks for reporting this! This does look like a bug.

The fix might take a while. Would removing e.g. tf.py_function unblock you for now?

@antalszava
Copy link
Author

Hi @JW1992, thank you! Unfortunately not, using tf.py_function in the example code would be required for the use case in mind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:eager Eager related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.7 Issues related to TF 2.7.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants