import os import logging import numpy as np import tensorflow as tf def checkpoint_exists(path): return (tf.io.gfile.exists(path) or tf.io.gfile.exists(path + ".index")) def main(checkpoint_dir, output_dir): file_list = os.listdir(checkpoint_dir) checkpoints = [os.path.join(checkpoint_dir, file) for file in file_list if file.endswith(".index")] checkpoints = [checkpoint[:-6] for checkpoint in checkpoints] checkpoints = [c for c in checkpoints if checkpoint_exists(c)] if not checkpoints: raise ValueError( "None of the provided checkpoints exist. %s" % checkpoint_dir) # Read variables from all checkpoints and average them. logging.info("Reading variables and averaging checkpoints:") for c in checkpoints: logging.info("%s ", c) var_list = tf.train.list_variables(checkpoints[0]) var_values, var_dtypes = {}, {} for (name, shape) in var_list: if not name.startswith("save_counter"): var_values[name] = np.zeros(shape) for checkpoint in checkpoints: reader = tf.train.load_checkpoint(checkpoint) for name in var_values: tensor = reader.get_tensor(name) dtype = reader.get_variable_to_dtype_map()[name] if dtype == tf.string: var_values[name] = tensor else: var_values[name] += tensor var_dtypes[name] = dtype logging.info("Read from checkpoint %s", checkpoint) for name in var_values: # Average. if var_dtypes[name] != tf.string: var_values[name] /= len(checkpoints) if not os.path.exists(output_dir): os.mkdir(output_dir) save_path = os.path.join(output_dir, "average.ckpt") save_variables(var_values, var_dtypes, save_path) logging.info("Averaged checkpoints saved in %s", output_dir) def save_variables(var_values, var_dtypes, save_path): checkpoint = tf.train.Checkpoint(**{name:tf.Variable(value, dtype=var_dtypes[name]) for name, value in var_values.items()}) checkpoint.save(save_path) if __name__ == '__main__': os.environ["CUDA_VISIBLE_DEVICES"]="" gpus = tf.config.experimental.list_physical_devices(device_type='GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # step1, if there is some TensorFlow2 checkpoints in your computer, run the main function directely # checkpoint_dir = "" # define the path of the checkpoints # output_dir = "./average_ckpt" # the path to save the averaged checkpoint # main(checkpoint_dir, output_dir) # step2, if there is no TensorFlow2 checkpoints in your computer, you can you the following script to generate some from gen_checkpoints import generate_tf2_checkpoints generate_tf2_checkpoints() checkpoint_dir = "./tf_ckpts" output_dir = "./average_ckpt" main(checkpoint_dir, output_dir) # step3, check the saved averaged checkpoint from tensorflow.python.tools import inspect_checkpoint as chkp chkp.print_tensors_in_checkpoint_file("./average_ckpt/average.ckpt-1", tensor_name=None, all_tensors=True, all_tensor_names=True)