Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF 2.0] Nested Gradient Tape - unconnected graphs #34335

Closed
janbolle opened this issue Nov 16, 2019 · 20 comments
Closed

[TF 2.0] Nested Gradient Tape - unconnected graphs #34335

janbolle opened this issue Nov 16, 2019 · 20 comments
Assignees
Labels
2.6.0 comp:keras Keras related issues type:bug Bug

Comments

@janbolle
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution: MacOS 10.15.1
  • TensorFlow installed from binary (pip 19.3.1)
  • TensorFlow version: v2.0.0-rc2-26-g64c3d382ca 2.0.0
  • Python version: Python 3.6.5

Describe the current behavior
A copy of my model (model_copy) should be trained one step, then I need my meta_model to be trained with the loss of my model_copy. It seems, that the graphs are unconnected.
It only works, if I use the meta_model for the training step.

Describe the expected behavior
I would expect, that model_copy is known to both gradient tapes and can be used w/o using meta_model.

Code to reproduce the issue

import tensorflow as tf
import tensorflow.keras.backend as keras_backend
import tensorflow.keras as keras

class MetaModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.hidden1 = keras.layers.Dense(5, input_shape=(1,))
        self.out = keras.layers.Dense(1)
    def forward(self, x):
        x = keras.activations.relu(self.hidden1(x))
        x = self.out(x)
        return x

def copy_model(model, x):
    copied_model = MetaModel()
    copied_model.forward(x)
    copied_model.set_weights(model.get_weights())
    return copied_model

def compute_loss(model, x, y):
    logits = model.forward(x)  # prediction of my model
    mse = keras_backend.mean(keras.losses.mean_squared_error(y, logits))  # compute loss between prediciton and label/truth
    return mse, logits

optimizer_outer = keras.optimizers.Adam()
alpha = 0.01
with tf.GradientTape() as g:
    # meta_model to learn in outer gradient tape
    meta_model = MetaModel()
    # inputs for training
    x = tf.constant(3.0, shape=(1, 1, 1))
    y = tf.constant(3.0, shape=(1, 1, 1))

    meta_model.forward(x)
    model_copy = copy_model(meta_model, x)
    with tf.GradientTape() as gg:
        loss, _ = compute_loss(model_copy, x, y)
        gradients = gg.gradient(loss, model_copy.trainable_variables)
        k = 0
        for layer in range(len(model_copy.layers)):
            """ If I use meta-model for updating, this works """
            # model_copy.layers[layer].kernel = tf.subtract(meta_model.layers[layer].kernel,
            #                                               tf.multiply(alpha, gradients[k]))
            # model_copy.layers[layer].bias = tf.subtract(meta_model.layers[layer].bias,
            #                                             tf.multiply(alpha, gradients[k + 1]))

            """ If I use model-copy for updating instead, gradients_meta always will be [None,None,...]"""
            model_copy.layers[layer].kernel = tf.subtract(model_copy.layers[layer].kernel,
                                                          tf.multiply(alpha, gradients[k]))
            model_copy.layers[layer].bias = tf.subtract(model_copy.layers[layer].bias,
                                                        tf.multiply(alpha, gradients[k + 1]))

            k += 2

    # calculate loss of model_copy
    test_loss, _ = compute_loss(model_copy, x, y)
    # build gradients for meta_model update
    gradients_meta = g.gradient(test_loss, meta_model.trainable_variables)
    """ gradients always None !?!!11 elf """
    optimizer_outer.apply_gradients(zip(gradients_meta, meta_model.trainable_variables))

Other info / logs
Is it intended to work as above? This would force me not to be able to use a different optimizer in the inner loop, as the networks need somehow to be connected.

@rmothukuru rmothukuru self-assigned this Nov 18, 2019
@rmothukuru rmothukuru added the TF 2.0 Issues relating to TensorFlow 2.0 label Nov 18, 2019
@rmothukuru
Copy link
Contributor

@janbolle,
When trying to reproduce your issue, I encounter the error, ValueError: No gradients provided for any variable: ['dense_8/kernel:0', 'dense_8/bias:0', 'dense_9/kernel:0', 'dense_9/bias:0'].. Can you please help us reproduce the issue. Here is the Gist. Thanks!

@rmothukuru rmothukuru added stat:awaiting response Status - Awaiting response from author comp:keras Keras related issues type:support Support issues labels Nov 18, 2019
@janbolle
Copy link
Author

@rmothukuru , thanks for your reply.
This is my problem. gradients_meta are always [None, None, ...].
So TF tells me that there are no gradients provided..

@rmothukuru
Copy link
Contributor

@janbolle,
So do you mean you are encountering same error as that of mine. Please confirm.

@janbolle
Copy link
Author

@rmothukuru , yes, same error on my side.

@Rahulmishra07
Copy link

Rahulmishra07 commented Nov 18, 2019 via email

@rmothukuru
Copy link
Contributor

Could reproduce the error with TF Version 2.0. Here is the Gist. Thanks!

@janbolle
Copy link
Author

janbolle commented Nov 18, 2019

Also what is very odd: if I print the weights of the layers before and after the training, they are available. But if I use the function model_copy.get_weights() it results in an empty array.

Following code:

        k = 0
        for layer in range(len(model_copy.layers)):
            # calculate adapted parameters w/ gradient descent
            # \theta_i' = \theta - \alpha * gradients
            print("pre: ", model_copy.layers[layer].kernel.shape, model_copy.layers[layer].kernel)
            model_copy.layers[layer].kernel = tf.subtract(model_copy.layers[layer].kernel,
                                                          tf.multiply(alpha, gradients[k]))
            model_copy.layers[layer].bias = tf.subtract(model_copy.layers[layer].bias,
                                                        tf.multiply(alpha, gradients[k + 1]))
            print("post: ", model_copy.layers[layer].kernel.shape, model_copy.layers[layer].kernel)
            k += 2
    print(model_copy.get_weights())  # results in empty array

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Nov 19, 2019
@janbolle
Copy link
Author

janbolle commented Nov 25, 2019

@jvishnuvardhan do you need further information?
@rmothukuru did you connect the right person?

Also, I think this is a bug, not a support case :-/

Maybe related to #29535

@jvishnuvardhan jvishnuvardhan added type:bug Bug comp:keras Keras related issues and removed comp:keras Keras related issues type:support Support issues labels Nov 26, 2019
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 26, 2019
@gadagashwini-zz gadagashwini-zz added the TF 2.1 for tracking issues in 2.1 release label Mar 19, 2020
@gadagashwini-zz gadagashwini-zz self-assigned this Mar 19, 2020
@gadagashwini-zz
Copy link
Contributor

I was able to replicate the issue with Tf-nightly==2.2.0.dev20200318.
Please find the gist here. Thanks!

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 21, 2020
@ravikyram ravikyram added the TF 2.2 Issues related to TF 2.2 label Jun 12, 2020
@ravikyram
Copy link
Contributor

I was able to replicate the issue with Tf-nightly==2.3.0-dev20200612.Please, find the gist here.Thanks!.

@Saduf2019
Copy link
Contributor

I was able to replicate the issue with Tf-nightly==2.4.0-dev20200806, Please, find the gist here.

@velocirabbit
Copy link

velocirabbit commented Sep 2, 2020

Have there been any updates with this issue? Running into a similar None gradients case when using nested tf.GradientTapes.

@interactivetech
Copy link

Looking for updates on this as well! I am able to get my U-Net model to do an inner update, but this issue shows its only possible to do one inner update.

@sushreebarsa
Copy link
Contributor

Was able to replicate the issue in TF v2.5,please find the gist here..Thanks !

@sushreebarsa
Copy link
Contributor

Was able to replicate the issue with TF 2.6.0-dev20210606,please find the gist here ..Thanks!

@kumariko
Copy link

I could reproduce the issue with TF 2.6 .Please, find the gist here.Thanks!

@kumariko kumariko added 2.6.0 and removed TF 2.0 Issues relating to TensorFlow 2.0 TF 2.1 for tracking issues in 2.1 release TF 2.2 Issues related to TF 2.2 labels Aug 26, 2021
@tensorflowbutler
Copy link
Member

Hi There,

This is a stale issue. As you are using an older version of tensorflow, we are checking to see if you still need help on this issue. Please test the issue with the latest TensorFlow (TF2.7 and tf-nightly). If the issue still persists with the newer versions of TF, please feel free to open it in keras-team/keras repository by providing details about the issue and a standalone code to reproduce the issue. Thanks!

Please note that Keras development has moved to a separate Keras-team/keras repository to focus entirely on only Keras. Thanks!

@andrewemendez
Copy link

Still an issue with TF 2.7. Gist here

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@evanfwelch
Copy link

Was there ever a resolution here? Curious why the issue was closed after the last poster confirmed it was still an issue in TF 2.7.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2.6.0 comp:keras Keras related issues type:bug Bug
Projects
None yet
Development

No branches or pull requests