Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Keras SaveModel does not properly save optimizer state #44670

Open
adriangb opened this issue Nov 7, 2020 · 13 comments
Open

BUG: Keras SaveModel does not properly save optimizer state #44670

adriangb opened this issue Nov 7, 2020 · 13 comments
Labels
comp:keras Keras related issues TF 2.3 Issues related to TF 2.3 type:bug Bug

Comments

@adriangb
Copy link
Contributor

adriangb commented Nov 7, 2020

EDIT: looks like this is a dupe of #42749, I'll leave this up for now in case since that issue does not have as reproducible high level example, but feel free to close.

This happens at least for Adam (does not apply to SGD for example, did not test with others).

Tested on tf-nightly and tf==2.3.0.

TL;DR: running a tf.kerasModel through tf.keras.models.load(model.save) does not properly preserve the state of optimizers for certain optimizers (see #42749 for more details).

The docs read:

The savefile includes:

  • The model architecture, allowing to re-instantiate the model.
  • The model weights.
  • The state of the optimizer, allowing to resume training exactly where you left off.

Full example:

import numpy as np
import tensorflow as tf
from tensorflow import keras


# Define a minimal model
inp = keras.layers.Input((1, ))
out = keras.layers.Dense(1)(inp)
m1 = keras.Model(inp, out)
m1.compile(loss="mae", optimizer="adam")

# Create some test data
X, y = np.random.random((100, )), np.random.random((100, ))

# Fit the model to the test data to get everything initialized
m1.fit(X, y, verbose=0)


def roundtrip(model: keras.Model) -> keras.Model:
    save_dir = "/tmp/mymodel"
    model.save(save_dir)
    restored = keras.models.load_model(save_dir)
    return restored


# Create a copy of the fitted m1
m2 = roundtrip(m1)

# Weights are preserved correctly, this passes
np.testing.assert_allclose(m1.predict(X), m2.predict(X))

# New lets train once more round
m1.fit(X, y, verbose=0)
# Since optimizer weights/state is not preserved, this fit call
# results in different weights in m2, which makes the predictions differ
m2.fit(X, y, verbose=0)
try:
    np.testing.assert_allclose(m1.predict(X), m2.predict(X), rtol=0.1)  # large relative tolerance
except AssertionError:
    print("AssertionError: model predictions differ")

# Diagnosis: optimizer weights are not preserved
weights1 = m1.optimizer.get_weights()
m3 = roundtrip(m1)
weights3 = m3.optimizer.get_weights()

try:
    assert weights1 == weights3
except AssertionError:
    print("AssertionError: optimizer weights differ")
    print(f"    {weights1}\n    vs\n    {weights3}")

# Further, we can't even restore the weights without training!
try:
    m3.optimizer.set_weights(weights1)
except Exception as e:
    print(str(e).split(" Provided weights:")[0])
@ravikyram
Copy link
Contributor

@adriangb

I have tried in colab with TF version 2.3, nightly version(2.5.0-dev20201108).Please, find the gist here.You are also seeing the same behavior?
Thanks!

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Nov 9, 2020
@adriangb
Copy link
Contributor Author

adriangb commented Nov 9, 2020

Yes

@ravikyram ravikyram removed the stat:awaiting response Status - Awaiting response from author label Nov 9, 2020
@ravikyram ravikyram assigned jvishnuvardhan and unassigned ravikyram Nov 9, 2020
@adriangb
Copy link
Contributor Author

This hacked together fix may be helpful to others with the same problem:
notebook

def _temp_create_all_weights(self, var_list):
    self._create_all_weights_orig(var_list)
    try:
        self.set_weights(self._restored_weights)
    except ValueError:
        # Weights don't match, eg. when optimizer was pickled before any training
        pass
    delattr(self, "_restored_weights")
    self._create_all_weights = self._create_all_weights_orig


def _restore_optimizer_weights(optimizer, weights) -> None:
    optimizer._restored_weights = weights
    optimizer._create_all_weights_orig = optimizer._create_all_weights
    optimizer._create_all_weights = MethodType(_temp_create_all_weights, optimizer)


def unpack_keras_optimizer(opt_serialized, weights):
    """Reconstruct optimizer.
    """
    optimizer: keras.optimizers.Optimizer = keras.optimizers.deserialize(opt_serialized)
    _restore_optimizer_weights(optimizer, weights)
    return optimizer


def pack_keras_optimizer(optimizer: keras.optimizers.Optimizer):
    """Support for Pythons's Pickle protocol in Keras Optimizers.
    """
    opt_serialized = keras.optimizers.serialize(optimizer)
    weights = optimizer.get_weights()
    return unpack_keras_optimizer, (opt_serialized, weights)

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 11, 2020
@adriangb
Copy link
Contributor Author

Hi any updates here? I believe this is a major bug. If a user saves a model and loads it they will get unexpected results with no errors.

@sachinprasadhs
Copy link
Contributor

Was able to reproduce your issue in Tf Nightly 2.6.0-dev20210530, please find the gist here. Thanks!

@adriangb
Copy link
Contributor Author

adriangb commented Sep 6, 2022

It would be really nice if this could be handled before TensorFlow 2.11 which will make the new experimental Optimizer API the default since the new API can't be serialized as far as I can tell and breaks the existing serialization methods (see keras-team/tf-keras#442)

@gowthamkpr
Copy link

@adriangb The development of keras has been moved to keras-team/keras repo. Can you please go ahead and close this issue as its already been discussed here. Thanks!

@gowthamkpr gowthamkpr added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Sep 23, 2022
@adriangb
Copy link
Contributor Author

They’re kinda different issues and this has been open for 2 years now so I’d rather leave this open until it gets fixed

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Sep 28, 2022
@gowthamkpr
Copy link

Sure. Thank you!!

@gowthamkpr gowthamkpr removed their assignment Oct 5, 2022
@songsuoyuan
Copy link

They’re kinda different issues and this has been open for 2 years now so I’d rather leave this open until it gets fixed

any updates? still cannot resume the optimizer using

my_model.save('some-path', include_optimizer=True) and keras.models.load_model('some-path')

since I would like to track the global steps using my_model.optimizer.iterations, but each time I load the model, the iterations reset to zero.

@adriangb
Copy link
Contributor Author

The original example was meant to represent a real world example where someone saves a model and then continues training. A simpler example that directly tests the issue:

import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(1,)),
        tf.keras.layers.Dense(1, activation="softmax"),
    ]
)
model.compile(optimizer="adam", loss="categorical_crossentropy")
model.fit([[1]], [0])

model.save("model")
new = tf.keras.models.load_model("model")
assert len(new.optimizer.variables()) == len(model.optimizer.variables())

@oksanatkach
Copy link

So I had the same issue - I save a model, then load the checkpoint in order to continue training, but my loss and accuracy are lower than they were before I saved the model. I also used Adam optimizer.
From issue 16983, I did this:

new_model = tf.keras.models.load_model(folder_path)
new_model.load_weights(folder_path)

And the training continues as expected. If I start new_model.fit(), stop it, and then check new_model.optimizer.iterations, it shows the proper number. Looks like the optimizer gets built after you call .fit(), BUT if I don't call .load_weights() explicitly first, the optimizer gets reset, which is definitely a bug.

@Jnorm911
Copy link

Bump. This bug is a nightmare for notebook users.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.3 Issues related to TF 2.3 type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants