-
Notifications
You must be signed in to change notification settings - Fork 75.3k
tf.custom_gradient for function with kwarg shows unexpected behavior #77559
Copy link
Copy link
Open
Labels
2.17Issues related to 2.17 releaseIssues related to 2.17 releasecomp:opsOPs related issuesOPs related issuesstat:awaiting tensorflowerStatus - Awaiting response from tensorflowerStatus - Awaiting response from tensorflowertype:bugBugBug
Description
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
No
Source
binary
TensorFlow version
2.17.0
Custom code
No
OS platform and distribution
Ubuntu 22.04.5 LTS
Mobile device
No response
Python version
3.12.7
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
12.3/8
GPU model and memory
No response
Current behavior?
I have a function that takes two tensors as inputs, one as argument and one as keyword argument.
The function has a custom gradient.
When tape.gradient for both input tensors with respect to the output of the function is called, TensorFlow throws an error, saying that only one gradient is expect and not two.
When the function is called with both inputs as arguments (and not one of them as kwarg), no error is thrown.
Standalone code to reproduce the issue
@tf.custom_gradient
def func(x, y=0):
z = 2*x + y
def grad(dz):
dx = 2*dz
dy = dz
return dx, dy
return z, grad
x = tf.constant(2.)
y = tf.constant(3.)
with tf.GradientTape() as tape:
tape.watch([x, y])
z = func(x, y=y) #func(x, y) does not generate the error
grads = tape.gradient(z, [x, y])Relevant log output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[129], line 14
12 tape.watch([x, y])
13 z = func(x, y=y)
---> 14 grads = tape.gradient(z, [x, y])
File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/backprop.py:1066, in GradientTape.gradient(self, target, sources, output_gradients, unconnected_gradients)
1060 output_gradients = (
1061 composite_tensor_gradient.get_flat_tensors_for_gradients(
1062 output_gradients))
1063 output_gradients = [None if x is None else ops.convert_to_tensor(x)
1064 for x in output_gradients]
-> 1066 flat_grad = imperative_grad.imperative_grad(
1067 self._tape,
1068 flat_targets,
1069 flat_sources,
1070 output_gradients=output_gradients,
1071 sources_raw=flat_sources_raw,
1072 unconnected_gradients=unconnected_gradients)
1074 if not self._persistent:
1075 # Keep track of watched variables before setting tape to None
1076 self._watched_variables = self._tape.watched_variables()
File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/imperative_grad.py:67, in imperative_grad(tape, target, sources, output_gradients, sources_raw, unconnected_gradients)
63 except ValueError:
64 raise ValueError(
65 "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
---> 67 return pywrap_tfe.TFE_Py_TapeGradient(
68 tape._tape, # pylint: disable=protected-access
69 target,
70 sources,
71 output_gradients,
72 sources_raw,
73 compat.as_str(unconnected_gradients.value))
File ~/.local/lib/python3.12/site-packages/tensorflow/python/ops/custom_gradient.py:588, in _eager_mode_decorator.<locals>.actual_grad_fn(*result_grad_components)
585 flat_grads = composite_tensor_gradient.get_flat_tensors_for_gradients(
586 nest.flatten(input_grads))
587 if len(flat_grads) != arg_count:
--> 588 raise ValueError(
589 f"custom_gradient function expected to return {arg_count} "
590 f"gradients, but returned {len(flat_grads)} instead.")
591 return flat_grads + variable_grads
ValueError: custom_gradient function expected to return 1 gradients, but returned 2 instead.Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
2.17Issues related to 2.17 releaseIssues related to 2.17 releasecomp:opsOPs related issuesOPs related issuesstat:awaiting tensorflowerStatus - Awaiting response from tensorflowerStatus - Awaiting response from tensorflowertype:bugBugBug