Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
import argparse
import sys

from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.mnist import dataset
from official.utils.arg_parsers import parsers
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.misc import model_helpers


LEARNING_RATE = 1e-4


Expand Down Expand Up @@ -86,6 +89,16 @@ def create_model(data_format):
])


def define_mnist_flags():
flags_core.define_base()
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(data_dir='/tmp/mnist_data',
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)


def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator."""
model = create_model(params['data_format'])
Expand Down Expand Up @@ -172,31 +185,28 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err)


def main(argv):
parser = MNISTArgParser()
flags = parser.parse_args(args=argv[1:])

def main(flags_obj):
model_function = model_fn

if flags.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size)
if flags_obj.multi_gpu:
validate_batch_size_for_multi_gpu(flags_obj.batch_size)

# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# in the model_fn itself when the optimizer is defined.
model_function = tf.contrib.estimator.replicate_model_fn(
model_fn, loss_reduction=tf.losses.Reduction.MEAN)

data_format = flags.data_format
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=flags.model_dir,
model_dir=flags_obj.model_dir,
params={
'data_format': data_format,
'multi_gpu': flags.multi_gpu
'multi_gpu': flags_obj.multi_gpu
})

# Set up training and evaluation input functions.
Expand All @@ -206,57 +216,42 @@ def train_input_fn():
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(flags.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size)
ds = dataset.train(flags_obj.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

# Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session.
ds = ds.repeat(flags.epochs_between_evals)
ds = ds.repeat(flags_obj.epochs_between_evals)
return ds

def eval_input_fn():
return dataset.test(flags.data_dir).batch(
flags.batch_size).make_one_shot_iterator().get_next()
return dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size).make_one_shot_iterator().get_next()

# Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks(
flags.hooks, batch_size=flags.batch_size)
flags_obj.hooks, batch_size=flags_obj.batch_size)

# Train and evaluate model.
for _ in range(flags.train_epochs // flags.epochs_between_evals):
for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)

if model_helpers.past_stop_threshold(flags.stop_threshold,
if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
eval_results['accuracy']):
break

# Export the model
if flags.export_dir is not None:
if flags_obj.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
mnist_classifier.export_savedmodel(flags.export_dir, input_fn)


class MNISTArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model."""

def __init__(self):
super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.ImageModelParser(),
])

self.set_defaults(
data_dir='/tmp/mnist_data',
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)


if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv)
define_mnist_flags()
absl_app.run(main)
125 changes: 57 additions & 68 deletions official/mnist/mnist_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,20 @@
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys
import time

import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags
import tensorflow as tf
import tensorflow.contrib.eager as tfe
# pylint: enable=g-bad-import-order

from official.mnist import dataset as mnist_dataset
from official.mnist import mnist
from official.utils.arg_parsers import parsers
from official.utils.flags import core as flags_core


def loss(logits, labels):
Expand Down Expand Up @@ -95,38 +98,36 @@ def test(model, dataset):
tf.contrib.summary.scalar('accuracy', accuracy.result())


def main(argv):
parser = MNISTEagerArgParser()
flags = parser.parse_args(args=argv[1:])

def main(flags_obj):
tf.enable_eager_execution()

# Automatically determine device and data_format
(device, data_format) = ('/gpu:0', 'channels_first')
if flags.no_gpu or not tf.test.is_gpu_available():
if flags_obj.no_gpu or tf.test.is_gpu_available():
(device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value.
if flags.data_format is not None:
data_format = flags.data_format
if flags_obj.data_format is not None:
data_format = flags_obj.data_format
print('Using device %s, and data format %s.' % (device, data_format))

# Load the datasets
train_ds = mnist_dataset.train(flags.data_dir).shuffle(60000).batch(
flags.batch_size)
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)
train_ds = mnist_dataset.train(flags_obj.data_dir).shuffle(60000).batch(
flags_obj.batch_size)
test_ds = mnist_dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size)

# Create the model and optimizer
model = mnist.create_model(data_format)
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)
optimizer = tf.train.MomentumOptimizer(flags_obj.lr, flags_obj.momentum)

# Create file writers for writing TensorBoard summaries.
if flags.output_dir:
if flags_obj.output_dir:
# Create directories to which summaries will be written
# tensorboard --logdir=<output_dir>
# can then be used to see the recorded summaries.
train_dir = os.path.join(flags.output_dir, 'train')
test_dir = os.path.join(flags.output_dir, 'eval')
tf.gfile.MakeDirs(flags.output_dir)
train_dir = os.path.join(flags_obj.output_dir, 'train')
test_dir = os.path.join(flags_obj.output_dir, 'eval')
tf.gfile.MakeDirs(flags_obj.output_dir)
else:
train_dir = None
test_dir = None
Expand All @@ -136,19 +137,20 @@ def main(argv):
test_dir, flush_millis=10000, name='test')

# Create and restore checkpoint (if one exists on the path)
checkpoint_prefix = os.path.join(flags.model_dir, 'ckpt')
checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(flags.model_dir))
checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))

# Train and evaluate for a set number of epochs.
with tf.device(device):
for _ in range(flags.train_epochs):
for _ in range(flags_obj.train_epochs):
start = time.time()
with summary_writer.as_default():
train(model, optimizer, train_ds, step_counter, flags.log_interval)
train(model, optimizer, train_ds, step_counter,
flags_obj.log_interval)
end = time.time()
print('\nTrain time for epoch #%d (%d total steps): %f' %
(checkpoint.save_counter.numpy() + 1,
Expand All @@ -159,50 +161,37 @@ def main(argv):
checkpoint.save(checkpoint_prefix)


class MNISTEagerArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model with eager training loop."""

def __init__(self):
super(MNISTEagerArgParser, self).__init__(parents=[
parsers.EagerParser(),
parsers.ImageModelParser()])

self.add_argument(
'--log_interval', '-li',
type=int,
default=10,
metavar='N',
help='[default: %(default)s] batches between logging training status')
self.add_argument(
'--output_dir', '-od',
type=str,
default=None,
metavar='<OD>',
help='[default: %(default)s] Directory to write TensorBoard summaries')
self.add_argument(
'--lr', '-lr',
type=float,
default=0.01,
metavar='<LR>',
help='[default: %(default)s] learning rate')
self.add_argument(
'--momentum', '-m',
type=float,
default=0.5,
metavar='<M>',
help='[default: %(default)s] SGD momentum')
self.add_argument(
'--no_gpu', '-nogpu',
action='store_true',
default=False,
help='disables GPU usage even if a GPU is available')

self.set_defaults(
data_dir='/tmp/tensorflow/mnist/input_data',
model_dir='/tmp/tensorflow/mnist/checkpoints/',
batch_size=100,
train_epochs=10,
)
def define_mnist_eager_flags():
"""Defined flags and defaults for MNIST in eager mode."""
flags_core.define_base_eager()
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've read through flags_core, and I still don't understand what this is doing. What does this do, and can we somehow make that activity more clear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes arguments show up in -help instead of just -helpfull. I think there is just going to have to be some familiarizing with absl concepts like key flags.


flags.DEFINE_integer(
name='log_interval', short_name='li', default=10,
help=flags_core.help_wrap('batches between logging training status'))

flags.DEFINE_string(
name='output_dir', short_name='od', default=None,
help=flags_core.help_wrap('Directory to write TensorBoard summaries'))

flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
help=flags_core.help_wrap('Learning rate.'))

flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
help=flags_core.help_wrap('SGD momentum.'))

flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
help=flags_core.help_wrap(
'disables GPU usage even if a GPU is available'))

flags_core.set_defaults(
data_dir='/tmp/tensorflow/mnist/input_data',
model_dir='/tmp/tensorflow/mnist/checkpoints/',
batch_size=100,
train_epochs=10,
)

if __name__ == '__main__':
main(argv=sys.argv)
define_mnist_eager_flags()
absl_app.run(main=main)
31 changes: 18 additions & 13 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import os
import sys

from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.utils.flags import core as flags_core
from official.resnet import resnet_model
from official.resnet import resnet_run_loop

Expand Down Expand Up @@ -224,25 +227,27 @@ def loss_filter_fn(_):
)


def main(argv):
parser = resnet_run_loop.ResnetArgParser()
# Set defaults that are reasonable for this model.
parser.set_defaults(data_dir='/tmp/cifar10_data',
model_dir='/tmp/cifar10_model',
resnet_size=32,
train_epochs=250,
epochs_between_evals=10,
batch_size=128)
def define_cifar_flags():
resnet_run_loop.define_resnet_flags()
flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(data_dir='/tmp/cifar10_data',
model_dir='/tmp/cifar10_model',
resnet_size='32',
train_epochs=250,
epochs_between_evals=10,
batch_size=128)

flags = parser.parse_args(args=argv[1:])

input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
def main(flags_obj):
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)

resnet_run_loop.resnet_main(
flags, cifar10_model_fn, input_function,
flags_obj, cifar10_model_fn, input_function,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])


if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv)
define_cifar_flags()
absl_app.run(main)
5 changes: 5 additions & 0 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class BaseTest(tf.test.TestCase):
"""Tests for the Cifar10 version of Resnet.
"""

@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
cifar10_main.define_cifar_flags()

def tearDown(self):
super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
Expand Down
Loading