Skip to content

Commit ee3bfa1

Browse files
Vighnesh BirodkarTF Object Detection Team
authored andcommitted
Add options in TF2 launch script for summaries and checkpoints.
PiperOrigin-RevId: 322828673
1 parent 2ae9c3a commit ee3bfa1

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

research/object_detection/model_lib_v2.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import time
2424

2525
import tensorflow.compat.v1 as tf
26+
import tensorflow.compat.v2 as tf2
2627

2728
from object_detection import eval_util
2829
from object_detection import inputs
@@ -414,8 +415,9 @@ def train_loop(
414415
train_steps=None,
415416
use_tpu=False,
416417
save_final_config=False,
417-
checkpoint_every_n=5000,
418+
checkpoint_every_n=1000,
418419
checkpoint_max_to_keep=7,
420+
record_summaries=True,
419421
**kwargs):
420422
"""Trains a model using eager + functions.
421423
@@ -445,6 +447,7 @@ def train_loop(
445447
Checkpoint every n training steps.
446448
checkpoint_max_to_keep:
447449
int, the number of most recent checkpoints to keep in the model directory.
450+
record_summaries: Boolean, whether or not to record summaries.
448451
**kwargs: Additional keyword arguments for configuration override.
449452
"""
450453
## Parse the configs
@@ -531,8 +534,11 @@ def train_dataset_fn(input_context):
531534
# is the chief.
532535
summary_writer_filepath = get_filepath(strategy,
533536
os.path.join(model_dir, 'train'))
534-
summary_writer = tf.compat.v2.summary.create_file_writer(
535-
summary_writer_filepath)
537+
if record_summaries:
538+
summary_writer = tf.compat.v2.summary.create_file_writer(
539+
summary_writer_filepath)
540+
else:
541+
summary_writer = tf2.summary.create_noop_writer()
536542

537543
if use_tpu:
538544
num_steps_per_iteration = 100
@@ -604,7 +610,9 @@ def _dist_train_step(data_iterator):
604610

605611
if num_steps_per_iteration > 1:
606612
for _ in tf.range(num_steps_per_iteration - 1):
607-
_sample_and_train(strategy, train_step_fn, data_iterator)
613+
# Following suggestion on yaqs/5402607292645376
614+
with tf.name_scope(''):
615+
_sample_and_train(strategy, train_step_fn, data_iterator)
608616

609617
return _sample_and_train(strategy, train_step_fn, data_iterator)
610618

research/object_detection/model_main_tf2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@
6262
'num_workers', 1, 'When num_workers > 1, training uses '
6363
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
6464
'MirroredStrategy.')
65+
flags.DEFINE_integer(
66+
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
67+
flags.DEFINE_boolean('record_summaries', True,
68+
('Whether or not to record summaries during'
69+
' training.'))
6570

6671
FLAGS = flags.FLAGS
6772

@@ -100,7 +105,9 @@ def main(unused_argv):
100105
pipeline_config_path=FLAGS.pipeline_config_path,
101106
model_dir=FLAGS.model_dir,
102107
train_steps=FLAGS.num_train_steps,
103-
use_tpu=FLAGS.use_tpu)
108+
use_tpu=FLAGS.use_tpu,
109+
checkpoint_every_n=FLAGS.checkpoint_every_n,
110+
record_summaries=FLAGS.record_summaries)
104111

105112
if __name__ == '__main__':
106113
tf.compat.v1.app.run()

0 commit comments

Comments
 (0)