From 93b8168ad9c54f9acf09a161a6dad1dd99b3bfeb Mon Sep 17 00:00:00 2001 From: Zhichao Lu Date: Mon, 26 Mar 2018 11:27:24 -0700 Subject: [PATCH] Switch line orders in trainer so that restore_map is called after moving average variables are created. Moving averages are now properly loaded during fine-tuning, instead of being recreated. PiperOrigin-RevId: 190496046 --- research/object_detection/trainer.py | 46 ++++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/research/object_detection/trainer.py b/research/object_detection/trainer.py index 0065c9919ef..cf3429a60bf 100644 --- a/research/object_detection/trainer.py +++ b/research/object_detection/trainer.py @@ -264,29 +264,6 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, total_num_replicas=worker_replicas) sync_optimizer = training_optimizer - # Create ops required to initialize the model from a given checkpoint. - init_fn = None - if train_config.fine_tune_checkpoint: - if not train_config.fine_tune_checkpoint_type: - # train_config.from_detection_checkpoint field is deprecated. For - # backward compatibility, fine_tune_checkpoint_type is set based on - # from_detection_checkpoint. - if train_config.from_detection_checkpoint: - train_config.fine_tune_checkpoint_type = 'detection' - else: - train_config.fine_tune_checkpoint_type = 'classification' - var_map = detection_model.restore_map( - fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type, - load_all_detection_checkpoint_vars=( - train_config.load_all_detection_checkpoint_vars)) - available_var_map = (variables_helper. - get_variables_available_in_checkpoint( - var_map, train_config.fine_tune_checkpoint)) - init_saver = tf.train.Saver(available_var_map) - def initializer_fn(sess): - init_saver.restore(sess, train_config.fine_tune_checkpoint) - init_fn = initializer_fn - with tf.device(deploy_config.optimizer_device()): regularization_losses = (None if train_config.add_regularization_loss else []) @@ -354,6 +331,29 @@ def initializer_fn(sess): saver = tf.train.Saver( keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) + # Create ops required to initialize the model from a given checkpoint. + init_fn = None + if train_config.fine_tune_checkpoint: + if not train_config.fine_tune_checkpoint_type: + # train_config.from_detection_checkpoint field is deprecated. For + # backward compatibility, fine_tune_checkpoint_type is set based on + # from_detection_checkpoint. + if train_config.from_detection_checkpoint: + train_config.fine_tune_checkpoint_type = 'detection' + else: + train_config.fine_tune_checkpoint_type = 'classification' + var_map = detection_model.restore_map( + fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type, + load_all_detection_checkpoint_vars=( + train_config.load_all_detection_checkpoint_vars)) + available_var_map = (variables_helper. + get_variables_available_in_checkpoint( + var_map, train_config.fine_tune_checkpoint)) + init_saver = tf.train.Saver(available_var_map) + def initializer_fn(sess): + init_saver.restore(sess, train_config.fine_tune_checkpoint) + init_fn = initializer_fn + slim.learning.train( train_tensor, logdir=train_dir,