Skip to content

tf.custom_gradient for function with kwarg shows unexpected behavior #77559

@jhoydis

Description

@jhoydis

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

binary

TensorFlow version

2.17.0

Custom code

No

OS platform and distribution

Ubuntu 22.04.5 LTS

Mobile device

No response

Python version

3.12.7

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

12.3/8

GPU model and memory

No response

Current behavior?

I have a function that takes two tensors as inputs, one as argument and one as keyword argument.
The function has a custom gradient.

When tape.gradient for both input tensors with respect to the output of the function is called, TensorFlow throws an error, saying that only one gradient is expect and not two.

When the function is called with both inputs as arguments (and not one of them as kwarg), no error is thrown.

Standalone code to reproduce the issue

@tf.custom_gradient
def func(x, y=0):
    z = 2*x + y
    def grad(dz):
        dx = 2*dz
        dy = dz
        return dx, dy
    return z, grad
x = tf.constant(2.)
y = tf.constant(3.)
with tf.GradientTape() as tape:
    tape.watch([x, y])
    z = func(x, y=y) #func(x, y) does not generate the error
grads = tape.gradient(z, [x, y])

Relevant log output

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[129], line 14
     12     tape.watch([x, y])
     13     z = func(x, y=y)
---> 14 grads = tape.gradient(z, [x, y])

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/backprop.py:1066, in GradientTape.gradient(self, target, sources, output_gradients, unconnected_gradients)
   1060   output_gradients = (
   1061       composite_tensor_gradient.get_flat_tensors_for_gradients(
   1062           output_gradients))
   1063   output_gradients = [None if x is None else ops.convert_to_tensor(x)
   1064                       for x in output_gradients]
-> 1066 flat_grad = imperative_grad.imperative_grad(
   1067     self._tape,
   1068     flat_targets,
   1069     flat_sources,
   1070     output_gradients=output_gradients,
   1071     sources_raw=flat_sources_raw,
   1072     unconnected_gradients=unconnected_gradients)
   1074 if not self._persistent:
   1075   # Keep track of watched variables before setting tape to None
   1076   self._watched_variables = self._tape.watched_variables()

File ~/.local/lib/python3.12/site-packages/tensorflow/python/eager/imperative_grad.py:67, in imperative_grad(tape, target, sources, output_gradients, sources_raw, unconnected_gradients)
     63 except ValueError:
     64   raise ValueError(
     65       "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
---> 67 return pywrap_tfe.TFE_Py_TapeGradient(
     68     tape._tape,  # pylint: disable=protected-access
     69     target,
     70     sources,
     71     output_gradients,
     72     sources_raw,
     73     compat.as_str(unconnected_gradients.value))

File ~/.local/lib/python3.12/site-packages/tensorflow/python/ops/custom_gradient.py:588, in _eager_mode_decorator.<locals>.actual_grad_fn(*result_grad_components)
    585 flat_grads = composite_tensor_gradient.get_flat_tensors_for_gradients(
    586     nest.flatten(input_grads))
    587 if len(flat_grads) != arg_count:
--> 588   raise ValueError(
    589       f"custom_gradient function expected to return {arg_count} "
    590       f"gradients, but returned {len(flat_grads)} instead.")
    591 return flat_grads + variable_grads

ValueError: custom_gradient function expected to return 1 gradients, but returned 2 instead.

Metadata

Metadata

Assignees

Labels

2.17Issues related to 2.17 releasecomp:opsOPs related issuesstat:awaiting tensorflowerStatus - Awaiting response from tensorflowertype:bugBug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions