Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions official/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it matters one way or the other, but a quick check of TF proper implies that they only include the license in init files if there is actual code in them, otherwise empty. Fine to leave empty in this case, probably.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
74 changes: 10 additions & 64 deletions official/resnet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

import tensorflow as tf

from official.utils.arg_parsers import parsers # pylint: disable=g-bad-import-order

_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5

Expand Down Expand Up @@ -779,71 +781,15 @@ class ResnetArgParser(argparse.ArgumentParser):
"""

def __init__(self, resnet_size_choices=None):
super(ResnetArgParser, self).__init__()
self.add_argument(
'--data_dir', type=str, default='/tmp/resnet_data',
help='The directory where the input data is stored.')

self.add_argument(
'--num_parallel_calls', type=int, default=5,
help='The number of records that are processed in parallel '
'during input processing. This can be optimized per data set but '
'for generally homogeneous data sets, should be approximately the '
'number of available CPU cores.')

self.add_argument(
'--model_dir', type=str, default='/tmp/resnet_model',
help='The directory where the model will be stored.')
super(ResnetArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.PerformanceParser(),
parsers.ImageModelParser(),
])

self.add_argument(
'--resnet_size', type=int, default=50,
'--resnet_size', '-rs', type=int, default=50,
choices=resnet_size_choices,
help='The size of the ResNet model to use.')

self.add_argument(
'--train_epochs', type=int, default=100,
help='The number of epochs to use for training.')

self.add_argument(
'--epochs_per_eval', type=int, default=1,
help='The number of training epochs to run between evaluations.')

self.add_argument(
'--batch_size', type=int, default=32,
help='Batch size for training and evaluation.')

self.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. '
'channels_first provides a performance boost on GPU but '
'is not always compatible with CPU. If left unspecified, '
'the data format will be chosen automatically based on '
'whether TensorFlow was built for CPU or GPU.')

self.add_argument(
'--multi_gpu', action='store_true',
help='If set, run across all available GPUs.')

self.add_argument(
'-v', '--version', type=int, choices=[1, 2], dest="version",
default=DEFAULT_VERSION,
help="Version of ResNet. (1 or 2) See README.md for details."
help='[default: %(default)s]The size of the ResNet model to use.',
metavar='<RS>'
)

# Advanced args
self.add_argument(
'--use_synthetic_data', action='store_true',
help='If set, use fake data (zeroes) instead of a real dataset. '
'This mode is useful for performance debugging, as it removes '
'input processing steps, but will not learn anything.')

self.add_argument(
'--inter_op_parallelism_threads', type=int, default=0,
help='Number of inter_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.')

self.add_argument(
'--intra_op_parallelism_threads', type=int, default=0,
help='Number of intra_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.')
14 changes: 14 additions & 0 deletions official/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Empty file.
194 changes: 194 additions & 0 deletions official/utils/arg_parsers/parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Collection of parsers which are shared among the official models.

The parsers in this module are intended to be used as parents to all arg
parsers in official models. For instance, one might define a new class:

class ExampleParser(argparse.ArgumentParser):
def __init__(self):
super(ExampleParser, self).__init__(parents=[
official.utils.arg_parsers.LocationParser(data_dir=True, model_dir=True),
official.utils.arg_parsers.DummyParser(use_synthetic_data=True),
])

self.add_argument(
"--application_specific_arg", "-asa", type=int, default=123,
help="[default: %(default)s] This arg is application specific.",
metavar="<ASA>"
)

Notes about add_argument():
Argparse will automatically template in default values in help messages if
the "%(default)s" string appears in the message. Using the example above:

parser = ExampleParser()
parser.set_defaults(application_specific_arg=3141592)
parser.parse_args(["-h"])

When the help text is generated, it will display 3141592 to the user. (Even
though the default was 123 when the flag was created.)


The metavar variable determines how the flag will appear in help text. If
not specified, the convention is to use name.upper(). Thus rather than:

--application_specific_arg APPLICATION_SPECIFIC_ARG, -asa APPLICATION_SPECIFIC_ARG

if metavar="<ASA>" is set, the user sees:

--application_specific_arg <ASA>, -asa <ASA>

"""

import argparse

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can collapse some of these categories-- all will need location, device, supervised-- those can be lumped into ModelParser or something, with all args True by default. Here we have tiny groups, and also the ability to turn things on and off one by one, which seems redundant. If we collapse those three, then remove the need to turn each option on individually, you substantially reduce the overall code required in resnet above, for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I left the default as False for the secondary classes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go ahead and make default=True for those as well. I would prefer to reduce the total amount of arg_parser code that users have to read through when looking at Resnet.


class BaseParser(argparse.ArgumentParser):
"""Parser to contain flags which will be nearly universal across models.

Args:
add_help: Create the "--help" flag. False if class instance is a parent.
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_per_eval: Create a flag to specify the frequency of testing.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
"""

def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_per_eval=True, batch_size=True,
multi_gpu=True):
super(BaseParser, self).__init__(add_help=add_help)

if data_dir:
self.add_argument(
"--data_dir", "-dd", default="/tmp",
help="[default: %(default)s] The location of the input data.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe comment on some of the arg parser magic here? What is the default templating, and metavar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to the docstring.

metavar="<DD>",
)

if model_dir:
self.add_argument(
"--model_dir", "-md", default="/tmp",
help="[default: %(default)s] The location of the model files.",
metavar="<MD>",
)

if train_epochs:
self.add_argument(
"--train_epochs", "-te", type=int, default=1,
help="[default: %(default)s] The number of epochs used to train.",
metavar="<TE>"
)

if epochs_per_eval:
self.add_argument(
"--epochs_per_eval", "-epe", type=int, default=1,
help="[default: %(default)s] The number of training epochs to run "
"between evaluations.",
metavar="<EPE>"
)

if batch_size:
self.add_argument(
"--batch_size", "-bs", type=int, default=32,
help="[default: %(default)s] Batch size for training and evaluation.",
metavar="<BS>"
)

if multi_gpu:
self.add_argument(
"--multi_gpu", action="store_true",
help="If set, run across all available GPUs."
)


class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments.

Args:
add_help: Create the "--help" flag. False if class instance is a parent.
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
"""

def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True):
super(PerformanceParser, self).__init__(add_help=add_help)

if num_parallel_calls:
self.add_argument(
"--num_parallel_calls", "-npc",
type=int, default=5,
help="[default: %(default)s] The number of records that are "
"processed in parallel during input processing. This can be "
"optimized per data set but for generally homogeneous data "
"sets, should be approximately the number of available CPU "
"cores.",
metavar="<NPC>"
)

if inter_op:
self.add_argument(
"--inter_op_parallelism_threads", "-inter",
type=int, default=0,
help="[default: %(default)s Number of inter_op_parallelism_threads "
"to use for CPU. See TensorFlow config.proto for details.",
metavar="<INTER>"
)

if intra_op:
self.add_argument(
"--intra_op_parallelism_threads", "-intra",
type=int, default=0,
help="[default: %(default)s Number of intra_op_parallelism_threads "
"to use for CPU. See TensorFlow config.proto for details.",
metavar="<INTRA>"
)

if use_synthetic_data:
self.add_argument(
"--use_synthetic_data", "-synth",
action="store_true",
help="If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."
)


class ImageModelParser(argparse.ArgumentParser):
"""Default parser for specification image specific behavior.

Args:
add_help: Create the "--help" flag. False if class instance is a parent.
data_format: Create a flag to specify image axis convention.
"""

def __init__(self, add_help=False, data_format=True):
super(ImageModelParser, self).__init__(add_help=add_help)
if data_format:
self.add_argument(
"--data_format", "-df",
help="A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data "
"format will be chosen automatically based on whether TensorFlow"
"was built for CPU or GPU.",
metavar="<CF>",
)
72 changes: 72 additions & 0 deletions official/utils/arg_parsers/parsers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import argparse
import unittest


from official.utils.arg_parsers import parsers


class TestParser(argparse.ArgumentParser):
"""Class to test canned parser functionality."""

def __init__(self):
super(TestParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.PerformanceParser(num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True),
parsers.ImageModelParser(data_format=True)
])


class BaseTester(unittest.TestCase):

def test_default_setting(self):
"""Test to ensure fields exist and defaults can be set.
"""

defaults = dict(
data_dir="dfgasf",
model_dir="dfsdkjgbs",
train_epochs=534,
epochs_per_eval=15,
batch_size=256,
num_parallel_calls=18,
inter_op_parallelism_threads=5,
intra_op_parallelism_thread=10,
data_format="channels_first"
)

parser = TestParser()
parser.set_defaults(**defaults)

namespace_vars = vars(parser.parse_args([]))
for key, value in defaults.items():
assert namespace_vars[key] == value

def test_booleans(self):
"""Test to ensure boolean flags trigger as expected.
"""

parser = TestParser()
namespace = parser.parse_args(["--multi_gpu", "--use_synthetic_data"])

assert namespace.multi_gpu
assert namespace.use_synthetic_data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do u also want to test the case that the flag value is specified while the arg parser is not turned on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This turns out to not be practical. When parse_args fails, argparse fails hard by calling sys.exit(2) rather than raising an exception. While I certainly could invoke subprocess and assert an exit code of 2, this seems like too much machinery for a test that is just intended to show off parse_args().


if __name__ == "__main__":
unittest.main()