Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save average weights of checkpoints using Tensorflow 2.X? #60064

Closed
yjiangling opened this issue Mar 22, 2023 · 11 comments
Closed

How to save average weights of checkpoints using Tensorflow 2.X? #60064

yjiangling opened this issue Mar 22, 2023 · 11 comments
Assignees
Labels
comp:model Model related issues type:support Support issues

Comments

@yjiangling
Copy link

yjiangling commented Mar 22, 2023

HELP NEEDED !!!

Hi, everyone. I found a scrit to load serials of checkpoints for a model and save average weights of them in TensorFlow1.

https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py

Which is very useful to improve the performance of the model. But in TensorFlow2, the saved checkpoints like this:

image

How to average the weights of parameters for them? I tried to write a script but the output checkpoint seems not as we expected, anyone can give some helps? Thanks a lot in advance. Here is my script:

`
import os
import numpy as np
import tensorflow as tf
from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_string("checkpoints","",
"Comma-separated list of checkpoints to average.")
flags.DEFINE_integer("num_last_chekpoints", 0,
"Average the last N saved checkpoints."
" If the checkpoints flag is set, this is ignored.")
flags.DEFINE_string("prefix", "",
"Prefix (e.g., directory) to append to each checkpoint.")
flags.DEFINE_string("output_path", "/tmp/averaged.ckpt",
"Path to output the averaged checkpoint to.")

def checkpoint_exists(path):
return (tf.io.gfile.exists(path) or tf.io.gfile.exists(path + ".index"))

def main(argv):
if FLAGS.checkpoints:
# Get the checkpoints list from flags and run some basic checks.
checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")]
checkpoints = [c for c in checkpoints if c]
if not checkpoints:
raise ValueError("No checkpoints provided for averaging.")
if FLAGS.prefix:
checkpoints = [FLAGS.prefix + c for c in checkpoints]
else:
assert FLAGS.num_last_chekpoints >= 1, "Must average at least one model"
assert FLAGS.prefix, ("Prefix must be provided when averaging last"
" N checkpoints")
# checkpoint_state = tf.train.get_checkpoint_state(
# os.path.dirname(FLAGS.prefix))
# # Checkpoints are ordered from oldest to newest.
# checkpoints = checkpoint_state.all_model_checkpoint_paths[
# -FLAGS.num_last_checkpoints:]
file_list = os.listdir(FLAGS.prefix)
checkpoints = [os.path.join(FLAGS.prefix, file) for file
in file_list if file.endswith(".index")]
checkpoints = [checkpoint[:-6] for checkpoint in checkpoints]
checkpoints.sort(key=lambda checkpoint: int(checkpoint.split('-')[-1]))
checkpoints = checkpoints[-FLAGS.num_last_checkpoints:]

checkpoints = [c for c in checkpoints if checkpoint_exists(c)]
if not checkpoints:
	if FLAGS.checkpoints:
		raise ValueError(
			"None of the provided checkpoints exist. %s" % FLAGS.checkpoints)
	else:
		raise ValueError("Could not find checkpoints at %s" %
			os.path.dirname(FLAGS.prefix))

# Read variables from all checkpoints and average them.
logging.info("Reading variables and averaging checkpoints:")
for c in checkpoints:
	logging.info("%s ", c)

var_list = tf.train.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
	if not name.startswith("save_counter"):
		var_values[name] = np.zeros(shape)

for checkpoint in checkpoints:
	reader = tf.train.load_checkpoint(checkpoint)
	for name in var_values:
		tensor = reader.get_tensor(name)
		dtype = reader.get_variable_to_dtype_map()[name]

		if dtype == tf.string:
			var_values[name] = tensor
		else:
			var_values[name] += tensor
		var_dtypes[name] = dtype
	logging.info("Read from checkpoint %s", checkpoint)

for name in var_values:  # Average.
	if var_dtypes[name] != tf.string:
		var_values[name] /= len(checkpoints)

name = var_list[-1][0]
assert name.startswith("save_counter")
shape = reader.get_variable_to_shape_map()[name]
dtype = reader.get_variable_to_dtype_map()[name]
var_values[name] = np.zeros(shape)
var_dtypes[name] = dtype

for name in var_values.keys():
	var_values[name] = tf.Variable(
		var_values[name], dtype=var_dtypes[name])
save = tf.train.Checkpoint()
save.mapped = var_values
# save.listed = []
# save.mapped = {}
# for name in var_values.keys():
# 	save.listed.append(var_values[name])
# 	save.mapped[name] = var_values[name]
save_path = save.save(FLAGS.output_path)

logging.info("Averaged checkpoints saved in %s", FLAGS.output_path)

if name == 'main':
app.run(main)
`

@yjiangling
Copy link
Author

Thanks a lot for the help, but unfortunately the problem is still exisit. The averaged checkpoint have changed almost every name of the variables (1. change the '/' to '.S' ; 2. add '.ATTRIBUTES/VARIABLE_VALUE' to each variable again). Here is the var_list of single checkpoint and the averaged checkpoint with tf.train.list_variables (Left is the single checkpoint and the right is the averaged checkpoint). And I checked the single checkpoint and averaged checkpoint in TF1, they have the same variable name, what's wrong with it?

image

@synandi synandi added type:support Support issues comp:model Model related issues labels Mar 23, 2023
@synandi synandi assigned sushreebarsa and unassigned synandi Mar 24, 2023
@sushreebarsa
Copy link
Contributor

@yjiangling Thank you for raising an issue!
Could you please specify the TF version you are using and provide the standalone code to replicate this one?
Thank you!

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label Mar 26, 2023
@yjiangling
Copy link
Author

@yjiangling Thank you for raising an issue! Could you please specify the TF version you are using and provide the standalone code to replicate this one? Thank you!

OK, I use the TensorFlow2.2 and TensorFlow2.4. You can use the attach scripts to replicate the issue (the name of the variable of the averaged checkpoint replace the string "/" with ".S"). Please delete the suffix of ".txt" before run it.

avg_checkpoints.py.txt
gen_checkpoints.py.txt

Thanks a lot for the help!!!

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Mar 27, 2023
@sushreebarsa
Copy link
Contributor

sushreebarsa commented Apr 10, 2023

@yjiangling Thank you for the response!
FYI, the older version of TF is not actively supported. It is now recommended to use the latest TF version. I am trying to replicate it in the latest TF version and update you soon. Thank you!

@sushreebarsa
Copy link
Contributor

@yjiangling I tried to replicate this issue on the latest TF version 2.12 and faced different results. Could you please check this gist and confirm the same?
Thank you!

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label Apr 12, 2023
@github-actions
Copy link

This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 20, 2023
@yjiangling
Copy link
Author

@yjiangling I tried to replicate this issue on the latest TF version 2.12 and faced different results. Could you please check this gist and confirm the same? Thank you!

Thanks a lot for the detail experiments and the replay, I'm sorry for the late reply. I solve this problem with the following step:

  1. build model
    model = Model() pred = model.predict() # may use model.call() is also ok

  2. create train Checkpoint
    ckpt = tf.train.Checkpoint(model=model)

  3. restore parameter of each checkpoint, get weights and average

for checkpoint in checkpoints:
    ckpt.restore(checkpoint).expect_partial()
    weights = model.get_weights()
    swa_weithts += weights
swa_weithts = [weight/len(checkpoints) for weight in swa_weithts]
  1. set weights and save model
model.set_weights(swa_weithts)
ckpt.save(save_path)

Must use model.call() or model.predict() one time, otherwise can't get the weights of the checkpoint, I still do not understand the reason, maybe someone may look into it who is interested.

@google-ml-butler google-ml-butler bot removed stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author labels Apr 24, 2023
@yjiangling
Copy link
Author

This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you.

Thank you, the problem have been solved, you can feel free to close it.

@sushreebarsa
Copy link
Contributor

@yjiangling Thank you for the update!
Closing the issue as the issue has been resolved.
Thank you!

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:model Model related issues type:support Support issues
Projects
None yet
Development

No branches or pull requests

4 participants
@yjiangling @sushreebarsa @synandi and others