-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to save average weights of checkpoints using Tensorflow 2.X? #60064
Comments
Thanks a lot for the help, but unfortunately the problem is still exisit. The averaged checkpoint have changed almost every name of the variables (1. change the '/' to '.S' ; 2. add '.ATTRIBUTES/VARIABLE_VALUE' to each variable again). Here is the var_list of single checkpoint and the averaged checkpoint with tf.train.list_variables (Left is the single checkpoint and the right is the averaged checkpoint). And I checked the single checkpoint and averaged checkpoint in TF1, they have the same variable name, what's wrong with it? |
@yjiangling Thank you for raising an issue! |
OK, I use the TensorFlow2.2 and TensorFlow2.4. You can use the attach scripts to replicate the issue (the name of the variable of the averaged checkpoint replace the string "/" with ".S"). Please delete the suffix of ".txt" before run it. avg_checkpoints.py.txt Thanks a lot for the help!!! |
@yjiangling Thank you for the response! |
@yjiangling I tried to replicate this issue on the latest TF version 2.12 and faced different results. Could you please check this gist and confirm the same? |
This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you. |
Thanks a lot for the detail experiments and the replay, I'm sorry for the late reply. I solve this problem with the following step:
Must use model.call() or model.predict() one time, otherwise can't get the weights of the checkpoint, I still do not understand the reason, maybe someone may look into it who is interested. |
Thank you, the problem have been solved, you can feel free to close it. |
@yjiangling Thank you for the update! |
HELP NEEDED !!!
Hi, everyone. I found a scrit to load serials of checkpoints for a model and save average weights of them in TensorFlow1.
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py
Which is very useful to improve the performance of the model. But in TensorFlow2, the saved checkpoints like this:
How to average the weights of parameters for them? I tried to write a script but the output checkpoint seems not as we expected, anyone can give some helps? Thanks a lot in advance. Here is my script:
`
import os
import numpy as np
import tensorflow as tf
from absl import app
from absl import flags
from absl import logging
FLAGS = flags.FLAGS
flags.DEFINE_string("checkpoints","",
"Comma-separated list of checkpoints to average.")
flags.DEFINE_integer("num_last_chekpoints", 0,
"Average the last N saved checkpoints."
" If the checkpoints flag is set, this is ignored.")
flags.DEFINE_string("prefix", "",
"Prefix (e.g., directory) to append to each checkpoint.")
flags.DEFINE_string("output_path", "/tmp/averaged.ckpt",
"Path to output the averaged checkpoint to.")
def checkpoint_exists(path):
return (tf.io.gfile.exists(path) or tf.io.gfile.exists(path + ".index"))
def main(argv):
if FLAGS.checkpoints:
# Get the checkpoints list from flags and run some basic checks.
checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")]
checkpoints = [c for c in checkpoints if c]
if not checkpoints:
raise ValueError("No checkpoints provided for averaging.")
if FLAGS.prefix:
checkpoints = [FLAGS.prefix + c for c in checkpoints]
else:
assert FLAGS.num_last_chekpoints >= 1, "Must average at least one model"
assert FLAGS.prefix, ("Prefix must be provided when averaging last"
" N checkpoints")
# checkpoint_state = tf.train.get_checkpoint_state(
# os.path.dirname(FLAGS.prefix))
# # Checkpoints are ordered from oldest to newest.
# checkpoints = checkpoint_state.all_model_checkpoint_paths[
# -FLAGS.num_last_checkpoints:]
file_list = os.listdir(FLAGS.prefix)
checkpoints = [os.path.join(FLAGS.prefix, file) for file
in file_list if file.endswith(".index")]
checkpoints = [checkpoint[:-6] for checkpoint in checkpoints]
checkpoints.sort(key=lambda checkpoint: int(checkpoint.split('-')[-1]))
checkpoints = checkpoints[-FLAGS.num_last_checkpoints:]
if name == 'main':
app.run(main)
`
The text was updated successfully, but these errors were encountered: