-
Couldn't load subscription status.
- Fork 45.4k
Unified arg parser #3574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unified arg parser #3574
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,194 @@ | ||
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| """Collection of parsers which are shared among the official models. | ||
|
|
||
| The parsers in this module are intended to be used as parents to all arg | ||
| parsers in official models. For instance, one might define a new class: | ||
|
|
||
| class ExampleParser(argparse.ArgumentParser): | ||
| def __init__(self): | ||
| super(ExampleParser, self).__init__(parents=[ | ||
| official.utils.arg_parsers.LocationParser(data_dir=True, model_dir=True), | ||
| official.utils.arg_parsers.DummyParser(use_synthetic_data=True), | ||
| ]) | ||
|
|
||
| self.add_argument( | ||
| "--application_specific_arg", "-asa", type=int, default=123, | ||
| help="[default: %(default)s] This arg is application specific.", | ||
| metavar="<ASA>" | ||
| ) | ||
|
|
||
| Notes about add_argument(): | ||
| Argparse will automatically template in default values in help messages if | ||
| the "%(default)s" string appears in the message. Using the example above: | ||
|
|
||
| parser = ExampleParser() | ||
| parser.set_defaults(application_specific_arg=3141592) | ||
| parser.parse_args(["-h"]) | ||
|
|
||
| When the help text is generated, it will display 3141592 to the user. (Even | ||
| though the default was 123 when the flag was created.) | ||
|
|
||
|
|
||
| The metavar variable determines how the flag will appear in help text. If | ||
| not specified, the convention is to use name.upper(). Thus rather than: | ||
|
|
||
| --application_specific_arg APPLICATION_SPECIFIC_ARG, -asa APPLICATION_SPECIFIC_ARG | ||
|
|
||
| if metavar="<ASA>" is set, the user sees: | ||
|
|
||
| --application_specific_arg <ASA>, -asa <ASA> | ||
|
|
||
| """ | ||
|
|
||
| import argparse | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can collapse some of these categories-- all will need location, device, supervised-- those can be lumped into ModelParser or something, with all args True by default. Here we have tiny groups, and also the ability to turn things on and off one by one, which seems redundant. If we collapse those three, then remove the need to turn each option on individually, you substantially reduce the overall code required in resnet above, for example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I left the default as False for the secondary classes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's go ahead and make default=True for those as well. I would prefer to reduce the total amount of arg_parser code that users have to read through when looking at Resnet. |
||
|
|
||
| class BaseParser(argparse.ArgumentParser): | ||
| """Parser to contain flags which will be nearly universal across models. | ||
|
|
||
| Args: | ||
| add_help: Create the "--help" flag. False if class instance is a parent. | ||
| data_dir: Create a flag for specifying the input data directory. | ||
| model_dir: Create a flag for specifying the model file directory. | ||
| train_epochs: Create a flag to specify the number of training epochs. | ||
| epochs_per_eval: Create a flag to specify the frequency of testing. | ||
| batch_size: Create a flag to specify the batch size. | ||
| multi_gpu: Create a flag to allow the use of all available GPUs. | ||
| """ | ||
|
|
||
| def __init__(self, add_help=False, data_dir=True, model_dir=True, | ||
| train_epochs=True, epochs_per_eval=True, batch_size=True, | ||
| multi_gpu=True): | ||
| super(BaseParser, self).__init__(add_help=add_help) | ||
|
|
||
| if data_dir: | ||
| self.add_argument( | ||
| "--data_dir", "-dd", default="/tmp", | ||
| help="[default: %(default)s] The location of the input data.", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe comment on some of the arg parser magic here? What is the default templating, and metavar? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added to the docstring. |
||
| metavar="<DD>", | ||
| ) | ||
|
|
||
| if model_dir: | ||
| self.add_argument( | ||
| "--model_dir", "-md", default="/tmp", | ||
| help="[default: %(default)s] The location of the model files.", | ||
| metavar="<MD>", | ||
| ) | ||
|
|
||
| if train_epochs: | ||
| self.add_argument( | ||
| "--train_epochs", "-te", type=int, default=1, | ||
| help="[default: %(default)s] The number of epochs used to train.", | ||
| metavar="<TE>" | ||
| ) | ||
|
|
||
| if epochs_per_eval: | ||
| self.add_argument( | ||
| "--epochs_per_eval", "-epe", type=int, default=1, | ||
| help="[default: %(default)s] The number of training epochs to run " | ||
| "between evaluations.", | ||
| metavar="<EPE>" | ||
| ) | ||
|
|
||
| if batch_size: | ||
| self.add_argument( | ||
| "--batch_size", "-bs", type=int, default=32, | ||
| help="[default: %(default)s] Batch size for training and evaluation.", | ||
| metavar="<BS>" | ||
| ) | ||
|
|
||
| if multi_gpu: | ||
| self.add_argument( | ||
| "--multi_gpu", action="store_true", | ||
| help="If set, run across all available GPUs." | ||
| ) | ||
|
|
||
|
|
||
| class PerformanceParser(argparse.ArgumentParser): | ||
| """Default parser for specifying performance tuning arguments. | ||
|
|
||
| Args: | ||
| add_help: Create the "--help" flag. False if class instance is a parent. | ||
| num_parallel_calls: Create a flag to specify parallelism of data loading. | ||
| inter_op: Create a flag to allow specification of inter op threads. | ||
| intra_op: Create a flag to allow specification of intra op threads. | ||
| """ | ||
|
|
||
| def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True, | ||
| intra_op=True, use_synthetic_data=True): | ||
| super(PerformanceParser, self).__init__(add_help=add_help) | ||
|
|
||
| if num_parallel_calls: | ||
| self.add_argument( | ||
| "--num_parallel_calls", "-npc", | ||
| type=int, default=5, | ||
| help="[default: %(default)s] The number of records that are " | ||
| "processed in parallel during input processing. This can be " | ||
| "optimized per data set but for generally homogeneous data " | ||
| "sets, should be approximately the number of available CPU " | ||
| "cores.", | ||
| metavar="<NPC>" | ||
| ) | ||
|
|
||
| if inter_op: | ||
| self.add_argument( | ||
| "--inter_op_parallelism_threads", "-inter", | ||
| type=int, default=0, | ||
| help="[default: %(default)s Number of inter_op_parallelism_threads " | ||
| "to use for CPU. See TensorFlow config.proto for details.", | ||
| metavar="<INTER>" | ||
| ) | ||
|
|
||
| if intra_op: | ||
| self.add_argument( | ||
| "--intra_op_parallelism_threads", "-intra", | ||
| type=int, default=0, | ||
| help="[default: %(default)s Number of intra_op_parallelism_threads " | ||
| "to use for CPU. See TensorFlow config.proto for details.", | ||
| metavar="<INTRA>" | ||
| ) | ||
|
|
||
| if use_synthetic_data: | ||
| self.add_argument( | ||
| "--use_synthetic_data", "-synth", | ||
| action="store_true", | ||
| help="If set, use fake data (zeroes) instead of a real dataset. " | ||
| "This mode is useful for performance debugging, as it removes " | ||
| "input processing steps, but will not learn anything." | ||
| ) | ||
|
|
||
|
|
||
| class ImageModelParser(argparse.ArgumentParser): | ||
| """Default parser for specification image specific behavior. | ||
|
|
||
| Args: | ||
| add_help: Create the "--help" flag. False if class instance is a parent. | ||
| data_format: Create a flag to specify image axis convention. | ||
| """ | ||
|
|
||
| def __init__(self, add_help=False, data_format=True): | ||
| super(ImageModelParser, self).__init__(add_help=add_help) | ||
| if data_format: | ||
| self.add_argument( | ||
| "--data_format", "-df", | ||
| help="A flag to override the data format used in the model. " | ||
| "channels_first provides a performance boost on GPU but is not " | ||
| "always compatible with CPU. If left unspecified, the data " | ||
| "format will be chosen automatically based on whether TensorFlow" | ||
| "was built for CPU or GPU.", | ||
| metavar="<CF>", | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import argparse | ||
| import unittest | ||
|
|
||
|
|
||
| from official.utils.arg_parsers import parsers | ||
|
|
||
|
|
||
| class TestParser(argparse.ArgumentParser): | ||
| """Class to test canned parser functionality.""" | ||
|
|
||
| def __init__(self): | ||
| super(TestParser, self).__init__(parents=[ | ||
| parsers.BaseParser(), | ||
| parsers.PerformanceParser(num_parallel_calls=True, inter_op=True, | ||
| intra_op=True, use_synthetic_data=True), | ||
| parsers.ImageModelParser(data_format=True) | ||
| ]) | ||
|
|
||
|
|
||
| class BaseTester(unittest.TestCase): | ||
|
|
||
| def test_default_setting(self): | ||
| """Test to ensure fields exist and defaults can be set. | ||
| """ | ||
|
|
||
| defaults = dict( | ||
| data_dir="dfgasf", | ||
| model_dir="dfsdkjgbs", | ||
| train_epochs=534, | ||
| epochs_per_eval=15, | ||
| batch_size=256, | ||
| num_parallel_calls=18, | ||
| inter_op_parallelism_threads=5, | ||
| intra_op_parallelism_thread=10, | ||
| data_format="channels_first" | ||
| ) | ||
|
|
||
| parser = TestParser() | ||
| parser.set_defaults(**defaults) | ||
|
|
||
| namespace_vars = vars(parser.parse_args([])) | ||
| for key, value in defaults.items(): | ||
| assert namespace_vars[key] == value | ||
|
|
||
| def test_booleans(self): | ||
| """Test to ensure boolean flags trigger as expected. | ||
| """ | ||
|
|
||
| parser = TestParser() | ||
| namespace = parser.parse_args(["--multi_gpu", "--use_synthetic_data"]) | ||
|
|
||
| assert namespace.multi_gpu | ||
| assert namespace.use_synthetic_data | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do u also want to test the case that the flag value is specified while the arg parser is not turned on? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This turns out to not be practical. When parse_args fails, argparse fails hard by calling sys.exit(2) rather than raising an exception. While I certainly could invoke subprocess and assert an exit code of 2, this seems like too much machinery for a test that is just intended to show off parse_args(). |
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it matters one way or the other, but a quick check of TF proper implies that they only include the license in init files if there is actual code in them, otherwise empty. Fine to leave empty in this case, probably.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.