Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward-mode Autodiff gives wrong nan when argument segment_ids[0] is not zero for API tf.math.segment_max #57002

Closed
eeDigitalSeal opened this issue Aug 3, 2022 · 5 comments
Assignees
Labels
comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug

Comments

@eeDigitalSeal
Copy link

eeDigitalSeal commented Aug 3, 2022

Click to expand!

Issue Type

Bug

Source

binary

Tensorflow Version

tf 2.9

Custom Code

Yes

OS Platform and Distribution

Linux Ubuntu 20.04

Mobile device

No response

Python version

3.9

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current Behaviour?

If the first element of segment_ids is not zero, the forward mode autodiff will give nan as gradient, which is incorrect. It should give the same result as reverse-mode did.

Standalone code to reproduce the issue

import tensorflow as tf

data = tf.constant(
[[0.47682607, 0.2620497, 0.22771002],
 [0.59265005, 0.62705662, 0.81160802]], shape=(2, 3), dtype=tf.float64)


with tf.GradientTape(persistent=True) as g:
 g.watch(data)
 res_backward = tf.math.segment_max(data,[1,1])
grad_backward = g.gradient(res_backward,data)
print("reverse-mode:\n",grad_backward)

grad_fwd_arr = []

for i in range(tf.size(data)):
    tangents = tf.reshape(tf.one_hot(i,tf.size(data),dtype=tf.float64),shape=data.shape)
    with tf.autodiff.ForwardAccumulator(data, tangents) as acc:
        res_forward = tf.math.segment_max(data,[1,1])
        jvp = acc.jvp(res_forward)
        grad_fwd_arr.append(tf.reduce_sum(jvp))

grad_fwd = tf.reshape(tf.convert_to_tensor(grad_fwd_arr),shape=data.shape)
print("forward-mode:\n",grad_fwd)

Relevant log output

reverse-mode:
 tf.Tensor(
[[0. 0. 0.]
 [1. 1. 1.]], shape=(2, 3), dtype=float64)
forward-mode:
 tf.Tensor(
[[nan nan nan]
 [nan nan nan]], shape=(2, 3), dtype=float64)
@sushreebarsa
Copy link
Contributor

sushreebarsa commented Aug 4, 2022

@gadagashwini I was able to replicate the issue on colab, please find the gist here. Thank you!

@sushreebarsa sushreebarsa added comp:ops OPs related issues TF 2.9 Issues found in the TF 2.9 release (or RCs) labels Aug 4, 2022
@sachinprasadhs
Copy link
Contributor

@eeDigitalSeal , Could you please refer the comment here about gradient computation behavior in forward and backward mode.

@sachinprasadhs sachinprasadhs added the stat:awaiting response Status - Awaiting response from author label Aug 10, 2022
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Aug 17, 2022
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug
Projects
None yet
Development

No branches or pull requests

4 participants