Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bug #7

Closed
rainwoodman opened this issue Nov 5, 2021 · 5 comments · Fixed by #8
Closed

Potential bug #7

rainwoodman opened this issue Nov 5, 2021 · 5 comments · Fixed by #8

Comments

@rainwoodman
Copy link
Contributor

The structure of the train_step in cell 8 of the notebook is very unconventional.

def train_step(self, data):
... first model evaluation
... first tape gradient
... second model evaluation
... update parameters
... second tape gradient      

Usually for the parameter update to affect the second tape gradient the update shall be before the second model evaluation.

def train_step(self, data):
... first model evaluation
... first tape gradient
... update parameters
... second model evaluation
... second tape gradient      

@rainwoodman
Copy link
Contributor Author

For example, the second sequence gives different gradients:

import tensorflow as tf
import numpy as np

x = tf.Variable(tf.constant(3.))

@tf.function
def epsilon_before_eval():
  with tf.GradientTape() as tape:
    y = x * x
  g1 = tape.gradient(y, x)

  x.assign_add(1.0)
  with tf.GradientTape() as tape:
    y = x * x
  g = tape.gradient(y, x)
  x.assign_sub(1.0)
  print(y, g)
  return g1, g

g1, g = epsilon_before_eval()
assert x == 3.0
assert g1 - g == -2

@rainwoodman
Copy link
Contributor Author

But the first sequence gives identical gradients (which means the implementation reverts back to the underlying non-SAM optimizer.

import tensorflow as tf
import numpy as np

x = tf.Variable(tf.constant(3.))

@tf.function
def epsilon_after_eval():
  with tf.GradientTape() as tape:
    y = x * x
  g1 = tape.gradient(y, x)

  with tf.GradientTape() as tape:
    y = x * x

  x.assign_add(1.0)
  g = tape.gradient(y, x)
  x.assign_sub(1.0)
  print(y, g)
  return g1, g

g1, g = epsilon_after_eval()
assert x == 3.0
assert g1 - g == -2  ## fails because g1 == g

@sayakpaul
Copy link
Owner

Thank you so much for pointing this out. Lesson learned, indeed!

Would you be interested in sending a PR reflecting this change in the notebook?

@rainwoodman
Copy link
Contributor Author

Sure. Thanks for confirming! I initially thought this may indicate another function / eager inconsistency in TensorFlow. We've been chasing wildly after such corner cases ;)

I am not particular good at notebook PRs. Any suggested process other than editing json directly?

@sayakpaul
Copy link
Owner

Yeah so, you could first clone this repository.

Then you could open the notebook in Colab directly, make changes, and commit that directly inside your repository from Colab itself. Then you could raise the PR.

image

image

Let me know if anything is unclear.

@rainwoodman rainwoodman mentioned this issue Nov 10, 2021
sayakpaul added a commit that referenced this issue Nov 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants