Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager execution guide: using GradientTape with keras.model and tf.keras.layer #20630

Closed
skeydan opened this issue Jul 8, 2018 · 5 comments
Closed
Assignees

Comments

@skeydan
Copy link

skeydan commented Jul 8, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): no
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Fedora 28
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 1.9
  • Python version: 3.6
  • Bazel version: N/A
  • CUDA/cuDNN version: N/A
  • GPU model and memory: N/A
  • Exact command to reproduce:N/A

Describe the problem

The eager execution doc

https://www.tensorflow.org/programmers_guide/eager

does not provide a simple complete example that shows how to use gradient tape optimization with keras layers (instead of tfe.Variables).
It would be great if that could be added, as I seem to be getting gradients that are None when I try to replace tfe.Variables in the following, running, example (copied from the doc but even more simplified for ease of experimentation):

import tensorflow as tf
tf.enable_eager_execution()
tfe = tf.contrib.eager
n = 10
x = tf.random_normal([n, 2])
noise = tf.random_normal([n, 2])
y = x * 3 + 2 + noise

class Model(tf.keras.Model):
  def __init__(self):
    super(Model, self).__init__()
    self.W = tfe.Variable(5., name='weight')
    self.B = tfe.Variable(10., name='bias')
  def predict(self, inputs):
    return inputs * self.W + self.B

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

model = Model()
with tf.GradientTape() as tape:
  error = model.predict(x) - y
  loss_value = tf.reduce_mean(tf.square(error))
gradients = tape.gradient(loss_value, model.variables)
print(gradients)
optimizer.apply_gradients(zip(gradients, model.variables),
                            global_step=tf.train.get_or_create_global_step())

by a Keras layer instead of weight and a bias (based on, but again simplified, the MNISTModel from the doc):

import tensorflow as tf
tf.enable_eager_execution()

n = 10
x = tf.random_normal([n, 2])
noise = tf.random_normal([n, 2])

y = x * 3 + 2 + noise

class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.dense1 = tf.keras.layers.Dense(units = 1, activation='relu')
    def call(self, inputs):
        return self.dense1(x)

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

model = Model()
with tf.GradientTape() as tape:
  error = model.predict(x) - y
  loss_value = tf.reduce_mean(tf.square(error))
gradients = tape.gradient(loss_value, model.variables)
# now gradients are None
print(gradients)
optimizer.apply_gradients(zip(gradients, model.variables),
                            global_step=tf.train.get_or_create_global_step())
                            

In this version, the gradients are None. Same if I use another construct from the doc (again, simplified):

model = tf.keras.Sequential([
  tf.keras.layers.Dense(1, input_shape=(2,))  # must declare input shape
])

It would be great if the doc could be extended to show a complete tf.keras.layer example with GradientTape. Thank you!

@tensorflowbutler tensorflowbutler added the stat:awaiting response Status - Awaiting response from author label Jul 9, 2018
@tensorflowbutler
Copy link
Member

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks.
Bazel version
CUDA/cuDNN version
GPU model and memory
Exact command to reproduce

@skeydan
Copy link
Author

skeydan commented Jul 9, 2018

Sorry, filled in now.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jul 9, 2018
@asimshankar
Copy link
Contributor

@skeydan : Thanks for bringing this up. The documentation you're pointing to is a bit misleading, we shouldn't override predict there, we should override call since tf.keras.Model.predict isn't meant to be overridden.

I'll update the getting started guide. And as for other examples, there are a bunch in https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples, and we'll be adding more to www.tensorflow.org soon.

As for an explanation, tf.keras.Model.predict returns a numpy.ndarray, not a tf.Tensor. And the TensorFlow libraries cannot differentiate through numpy conversions or operations. So, if you change your second example to use model(x) instead of model.predict(x), it should work out fine:

import tensorflow as tf
tf.enable_eager_execution()

n = 10
x = tf.random_normal([n, 2])
noise = tf.random_normal([n, 2])

y = x * 3 + 2 + noise

class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.dense1 = tf.keras.layers.Dense(units = 1, activation='relu')
    def call(self, inputs):
        return self.dense1(x)

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

model = Model()
with tf.GradientTape() as tape:
  error = model(x) - y  # NOT model.predict(x) - y
  loss_value = tf.reduce_mean(tf.square(error))
gradients = tape.gradient(loss_value, model.variables)
print(gradients)
optimizer.apply_gradients(zip(gradients, model.variables),
                            global_step=tf.train.get_or_create_global_step())

And similarly, if we change the first example to override call instead of predict, it would work out fine too (which is the documentation fix to be made):

import tensorflow as tf
tf.enable_eager_execution()
tfe = tf.contrib.eager
n = 10
x = tf.random_normal([n, 2])
noise = tf.random_normal([n, 2])
y = x * 3 + 2 + noise

class Model(tf.keras.Model):
  def __init__(self):
    super(Model, self).__init__()
    self.W = tfe.Variable(5., name='weight')
    self.B = tfe.Variable(10., name='bias')
  # Overriding call not predict
  def call(self, inputs):
    return inputs * self.W + self.B

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

model = Model()
with tf.GradientTape() as tape:
  error = model(x) - y
  loss_value = tf.reduce_mean(tf.square(error))
gradients = tape.gradient(loss_value, model.variables)
print(gradients)
optimizer.apply_gradients(zip(gradients, model.variables),
                            global_step=tf.train.get_or_create_global_step())

FYI @random-forests @fchollet @pavithrasv @yashk2810 - regarding documentation or other improvements that we might be able to make to reduce confusion between tf.keras.Model.predict(x) and tf.keras.Model(x).

Hope that helps.

@skeydan
Copy link
Author

skeydan commented Jul 11, 2018

Thank you for the detailed explanation!

lamberta pushed a commit to lamberta/tensorflow that referenced this issue Jul 11, 2018
lamberta pushed a commit to lamberta/tensorflow that referenced this issue Jul 11, 2018
@pavithrasv
Copy link
Member

  • FYI @raymond-yuan (who is also working on creating some examples for our documentation.)

lamberta pushed a commit to lamberta/tensorflow that referenced this issue Jul 16, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants