diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index 22fd727f9..5682ae820 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -70,6 +70,13 @@ "How many recent checkpoints to keep.") flags.DEFINE_bool("experimental_optimize_placement", False, "Optimize ops placement with experimental session options.") +flags.DEFINE_integer("keep_checkpoint_every_n_hours", 10000, + "Number of hours between each checkpoint to be saved. " + "The default value of 10,000 hours effectively disables the feature.") +flags.DEFINE_integer("save_checkpoints_secs", 0, + "Save checkpoints every this many seconds. " + "Default=0 means let tensorflow.contrib.learn.python.learn decide, " + "which is currently equivalent to 600, i.e. 10 minutes.") # Distributed training flags flags.DEFINE_string("master", "", "Address of TensorFlow master.") @@ -203,7 +210,9 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): model_dir=output_dir, gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction, session_config=session_config(), - keep_checkpoint_max=FLAGS.keep_checkpoint_max)) + keep_checkpoint_max=FLAGS.keep_checkpoint_max, + keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, + save_checkpoints_secs=FLAGS.save_checkpoints_secs,)) # Store the hparams in the estimator as well estimator.hparams = hparams return estimator, {