-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Keras model benchmark #4476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras model benchmark #4476
Changes from all commits
7cf46fc
049c600
3662e3c
9c635b0
ca03b62
88f9ac9
558f78e
951630a
401252c
48d8b13
0e474b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Keras Application Models Benchmark | ||
## Overview | ||
This provides a single scaffold to benchmark the Keras built-in application [models](https://keras.io/applications/). All the models are for image classification applications, and include: | ||
|
||
- Xception | ||
- VGG16 | ||
- VGG19 | ||
- ResNet50 | ||
- InceptionV3 | ||
- InceptionResNetV2 | ||
- MobileNet | ||
- DenseNet | ||
- NASNet | ||
|
||
## Dataset | ||
Synthetic dataset is used for the benchmark. | ||
|
||
## Callbacks | ||
Two custom callbacks are provided for model benchmarking: ExamplesPerSecondCallback and LoggingMetricCallback. For each callback, `epoch_based` and `batch_based` options are available to set the benchmark level. Check [model_callbacks.py](model_callbacks.py) for more details. | ||
|
||
## Running Code | ||
To benchmark a model, use `--model` to specify the model name, and issue the following command: | ||
``` | ||
python benchmark_main.py --model=resnet | ||
``` | ||
Arguments: | ||
* `--model`: Which model to be benchmarked. The model name is defined as the keys of `MODELS` in [benchmark_main.py](benchmark_main.py). | ||
* `--callbacks`: To specify a list of callbacks. | ||
|
||
Use the `--help` or `-h` flag to get a full list of possible arguments. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Benchmark on the keras built-in application models.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
# pylint: disable=g-bad-import-order | ||
import numpy as np | ||
from absl import app as absl_app | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. numpy should be below, right? Alphabetically, at least. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. somehow Kokoro test will give me a lint checking error if absl comes first, I remember. Will try to test it again. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohh, right, because numpy is third-party and absl is Google. I keep forgetting that. Better to leave as-is, thanks. |
||
from absl import flags | ||
import tensorflow as tf | ||
# pylint: enable=g-bad-import-order | ||
|
||
from official.keras_application_models import dataset | ||
from official.keras_application_models import model_callbacks | ||
from official.utils.flags import core as flags_core | ||
from official.utils.logs import logger | ||
|
||
# Define a dictionary that maps model names to their model classes inside Keras | ||
MODELS = { | ||
"vgg16": tf.keras.applications.VGG16, | ||
"vgg19": tf.keras.applications.VGG19, | ||
"inceptionv3": tf.keras.applications.InceptionV3, | ||
"xception": tf.keras.applications.Xception, | ||
"resnet50": tf.keras.applications.ResNet50, | ||
"inceptionresnetv2": tf.keras.applications.InceptionResNetV2, | ||
"mobilenet": tf.keras.applications.MobileNet, | ||
"densenet121": tf.keras.applications.DenseNet121, | ||
"densenet169": tf.keras.applications.DenseNet169, | ||
"densenet201": tf.keras.applications.DenseNet201, | ||
# TODO(b/80431378) | ||
# "nasnetlarge": tf.keras.applications.NASNetLarge, | ||
# "nasnetmobile": tf.keras.applications.NASNetMobile, | ||
} | ||
|
||
|
||
def run_keras_model_benchmark(_): | ||
"""Run the benchmark on keras model.""" | ||
# Ensure a valid model name was supplied via command line argument | ||
if FLAGS.model not in MODELS.keys(): | ||
raise AssertionError("The --model command line argument should " | ||
"be a key in the `MODELS` dictionary.") | ||
|
||
# Load the model | ||
tf.logging.info("Benchmark on {} model...".format(FLAGS.model)) | ||
keras_model = MODELS[FLAGS.model] | ||
model = keras_model(weights=None) | ||
|
||
# Get dataset | ||
dataset_name = "ImageNet" | ||
if FLAGS.use_synthetic_data: | ||
tf.logging.info("Using synthetic dataset...") | ||
dataset_name += "_Synthetic" | ||
train_num_images = FLAGS.batch_size | ||
val_num_images = FLAGS.batch_size | ||
train_dataset = dataset.generate_synthetic_input_dataset( | ||
FLAGS.model, train_num_images) | ||
val_dataset = dataset.generate_synthetic_input_dataset( | ||
FLAGS.model, val_num_images) | ||
else: | ||
raise ValueError("Only synthetic dataset is supported!") | ||
|
||
# If run with multiple GPUs | ||
num_gpus = flags_core.get_num_gpus(FLAGS) | ||
if num_gpus > 0: | ||
model = tf.keras.utils.multi_gpu_model(model, gpus=num_gpus) | ||
|
||
# Configure the model | ||
model.compile(loss="categorical_crossentropy", | ||
optimizer="sgd", | ||
metrics=["accuracy"]) | ||
|
||
# Create benchmark logger for benchmark logging | ||
run_params = { | ||
"batch_size": FLAGS.batch_size, | ||
"synthetic_data": FLAGS.use_synthetic_data, | ||
"train_epochs": FLAGS.train_epochs | ||
} | ||
|
||
benchmark_logger = logger.get_benchmark_logger() | ||
benchmark_logger.log_run_info( | ||
model_name=FLAGS.model, | ||
dataset_name=dataset_name, | ||
run_params=run_params, | ||
test_id=FLAGS.benchmark_test_id) | ||
|
||
# Create callbacks that log metric values about the training and evaluation | ||
callbacks = model_callbacks.get_model_callbacks( | ||
FLAGS.callbacks, | ||
batch_size=FLAGS.batch_size, | ||
metric_logger=benchmark_logger) | ||
# Train and evaluate the model | ||
history = model.fit( | ||
train_dataset, | ||
epochs=FLAGS.train_epochs, | ||
callbacks=callbacks, | ||
validation_data=val_dataset, | ||
steps_per_epoch=int(np.ceil(train_num_images / FLAGS.batch_size)), | ||
validation_steps=int(np.ceil(val_num_images / FLAGS.batch_size)) | ||
) | ||
|
||
tf.logging.info("Logging the evaluation results...") | ||
for epoch in range(FLAGS.train_epochs): | ||
eval_results = { | ||
"accuracy": history.history["val_acc"][epoch], | ||
"loss": history.history["val_loss"][epoch], | ||
tf.GraphKeys.GLOBAL_STEP: (epoch + 1) * np.ceil( | ||
train_num_images/FLAGS.batch_size) | ||
} | ||
benchmark_logger.log_evaluation_result(eval_results) | ||
|
||
# Clear the session explicitly to avoid session delete error | ||
tf.keras.backend.clear_session() | ||
|
||
|
||
def define_keras_benchmark_flags(): | ||
"""Add flags for keras built-in application models.""" | ||
flags_core.define_base(hooks=False) | ||
flags_core.define_performance() | ||
flags_core.define_image() | ||
flags_core.define_benchmark() | ||
flags.adopt_module_key_flags(flags_core) | ||
|
||
flags_core.set_defaults( | ||
data_format="channels_last", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does data format have a non-standard default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default is None in utils, and the models require channels_last data format. |
||
use_synthetic_data=True, | ||
batch_size=32, | ||
train_epochs=2) | ||
|
||
flags.DEFINE_enum( | ||
name="model", default=None, | ||
enum_values=MODELS.keys(), case_sensitive=False, | ||
help=flags_core.help_wrap( | ||
"Model to be benchmarked.")) | ||
|
||
flags.DEFINE_list( | ||
name="callbacks", | ||
default=["ExamplesPerSecondCallback", "LoggingMetricCallback"], | ||
help=flags_core.help_wrap( | ||
"A list of (case insensitive) strings to specify the names of " | ||
"callbacks. For example: `--callbacks ExamplesPerSecondCallback," | ||
"LoggingMetricCallback`")) | ||
|
||
|
||
def main(_): | ||
with logger.benchmark_context(FLAGS): | ||
run_keras_model_benchmark(FLAGS) | ||
|
||
if __name__ == "__main__": | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
define_keras_benchmark_flags() | ||
FLAGS = flags.FLAGS | ||
absl_app.run(main) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Prepare dataset for keras model benchmark.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
# Default values for dataset. | ||
_NUM_CHANNELS = 3 | ||
_NUM_CLASSES = 1000 | ||
|
||
|
||
def _get_default_image_size(model): | ||
"""Provide default image size for each model.""" | ||
image_size = (224, 224) | ||
if model in ["inception", "xception", "inceptionresnet"]: | ||
image_size = (299, 299) | ||
elif model in ["nasnetlarge"]: | ||
image_size = (331, 331) | ||
return image_size | ||
|
||
|
||
def generate_synthetic_input_dataset(model, num_imgs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this function need to care about the data format? eg channel first vs last? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By default, all the built-in models use channel_last format. So it's not an option here. |
||
"""Generate synthetic dataset.""" | ||
image_size = _get_default_image_size(model) | ||
input_shape = (num_imgs,) + image_size + (_NUM_CHANNELS,) | ||
|
||
images = tf.zeros(input_shape, dtype=tf.float32) | ||
labels = tf.zeros((num_imgs, _NUM_CLASSES), dtype=tf.float32) | ||
|
||
return tf.data.Dataset.from_tensors((images, labels)).repeat() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resnet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess its ResNet50 only