-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eager execution guide: using GradientTape with keras.model and tf.keras.layer #20630
Comments
Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks. |
Sorry, filled in now. |
@skeydan : Thanks for bringing this up. The documentation you're pointing to is a bit misleading, we shouldn't override I'll update the getting started guide. And as for other examples, there are a bunch in https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples, and we'll be adding more to www.tensorflow.org soon. As for an explanation, import tensorflow as tf
tf.enable_eager_execution()
n = 10
x = tf.random_normal([n, 2])
noise = tf.random_normal([n, 2])
y = x * 3 + 2 + noise
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.dense1 = tf.keras.layers.Dense(units = 1, activation='relu')
def call(self, inputs):
return self.dense1(x)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
model = Model()
with tf.GradientTape() as tape:
error = model(x) - y # NOT model.predict(x) - y
loss_value = tf.reduce_mean(tf.square(error))
gradients = tape.gradient(loss_value, model.variables)
print(gradients)
optimizer.apply_gradients(zip(gradients, model.variables),
global_step=tf.train.get_or_create_global_step()) And similarly, if we change the first example to override import tensorflow as tf
tf.enable_eager_execution()
tfe = tf.contrib.eager
n = 10
x = tf.random_normal([n, 2])
noise = tf.random_normal([n, 2])
y = x * 3 + 2 + noise
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.W = tfe.Variable(5., name='weight')
self.B = tfe.Variable(10., name='bias')
# Overriding call not predict
def call(self, inputs):
return inputs * self.W + self.B
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
model = Model()
with tf.GradientTape() as tape:
error = model(x) - y
loss_value = tf.reduce_mean(tf.square(error))
gradients = tape.gradient(loss_value, model.variables)
print(gradients)
optimizer.apply_gradients(zip(gradients, model.variables),
global_step=tf.train.get_or_create_global_step()) FYI @random-forests @fchollet @pavithrasv @yashk2810 - regarding documentation or other improvements that we might be able to make to reduce confusion between Hope that helps. |
Thank you for the detailed explanation! |
…t()). Fixes tensorflow#20630 PiperOrigin-RevId: 204036333
…t()). Fixes tensorflow#20630 PiperOrigin-RevId: 204036333
|
…t()). Fixes tensorflow#20630 PiperOrigin-RevId: 204036333
System information
Describe the problem
The eager execution doc
https://www.tensorflow.org/programmers_guide/eager
does not provide a simple complete example that shows how to use gradient tape optimization with keras layers (instead of tfe.Variables).
It would be great if that could be added, as I seem to be getting gradients that are None when I try to replace tfe.Variables in the following, running, example (copied from the doc but even more simplified for ease of experimentation):
by a Keras layer instead of weight and a bias (based on, but again simplified, the MNISTModel from the doc):
In this version, the gradients are None. Same if I use another construct from the doc (again, simplified):
It would be great if the doc could be extended to show a complete tf.keras.layer example with GradientTape. Thank you!
The text was updated successfully, but these errors were encountered: