Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to apply gradient clipping in TensorFlow 2.0? #28707

Closed
duancaohui opened this issue May 14, 2019 · 10 comments
Closed

How to apply gradient clipping in TensorFlow 2.0? #28707

duancaohui opened this issue May 14, 2019 · 10 comments
Assignees
Labels
comp:apis Highlevel API related issues contrib Anything that comes under contrib directory TF 2.0 Issues relating to TensorFlow 2.0 type:feature Feature requests

Comments

@duancaohui
Copy link

Please make sure that this is a feature request. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template

System information

  • TensorFlow version (you are using): Tensorflow 2.0 Alpha
  • Are you willing to contribute it (Yes/No): No

Describe the feature and the current behavior/state.

Will this change the current api? How?

Who will benefit with this feature?

Any Other info.

I want to apply gradient clipping in TF 2.0, the best solution is to decorator optimizer with tf.contrib.estimator.clip_gradients_by_norm in TF 1.x.

However, I can't find this function in TF2.0 after trying many methods. As I know, the tf.contrib has been clean up in TF 2.0

@achandraa achandraa self-assigned this May 15, 2019
@achandraa achandraa added 2.0.0-alpha0 contrib Anything that comes under contrib directory comp:apis Highlevel API related issues type:support Support issues labels May 15, 2019
@achandraa achandraa assigned ymodak and unassigned achandraa May 15, 2019
@ymodak ymodak added type:feature Feature requests and removed type:support Support issues labels May 17, 2019
@ymodak ymodak assigned yhliang2018 and unassigned ymodak May 17, 2019
@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 17, 2019
@yhliang2018
Copy link
Contributor

@tanzhenyu Is this already supported in keras optimizers v2?

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 15, 2019
@duancaohui
Copy link
Author

A simple method to apply gradient clipping in TensorFlow 2.0:

from tensorflow.keras import optimizers
sgd = optimizers.SGD(lr=0.01, clipvalue=0.5)

@akanyaani
Copy link

akanyaani commented Aug 19, 2019

Hi,

You can clip the gradients as we used to do in tfx1.0
Try this code

with tf.GradientTape() as tape:
            predictions= model(inputs, training=True)
            loss = get_loss(targets, predictions)

gradients = tape.gradient(loss, model.trainable_variables)
gradients = [(tf.clip_by_value(grad, -1.0, 1.0))
                                  for grad in gradients]
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

@tanzhenyu
Copy link
Contributor

There are essentially 2 ways to do this, as mentioned above, so I will just summarize here:

  1. pass clipvalue or clipnorm to optimizer constructor, this will clip all gradients
  2. customized clip using gradients=tape.gradient, gradients=[tf.process_gradient_???(grad) for grad in gradients]

They are both correct, it's just the 2nd option gives you more flexibility.

@lvenugopalan lvenugopalan added the TF 2.0 Issues relating to TensorFlow 2.0 label Apr 29, 2020
@thedomdom
Copy link

As it is mentioned here, there are cases where optimizer clipvalue & clipnorm are totally ignored in tf 2.0 and 2.1.
However, it is fixed in tf 2.2 (see here)

@caixxiong
Copy link

caixxiong commented Sep 2, 2020

There are essentially 2 ways to do this, as mentioned above, so I will just summarize here:

  1. pass clipvalue or clipnorm to optimizer constructor, this will clip all gradients
  2. customized clip using gradients=tape.gradient, gradients=[tf.process_gradient_???(grad) for grad in gradients]

They are both correct, it's just the 2nd option gives you more flexibility.

If we use clipnorm=1 in the constructor of keras.optimizers.Optimizer, the optimizer clip gradients using clipnorm for each Variable, not the global norm for all Variable. An example is as followings.

import tensorflow as tf
from tensorflow import keras 
 
x = tf.Variable([3.0, 4.0]) 
y = tf.Variable([1.0, 1.0, 1.0, 1.0]) 
z = tf.reduce_sum(x ** 2) + tf.reduce_sum(y) 
adam = keras.optimizers.Adam(0.01, clipnorm=1.0) 
grads = adam.get_gradients(z, [x, y]) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print(sess.run(grads))  
# outputs: [0.6, 0.8], [0.5, 0.5, 0.5, 0.5]
# that means the optimizer clip gradients using the clipnorm parameter as a local norm for each Variable

@tanzhenyu
Copy link
Contributor

tanzhenyu commented Sep 2, 2020

There are essentially 2 ways to do this, as mentioned above, so I will just summarize here:

  1. pass clipvalue or clipnorm to optimizer constructor, this will clip all gradients
  2. customized clip using gradients=tape.gradient, gradients=[tf.process_gradient_???(grad) for grad in gradients]

They are both correct, it's just the 2nd option gives you more flexibility.

If we use clipnorm=1 in the constructor of keras.optimizers.Optimizer, the optimizer clip gradients using clipnorm for each Variable, not the global norm for all Variable. An example is as followings.

import tensorflow as tf
from tensorflow import keras 
 
x = tf.Variable([3.0, 4.0]) 
y = tf.Variable([1.0, 1.0, 1.0, 1.0]) 
z = tf.reduce_sum(x ** 2) + tf.reduce_sum(y) 
adam = keras.optimizers.Adam(0.01, clipnorm=1.0) 
grads = adam.get_gradients(z, [x, y]) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print(sess.run(grads))  
# outputs: [0.6, 0.8], [0.5, 0.5, 0.5, 0.5]
# that means the optimizer clip gradients using the clipnorm parameter as a local norm for each Variable

global_clipnorm is your solution. Though that's a 2.4 feature or tf-nightly given we recently pushed it

@caixxiong
Copy link

There are essentially 2 ways to do this, as mentioned above, so I will just summarize here:

  1. pass clipvalue or clipnorm to optimizer constructor, this will clip all gradients
  2. customized clip using gradients=tape.gradient, gradients=[tf.process_gradient_???(grad) for grad in gradients]

They are both correct, it's just the 2nd option gives you more flexibility.

If we use clipnorm=1 in the constructor of keras.optimizers.Optimizer, the optimizer clip gradients using clipnorm for each Variable, not the global norm for all Variable. An example is as followings.

import tensorflow as tf
from tensorflow import keras 
 
x = tf.Variable([3.0, 4.0]) 
y = tf.Variable([1.0, 1.0, 1.0, 1.0]) 
z = tf.reduce_sum(x ** 2) + tf.reduce_sum(y) 
adam = keras.optimizers.Adam(0.01, clipnorm=1.0) 
grads = adam.get_gradients(z, [x, y]) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print(sess.run(grads))  
# outputs: [0.6, 0.8], [0.5, 0.5, 0.5, 0.5]
# that means the optimizer clip gradients using the clipnorm parameter as a local norm for each Variable

global_clipnorm is your solution. Though that's a 2.4 feature or tf-nightly given we recently pushed it
Thanks! How to perform global gradient clip at tf 1.x?

@tanzhenyu
Copy link
Contributor

There are essentially 2 ways to do this, as mentioned above, so I will just summarize here:

  1. pass clipvalue or clipnorm to optimizer constructor, this will clip all gradients
  2. customized clip using gradients=tape.gradient, gradients=[tf.process_gradient_???(grad) for grad in gradients]

They are both correct, it's just the 2nd option gives you more flexibility.

If we use clipnorm=1 in the constructor of keras.optimizers.Optimizer, the optimizer clip gradients using clipnorm for each Variable, not the global norm for all Variable. An example is as followings.

import tensorflow as tf
from tensorflow import keras 
 
x = tf.Variable([3.0, 4.0]) 
y = tf.Variable([1.0, 1.0, 1.0, 1.0]) 
z = tf.reduce_sum(x ** 2) + tf.reduce_sum(y) 
adam = keras.optimizers.Adam(0.01, clipnorm=1.0) 
grads = adam.get_gradients(z, [x, y]) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print(sess.run(grads))  
# outputs: [0.6, 0.8], [0.5, 0.5, 0.5, 0.5]
# that means the optimizer clip gradients using the clipnorm parameter as a local norm for each Variable

global_clipnorm is your solution. Though that's a 2.4 feature or tf-nightly given we recently pushed it
Thanks! How to perform global gradient clip at tf 1.x?

If you're using tf-nightly, then tf.compat.v1 everything

@jhanilesh96
Copy link

As it is mentioned here, there are cases where optimizer clipvalue & clipnorm are totally ignored in tf 2.0 and 2.1.
However, it is fixed in tf 2.2 (see here)

The fix doesn't seem to work in tf 2,2, and is also not working for tf 2.3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues contrib Anything that comes under contrib directory TF 2.0 Issues relating to TensorFlow 2.0 type:feature Feature requests
Projects
None yet
Development

No branches or pull requests