Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with custom grad with multiple external variables #43535

Open
RochMollero opened this issue Sep 24, 2020 · 12 comments
Open

Problem with custom grad with multiple external variables #43535

RochMollero opened this issue Sep 24, 2020 · 12 comments
Assignees
Labels
comp:core issues related to core part of tensorflow comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:contribution welcome Status - Contributions welcome TF 2.11 Issues related to TF 2.11 type:bug Bug

Comments

@RochMollero
Copy link

RochMollero commented Sep 24, 2020

Describe the current behavior

The "variables" argument in custom grad seems to be buggy. For example this code, which is a complexified version of the "custom_grad" example in the doc:

import tensorflow as tf
import numpy as np
import time

m=1000

varr=tf.Variable(200.0)
varr3=tf.Variable(200.0)
varr2=tf.constant(1.0)

@tf.function
@tf.custom_gradient
def log1pexp(x1, x2):
  #some basic function fluff 
  e=tf.exp(x1)
  e=tf.math.log(1 + e)
  for i in range(m):
  	e += 0.0*tf.exp(x1) + 1.0

  #here comes the problem
  print("tracing")
  def grad(dy, dz, variables=[varr, varr3]):
    return [dy * (1 - 1 / (1 + e)) + 0 * varr2, tf.constant(0.0)], [0.0 * varr, 0.0*varr3]

  return [e+varr, varr3], grad#

x = tf.constant(1.0)
with tf.GradientTape() as g:
  g.watch(x)
  y=log1pexp(x,0.0)

dy_dx = g.gradient(y, x)
print (dy_dx)

Leads to the error:
ValueError: Must return gradient for each variable from @custom_gradient grad_fn.

Which i seem to be doing (log1pexp has 2 inputs, outputs and two external variables, so the return size should be list of size 2 each ?)

However if I change line 23 to
return [dy * (1 - 1 / (1 + e)) + 0 * varr2, tf.constant(0.0)], [0.0 * varr]

Then it doesn't give an error, which doesn't make sense to me since in that case, grad_vars only contain gradient info for ONE of the external variables I have registered in the variables parameters.

Also it doesn't care about varr2 at all ...

Describe the expected behavior

The code above should not give error?

Standalone code to reproduce the issue

See above

Other info / logs

Basic ubuntu 20. error with tf 2.2 and 2.3

@bhack
Copy link
Contributor

bhack commented Sep 24, 2020

The issue is that actual_grad_fn receive only 1 variable (in this case e+varr) but two variable_grads:

input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
if len(variable_grads) != len(variables):

@RochMollero
Copy link
Author

RochMollero commented Sep 24, 2020

Are you telling me it's a tf problem or a problem in my code ?

I'm not sure what I should read in the two lines: it seems logiclal. According to my code the two "variables" actual_grad_fn receive should be varr and varr3, (i'm not sure why you said it only receive "e+varr" since it's an output of fn) and indeed there 's only one i guess

@bhack
Copy link
Contributor

bhack commented Sep 24, 2020

Yes It was just a typo I meant varr.
Edit this file locally and try to add print(variables) before the mentioned condition

@RochMollero
Copy link
Author

RochMollero commented Sep 24, 2020

OK so I've inspected with the code:

  variables_in_tape = frozenset([
    v.ref() for v in variable_watcher.watched_variables()
  ])

and I found something:

In the above code, varr3 is not registered in variables_in_tape seemingly because it's used as an "identity function" and there's no operation involving it

For example if I change my function to return this (just adding 0.0 to varr3 so the function is essentially the same) :
return [e+varr, 0.0+varr3], grad

Then tensorflow will register varr3, and as expected will require me to provide gradient for both varr and varr3.

I'm pretty sure this is not expected behavior from the tensorflow team ... not sure how to escalate but you guys should check this out.

Here's code to test it

import tensorflow as tf
import numpy as np
import time

m=1000

varr=tf.Variable(200.0)
varr3=tf.Variable(200.0)
varr2=tf.constant(1.0)

@tf.function
@tf.custom_gradient
def log1pexp(x1, x2):
  
  e=tf.exp(x1)
  e=tf.math.log(1 + e)
  
  for i in range(m):
  	e += 0.0*tf.exp(x1) + 1.0

  print("tracing")
  def grad(dy, dz, variables=[varr, varr3]):
    return [dy * (1 - 1 / (1 + e)) + 0 * varr2, tf.constant(0.0)], [0.0 * varr, 0.0 * varr3]#[0.0 * varr, 0.0 * varr3]

  return [e+varr, 0.0+varr3], grad#[e+varr, varr3]



#Debug.. works only if tf.function and tf.custom_grad are commented
#change with the commentaed changes above to see the difference...
custom_gradd=True
if not custom_gradd:
  from tensorflow.python.util import tf_inspect
  from tensorflow.python.eager import tape as tape_lib
  with tape_lib.VariableWatcher() as variable_watcher:
    result, grad_fn = log1pexp(0.0,0.0)

  variables_in_tape = frozenset([
    v.ref() for v in variable_watcher.watched_variables()
  ])

  print ("************")
  print (variables_in_tape)
  print ("************")
  print(tf_inspect.getfullargspec(log1pexp(0.0,0.0)[1]))


x = tf.constant(1.0)

#t0 = time.time()
#print(log1pexp(x,0.0))
#t1 = time.time()
#print ("Tracing run:", t1-t0)
#
#t0 = time.time()
#print(log1pexp(x,0.0))
#t1 = time.time()
#print ("Traced run:", t1-t0)
#
#m=2000
#varr.assign(400.0)
#t0 = time.time()
#print(log1pexp(x,0.0))
#t1 = time.time()
#print ("Modified Traced run:", t1-t0)


with tf.GradientTape() as g:
  g.watch(x)
  y=log1pexp(x,0.0)

dy_dx = g.gradient(y, x) # Will compute to 6.0
print (dy_dx)

#x = tf.constant(3.0)
#with tf.GradientTape() as g:
#  g.watch(x)
#  y = x * x
#dy_dx = g.gradient(y, x) # Will compute to 6.0

@bhack
Copy link
Contributor

bhack commented Sep 24, 2020

/cc @rohan100jain

@Saduf2019
Copy link
Contributor

@RochMollero
I ran your code on tf 2.3 and nightly, i do not face the error reported above, please find the gist here.

@Saduf2019 Saduf2019 added stat:awaiting response Status - Awaiting response from author TF 2.3 Issues related to TF 2.3 labels Sep 25, 2020
@bhack
Copy link
Contributor

bhack commented Sep 25, 2020

@Saduf2019 The user is asking something different.

/cc @jaingaurav @wangpengmit can you give a feedback here?

@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Oct 2, 2020
@RochMollero
Copy link
Author

@RochMollero
I ran your code on tf 2.3 and nightly, i do not face the error reported above, please find the gist here.

precisely because you ran the working version only. the not working version is when you set

return [e+varr, varr3],

(without the 0.0*)

@google-ml-butler google-ml-butler bot removed the stale This label marks the issue/pr stale - to be closed automatically if no activity label Oct 2, 2020
@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Oct 4, 2020
@Saduf2019 Saduf2019 added the comp:autograph Autograph related issues label Oct 5, 2020
@Saduf2019 Saduf2019 assigned gowthamkpr and unassigned Saduf2019 Oct 5, 2020
@gowthamkpr gowthamkpr assigned nikitamaia and unassigned gowthamkpr Oct 13, 2020
@nikitamaia nikitamaia added comp:core issues related to core part of tensorflow comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed comp:autograph Autograph related issues labels Oct 13, 2020
@rohan100jain
Copy link
Member

You're right here. We have some logic that detects what variables are used within a custom_gradient function and that only works when there is actual variable read (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/pywrap_tfe_src.cc#L3294) and not when a variable is simply passed through to the output. Right now we don't have plans to fix this immediately, so I'd recommend just doing 1.0*varr or something so that a read is triggered.

On the other hand, contributions are welcome to fix this! One suggestion would be to look at the outputs from running the function (

result, grad_fn = f(*args)
) and see if there are any variables we missed and add those to the list.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 28, 2020
@nikitamaia nikitamaia added the stat:contribution welcome Status - Contributions welcome label Apr 26, 2021
@sachinprasadhs
Copy link
Contributor

sachinprasadhs commented May 28, 2021

Was able to reproduce your issue in TF 2.11, please find the gist here. Thanks!

@pjpratik pjpratik added TF 2.11 Issues related to TF 2.11 and removed TF 2.3 Issues related to TF 2.3 labels Dec 30, 2022
@github-actions
Copy link

This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:core issues related to core part of tensorflow comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:contribution welcome Status - Contributions welcome TF 2.11 Issues related to TF 2.11 type:bug Bug
Projects
None yet
Development

No branches or pull requests

10 participants