Skip to content

Commit

Permalink
Inject enable_runtime_flags into benchmarks.
Browse files Browse the repository at this point in the history
This will help general debugging by enabling custom execution with  --benchmark_method_steps.

E.g --benchmark_method_steps=train_steps=7 will run the benchmark for only 7 steps without modifying benchmark code.

PiperOrigin-RevId: 282396875
  • Loading branch information
sganeshb authored and tensorflower-gardener committed Nov 25, 2019
1 parent 9c50e96 commit bcce419
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 1 deletion.
3 changes: 3 additions & 0 deletions official/benchmark/bert_benchmark.py
Expand Up @@ -35,6 +35,7 @@
from official.nlp.bert import input_pipeline
from official.nlp.bert import run_classifier
from official.utils.misc import distribution_utils
from official.utils.testing import benchmark_wrappers

# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(self, output_dir=TMP_DIR, **kwargs):
self.num_steps_per_epoch = 110
self.num_epochs = 1

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0,
Expand Down Expand Up @@ -308,6 +310,7 @@ def __init__(self, output_dir=TMP_DIR, **kwargs):

super(BertClassifyAccuracy, self).__init__(output_dir=output_dir)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.84,
Expand Down
4 changes: 4 additions & 0 deletions official/benchmark/bert_squad_benchmark.py
Expand Up @@ -32,6 +32,8 @@
from official.benchmark import squad_evaluate_v1_1
from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils
from official.utils.testing import benchmark_wrappers


# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
Expand Down Expand Up @@ -132,6 +134,7 @@ def _setup(self):
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 1

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
Expand Down Expand Up @@ -341,6 +344,7 @@ def _setup(self):
FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
Expand Down
3 changes: 3 additions & 0 deletions official/benchmark/keras_cifar_benchmark.py
Expand Up @@ -23,6 +23,7 @@
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.benchmark import keras_benchmark
from official.utils.testing import benchmark_wrappers
from official.vision.image_classification import resnet_cifar_main

MIN_TOP_1_ACCURACY = 0.929
Expand Down Expand Up @@ -197,6 +198,7 @@ def benchmark_graph_2_gpu(self):
FLAGS.dtype = 'fp32'
self._run_and_report_benchmark()

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_cifar_main.run(FLAGS)
Expand All @@ -222,6 +224,7 @@ def __init__(self, output_dir=None, default_flags=None):
flag_methods=flag_methods,
default_flags=default_flags)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_cifar_main.run(FLAGS)
Expand Down
8 changes: 7 additions & 1 deletion official/benchmark/keras_imagenet_benchmark.py
Expand Up @@ -22,6 +22,7 @@
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.benchmark import keras_benchmark
from official.utils.testing import benchmark_wrappers
from official.vision.image_classification import resnet_imagenet_main

MIN_TOP_1_ACCURACY = 0.76
Expand Down Expand Up @@ -171,6 +172,7 @@ def benchmark_xla_8_gpu_fp16_dynamic(self):
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark(top_1_min=0.736)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY):
Expand Down Expand Up @@ -201,6 +203,7 @@ def __init__(self, output_dir=None, default_flags=None):
flag_methods=flag_methods,
default_flags=default_flags)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, skip_steps=None):
start_time_sec = time.time()
stats = resnet_imagenet_main.run(FLAGS)
Expand Down Expand Up @@ -307,7 +310,7 @@ def benchmark_graph_1_gpu_no_dist_strat(self):
FLAGS.distribution_strategy = 'off'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
FLAGS.batch_size = 96 # BatchNorm is less efficient in legacy graph mode
# due to its reliance on v1 cond.
# due to its reliance on v1 cond.
self._run_and_report_benchmark()

def benchmark_1_gpu(self):
Expand Down Expand Up @@ -863,6 +866,7 @@ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
super(Resnet50KerasBenchmarkRemoteData, self).__init__(
output_dir=output_dir, default_flags=def_flags)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
# skip the first epoch for performance measurement.
super(Resnet50KerasBenchmarkRemoteData,
Expand Down Expand Up @@ -891,6 +895,7 @@ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
flag_methods=flag_methods,
default_flags=def_flags)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_imagenet_main.run(FLAGS)
Expand Down Expand Up @@ -1023,6 +1028,7 @@ def _benchmark_common(self, eager, num_workers, all_reduce_alg):

self._run_and_report_benchmark()

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
top_1_min=MIN_TOP_1_ACCURACY,
top_1_max=MAX_TOP_1_ACCURACY):
Expand Down
3 changes: 3 additions & 0 deletions official/benchmark/resnet_ctl_imagenet_benchmark.py
Expand Up @@ -25,6 +25,7 @@
from official.vision.image_classification import common
from official.vision.image_classification import resnet_ctl_imagenet_main
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
from official.utils.testing import benchmark_wrappers
from official.utils.flags import core as flags_core

MIN_TOP_1_ACCURACY = 0.76
Expand Down Expand Up @@ -169,6 +170,7 @@ def benchmark_8_gpu_amp(self):
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_ctl_imagenet_main.run(flags.FLAGS)
Expand Down Expand Up @@ -197,6 +199,7 @@ def __init__(self, output_dir=None, default_flags=None):
flag_methods=flag_methods,
default_flags=default_flags)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
start_time_sec = time.time()
stats = resnet_ctl_imagenet_main.run(FLAGS)
Expand Down
2 changes: 2 additions & 0 deletions official/benchmark/retinanet_benchmark.py
Expand Up @@ -32,6 +32,7 @@

from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers
from official.vision.detection import main as detection

TMP_DIR = os.getenv('TMPDIR')
Expand Down Expand Up @@ -151,6 +152,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
def __init__(self, output_dir=TMP_DIR, **kwargs):
super(RetinanetAccuracy, self).__init__(output_dir=output_dir)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, min_ap=0.325, max_ap=0.35):
"""Starts RetinaNet accuracy benchmark test."""

Expand Down
4 changes: 4 additions & 0 deletions official/benchmark/xlnet_benchmark.py
Expand Up @@ -31,6 +31,8 @@
from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp.xlnet import run_classifier
from official.nlp.xlnet import run_squad
from official.utils.testing import benchmark_wrappers


# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
Expand Down Expand Up @@ -76,6 +78,7 @@ def __init__(self, output_dir=None, **kwargs):

super(XLNetClassifyAccuracy, self).__init__(output_dir=output_dir)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.95,
Expand Down Expand Up @@ -149,6 +152,7 @@ def __init__(self, output_dir=None, **kwargs):

super(XLNetSquadAccuracy, self).__init__(output_dir=output_dir)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=87.0,
Expand Down
2 changes: 2 additions & 0 deletions official/recommendation/ncf_keras_benchmark.py
Expand Up @@ -28,6 +28,7 @@
from official.recommendation import ncf_common
from official.recommendation import ncf_keras_main
from official.utils.flags import core
from official.utils.testing import benchmark_wrappers

FLAGS = flags.FLAGS
NCF_DATA_DIR_NAME = 'movielens_data'
Expand Down Expand Up @@ -59,6 +60,7 @@ def _setup(self):
else:
flagsaver.restore_flag_values(NCFKerasBenchmarkBase.local_flags)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, hr_at_10_min=0, hr_at_10_max=0):
start_time_sec = time.time()
stats = ncf_keras_main.run_ncf(FLAGS)
Expand Down
2 changes: 2 additions & 0 deletions official/transformer/v2/transformer_benchmark.py
Expand Up @@ -26,6 +26,7 @@
from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as transformer_main
from official.utils.flags import core as flags_core
from official.utils.testing import benchmark_wrappers
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark

TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
default_flags=default_flags,
flag_methods=flag_methods)

@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
bleu_max=None,
bleu_min=None,
Expand Down

0 comments on commit bcce419

Please sign in to comment.