Skip to content

Commit

Permalink
Enables timer callback and disables checkpoint saving in retinanet be…
Browse files Browse the repository at this point in the history
…nchmark test.

PiperOrigin-RevId: 275080469
  • Loading branch information
yeqingli authored and tensorflower-gardener committed Oct 16, 2019
1 parent cb91369 commit 1b77cd8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
13 changes: 7 additions & 6 deletions official/benchmark/retinanet_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@ def _report_benchmark(self,
}]
if self.timer_callback:
metrics.append({
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
'name': 'exp_per_second',
'value': self.timer_callback.get_examples_per_sec(train_batch_size)
})
else:
metrics.append({
Expand Down Expand Up @@ -134,7 +132,7 @@ def __init__(self, output_dir=None, **kwargs):

def _run_detection_main(self):
"""Starts detection job."""
return detection.main('unused_argv')
return detection.run(callbacks=[self.timer_callback])


class RetinanetAccuracy(RetinanetBenchmarkBase):
Expand Down Expand Up @@ -166,7 +164,8 @@ def _run_and_report_benchmark(self, min_ap=0.325, max_ap=0.35):
stats=summary,
wall_time_sec=wall_time_sec,
min_ap=min_ap,
max_ap=max_ap)
max_ap=max_ap,
train_batch_size=self.params_override['train']['batch_size'])

def _setup(self):
super(RetinanetAccuracy, self)._setup()
Expand Down Expand Up @@ -228,6 +227,8 @@ def benchmark_8_gpu_coco(self):
params['eval']['eval_samples'] = 8
FLAGS.params_override = json.dumps(params)
FLAGS.model_dir = self._get_model_dir('real_benchmark_8_gpu_coco')
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
if self.timer_callback is None:
logging.error('Cannot measure performance without timer callback')
else:
Expand Down
12 changes: 11 additions & 1 deletion official/modeling/training/distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def initialize_common_flags():
flags.DEFINE_integer(
'task_index', 0,
'If multi-worker training, the task_index of this worker.')
flags.DEFINE_integer('save_checkpoint_freq', None,
'Number of steps to save checkpoint.')


def strategy_flags_dict():
Expand Down Expand Up @@ -447,6 +449,12 @@ def _run_callbacks_on_batch_end(batch):
if save_config:
self._save_config(model_dir)

if FLAGS.save_checkpoint_freq:
save_freq = FLAGS.save_checkpoint_freq
else:
save_freq = iterations_per_loop
last_save_checkpoint_step = 0

params = self._params
strategy = self._strategy
# To reduce unnecessary send/receive input pipeline operation, we place
Expand Down Expand Up @@ -540,9 +548,11 @@ def _run_callbacks_on_batch_end(batch):
# iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_steps:
if save_freq > 0 and current_step < total_steps and (
current_step - last_save_checkpoint_step) >= save_freq:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step

if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
Expand Down
21 changes: 16 additions & 5 deletions official/vision/detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@
FLAGS = flags.FLAGS


def run_executor(params, train_input_fn=None, eval_input_fn=None):
def run_executor(params,
train_input_fn=None,
eval_input_fn=None,
callbacks=None):
"""Runs Retinanet model on distribution strategy defined by the user."""

model_builder = model_factory.model_generator(params)
Expand Down Expand Up @@ -92,6 +95,7 @@ def _model_fn(params):
iterations_per_loop=params.train.iterations_per_loop,
total_steps=params.train.total_steps,
init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True)
elif FLAGS.mode == 'eval':

Expand Down Expand Up @@ -124,9 +128,7 @@ def _model_fn(params):
raise ValueError('Mode not found: %s.' % FLAGS.mode)


def main(argv):
del argv # Unused.

def run(callbacks=None):
params = config_factory.config_generator(FLAGS.model)

params = params_dict.override_params_dict(
Expand Down Expand Up @@ -171,7 +173,16 @@ def main(argv):
batch_size=params.eval.batch_size,
num_examples=params.eval.eval_samples)
return run_executor(
params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
params,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
callbacks=callbacks)


def main(argv):
del argv # Unused.

return run()


if __name__ == '__main__':
Expand Down

0 comments on commit 1b77cd8

Please sign in to comment.