New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: Keras SaveModel does not properly save optimizer state #44670
Comments
Yes |
This hacked together fix may be helpful to others with the same problem: def _temp_create_all_weights(self, var_list):
self._create_all_weights_orig(var_list)
try:
self.set_weights(self._restored_weights)
except ValueError:
# Weights don't match, eg. when optimizer was pickled before any training
pass
delattr(self, "_restored_weights")
self._create_all_weights = self._create_all_weights_orig
def _restore_optimizer_weights(optimizer, weights) -> None:
optimizer._restored_weights = weights
optimizer._create_all_weights_orig = optimizer._create_all_weights
optimizer._create_all_weights = MethodType(_temp_create_all_weights, optimizer)
def unpack_keras_optimizer(opt_serialized, weights):
"""Reconstruct optimizer.
"""
optimizer: keras.optimizers.Optimizer = keras.optimizers.deserialize(opt_serialized)
_restore_optimizer_weights(optimizer, weights)
return optimizer
def pack_keras_optimizer(optimizer: keras.optimizers.Optimizer):
"""Support for Pythons's Pickle protocol in Keras Optimizers.
"""
opt_serialized = keras.optimizers.serialize(optimizer)
weights = optimizer.get_weights()
return unpack_keras_optimizer, (opt_serialized, weights) |
Hi any updates here? I believe this is a major bug. If a user saves a model and loads it they will get unexpected results with no errors. |
Was able to reproduce your issue in Tf Nightly 2.6.0-dev20210530, please find the gist here. Thanks! |
It would be really nice if this could be handled before TensorFlow 2.11 which will make the new experimental Optimizer API the default since the new API can't be serialized as far as I can tell and breaks the existing serialization methods (see keras-team/tf-keras#442) |
They’re kinda different issues and this has been open for 2 years now so I’d rather leave this open until it gets fixed |
Sure. Thank you!! |
any updates? still cannot resume the optimizer using my_model.save('some-path', include_optimizer=True) and keras.models.load_model('some-path') since I would like to track the global steps using my_model.optimizer.iterations, but each time I load the model, the iterations reset to zero. |
The original example was meant to represent a real world example where someone saves a model and then continues training. A simpler example that directly tests the issue: import tensorflow as tf
model = tf.keras.Sequential(
[
tf.keras.Input(shape=(1,)),
tf.keras.layers.Dense(1, activation="softmax"),
]
)
model.compile(optimizer="adam", loss="categorical_crossentropy")
model.fit([[1]], [0])
model.save("model")
new = tf.keras.models.load_model("model")
assert len(new.optimizer.variables()) == len(model.optimizer.variables()) |
So I had the same issue - I save a model, then load the checkpoint in order to continue training, but my loss and accuracy are lower than they were before I saved the model. I also used Adam optimizer.
And the training continues as expected. If I start new_model.fit(), stop it, and then check new_model.optimizer.iterations, it shows the proper number. Looks like the optimizer gets built after you call .fit(), BUT if I don't call .load_weights() explicitly first, the optimizer gets reset, which is definitely a bug. |
Bump. This bug is a nightmare for notebook users. |
EDIT: looks like this is a dupe of #42749, I'll leave this up for now in case since that issue does not have as reproducible high level example, but feel free to close.
This happens at least for Adam (does not apply to SGD for example, did not test with others).
Tested on
tf-nightly
andtf==2.3.0
.TL;DR: running a
tf.kerasModel
throughtf.keras.models.load(model.save)
does not properly preserve the state of optimizers for certain optimizers (see #42749 for more details).The docs read:
Full example:
The text was updated successfully, but these errors were encountered: