/
flags.py
114 lines (98 loc) · 4.19 KB
/
flags.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The central place to define flags."""
from absl import flags
def define_flags():
"""Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags.DEFINE_string(
'experiment', default=None, help=
'The experiment type registered, specifying an ExperimentConfig.')
flags.DEFINE_enum(
'mode',
default=None,
enum_values=[
'train', 'eval', 'train_and_eval', 'continuous_eval',
'continuous_train_and_eval', 'train_and_validate',
'train_and_post_eval'
],
help='Mode to run: `train`, `eval`, `train_and_eval`, '
'`continuous_eval`, `continuous_train_and_eval` and '
'`train_and_validate` (which is not implemented in '
'the open source version).')
flags.DEFINE_string(
'model_dir',
default=None,
help='The directory where the model and training/evaluation summaries'
'are stored.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override',
default=None,
help='a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to override '
'the model parameters, but not the parameters like TPU specific flags. '
'One canonical use case of `--config_file` and `--params_override` is '
'users first define a template config file using `--config_file`, then '
'use `--params_override` to adjust the minimal set of tuning parameters, '
'for example setting up different `train_batch_size`. The final override '
'order of parameters: default_model_params --> params from config_file '
'--> params in params_override. See also the help message of '
'`--config_file`.')
# The libraries rely on gin often make mistakes that include flags inside
# the library files which causes conflicts.
try:
flags.DEFINE_multi_string(
'gin_file', default=None, help='List of paths to the config files.')
except flags.DuplicateFlagError:
pass
try:
flags.DEFINE_multi_string(
'gin_params',
default=None,
help='Newline separated list of Gin parameter bindings.')
except flags.DuplicateFlagError:
pass
flags.DEFINE_string(
'tpu',
default=None,
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.')
flags.DEFINE_string(
'tf_data_service', default=None, help='The tf.data service address')
flags.DEFINE_string(
'tpu_platform', default=None, help='TPU platform type.')