Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keras optimizer.iterations are not properly save & restored #48947

Open
blackyang opened this issue May 6, 2021 · 11 comments
Open

keras optimizer.iterations are not properly save & restored #48947

blackyang opened this issue May 6, 2021 · 11 comments
Assignees
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.16 type:bug Bug

Comments

@blackyang
Copy link
Contributor

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Any
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: NA
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.5.0rc2
  • Python version: 3.7
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior

optimizer.iterations are not properly save & restored. After restoring iterations are reset as 0, which leads to wrong lr based on lr_scheduler

Describe the expected behavior

provide an option to either reset it or not, for backward compatibility maybe default to reset

Contributing - Do you
want to contribute a PR? (yes/no): - Briefly describe your candidate solution
(if contributing):

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

import tensorflow as tf

@tf.keras.utils.register_keras_serializable(package='Custom', name='MyScheduler')
class MyScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, **kwargs):
        super(MyScheduler, self).__init__(**kwargs)

    def __call__(self, step):
        return step

    def get_config(self):
        return {}


inputs = tf.keras.Input(10)
outputs = tf.keras.layers.Dense(10)(inputs)
model = tf.keras.Model(inputs, outputs)

lr_scheduler = MyScheduler()
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_scheduler)
model.compile(optimizer=optimizer, loss="mse")


def get_dataset(repeat):
    inputs_data = tf.ones([16, 10])
    labels_data = tf.ones([16, 10])
    dataset = (
        tf.data.Dataset.from_tensors(inputs_data)
        .map(
            lambda x: (
                inputs_data,
                labels_data,
                None,
            )
        ).repeat(repeat)
    )
    return dataset


model.fit(get_dataset(3), epochs=1)
print(model.optimizer.iterations, lr_scheduler(model.optimizer.iterations))

path = "./foo/"
model.save(path)
loaded = tf.keras.models.load_model(path)
loaded.fit(get_dataset(4), epochs=1)
print(loaded.optimizer.iterations, lr_scheduler(loaded.optimizer.iterations))

the last print shows 4, but it should 3+4=7

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

@blackyang blackyang added the type:bug Bug label May 6, 2021
@UsharaniPagadala UsharaniPagadala added comp:keras Keras related issues TF 2.5 Issues related to TF 2.5 labels May 10, 2021
@UsharaniPagadala
Copy link

@sachinprasadhs

I was able to reproduce the error in tf 2.4, tf 2.5rc2 and tf-nightly.Please find the gist .Thanks

@sachinprasadhs
Copy link
Contributor

sachinprasadhs commented May 17, 2021

@blackyang I think this prints the number of individual batches where the updates have been performed, in your case if you are doing n iterations on your data then optimizer.iterations prints n.

@chrisbutner
Copy link

@blackyang I think this prints the number of individual batches where the updates have been performed, in your case if you are doing n iterations on your data then optimizer.iterations prints n.

The problem is that OptimizerV2.iterations is exactly what gets passed to a LearningRateSchedule to determine the learning rate. If it reports 4 when it should be 7 then training isn't going to work properly when it saves and resumes.

local_step = math_ops.cast(self.iterations, var_dtype)

@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label May 25, 2021
@blackyang
Copy link
Contributor Author

@sachinprasadhs thank you for your reply! yes I completely understand why it prints 4 instead of 7, I was saying that it should print 7 (or have an option to specify whether to print 4 by resetting or still print 7), otherwise the learning rate scheduler is wrong

another way is to update learning rate scheduler to not use this iteration

any thoughts? thank you!

@google-ml-butler google-ml-butler bot removed the stale This label marks the issue/pr stale - to be closed automatically if no activity label May 26, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 2, 2021
@blackyang
Copy link
Contributor Author

any updates? thx

@google-ml-butler google-ml-butler bot removed the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 4, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 11, 2021
@jvishnuvardhan jvishnuvardhan removed the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 16, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 23, 2021
@jvishnuvardhan jvishnuvardhan added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author stale This label marks the issue/pr stale - to be closed automatically if no activity labels Jun 23, 2021
@kumariko
Copy link

I was able to reproduce the error in tf 2.6 . Please find the gist here.Thanks

@kumariko kumariko added 2.6.0 and removed TF 2.5 Issues related to TF 2.5 labels Sep 29, 2021
@bhack
Copy link
Contributor

bhack commented Sep 29, 2021

@Venkat6871 Venkat6871 added TF 2.16 and removed 2.6.0 labels Apr 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.16 type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants