-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Adding flag to set random seeds. #3956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,6 +182,9 @@ def main(argv): | |
|
||
model_function = model_fn | ||
|
||
if flags.seed is not None: | ||
model_helpers.set_random_seed(flags.seed) | ||
|
||
if flags.multi_gpu: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not for here, but, for the record: we should consider at this point rolling some of this logic into the argparser. Something like, set_conditions_from_flags() that handles all of the validation/env-wide setting that we don't need return vals for. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. As some models (like MiniGo) may need more flags for different modes, this is indeed necessary. |
||
validate_batch_size_for_multi_gpu(flags.batch_size) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
from official.resnet import imagenet_preprocessing | ||
from official.resnet import resnet_model | ||
from official.resnet import resnet_run_loop | ||
from official.utils.misc import model_helpers | ||
|
||
_DEFAULT_IMAGE_SIZE = 224 | ||
_NUM_CHANNELS = 3 | ||
|
@@ -315,6 +316,9 @@ def main(argv): | |
|
||
flags = parser.parse_args(args=argv[1:]) | ||
|
||
if flags.seed is not None: | ||
model_helpers.set_random_seed(flags.seed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably not be in imagenet, but in the resnet run loop. Otherwise, has to be duped for cifar as well. |
||
|
||
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn | ||
|
||
resnet_run_loop.resnet_main( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,12 +104,13 @@ class BaseParser(argparse.ArgumentParser): | |
batch_size: Create a flag to specify the batch size. | ||
multi_gpu: Create a flag to allow the use of all available GPUs. | ||
hooks: Create a flag to specify hooks for logging. | ||
seed: Create a flag to set random seeds. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
""" | ||
|
||
def __init__(self, add_help=False, data_dir=True, model_dir=True, | ||
train_epochs=True, epochs_between_evals=True, | ||
stop_threshold=True, batch_size=True, multi_gpu=True, | ||
hooks=True): | ||
hooks=True, seed=True): | ||
super(BaseParser, self).__init__(add_help=add_help) | ||
|
||
if data_dir: | ||
|
@@ -176,6 +177,16 @@ def __init__(self, add_help=False, data_dir=True, model_dir=True, | |
metavar="<HK>" | ||
) | ||
|
||
if seed: | ||
self.add_argument( | ||
"--seed", "-s", nargs="+", type=int, default=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this nargs=+? Does this take a list, or copy-pasta? |
||
help="[default: %(default)s] An integer to seed random number" | ||
"generators. If unset, RNGs choose their own seeds resulting " | ||
"in each run having a different seed.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this actually work globally, including numpy? Can we expand on this description to describe the expected effect and also which RNGs get set specifically? |
||
metavar="<SEED>" | ||
) | ||
|
||
|
||
|
||
class PerformanceParser(argparse.ArgumentParser): | ||
"""Default parser for specifying performance tuning arguments. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from __future__ import print_function | ||
|
||
import numbers | ||
import random | ||
|
||
import tensorflow as tf | ||
|
||
|
@@ -53,3 +54,25 @@ def past_stop_threshold(stop_threshold, eval_metric): | |
return True | ||
|
||
return False | ||
|
||
|
||
def set_random_seed(seed): | ||
"""Sets the random seeds for available RNGs. | ||
This seeds RNGs for python's random and for Tensorflow. The intended | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. glint me. I think a \n is required here. |
||
use case is for this to be called exactly once at the start of execution | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which "execution" does it mean here? The execution of the training process? Or the execution of one Session? Or others? Does it act the same in the distributed environment (Multiple GPUs) vs. one single GPU? |
||
to improve stability and reproducability between runs. | ||
|
||
Successive calls to re-seed will not behave as expected. This should | ||
be called at most once. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can/should we prevent this? What happens if you clear the seed in between with set_random_seed(None)? I think that should work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also: does what happens if you reset the graph (reset_default_graph)? If we expect the seed to persist, we should check that it does, and if expect it to not, we should test that it doesn't. |
||
|
||
Args: | ||
seed: integer, a seed which will be passed to the RNGs. | ||
|
||
Raises: | ||
ValueError: if the seed is not an integer or if deemed unsuitable for | ||
seeding a the RNGs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or if deemed unsuitable? What does that mean? |
||
""" | ||
if not isinstance(seed, int): | ||
raise ValueError("Random seed is not an integer: {}".format(seed)) | ||
random.seed(seed) | ||
tf.set_random_seed(seed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Notably, this will persist across sessions, but not across graphs. With estimators, I think the graph is maintained throughout the life of an estimator, but this might have to get called again in the case of starting a new graph? Not sure. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import random | ||
|
||
import tensorflow as tf # pylint: disable=g-bad-import-order | ||
|
||
from official.utils.misc import model_helpers | ||
|
@@ -64,6 +66,29 @@ def test_past_stop_threshold_not_number(self): | |
with self.assertRaises(ValueError): | ||
model_helpers.past_stop_threshold(tf.constant(4), None) | ||
|
||
def test_random_seed(self): | ||
"""It is unclear if this test is a good idea or stable. | ||
If tests are run in parallel, this could be flakey.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should figure this out before including. CC @robieta for testing magic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Notably, TF does not seem to test set_random_seed in this way, which implies it is not a good idea. Perhaps we should instead test just that the seeds for python and tf are set as expected by their own reports. |
||
model_helpers.set_random_seed(42) | ||
expected_py_random = [int(random.random() * 1000) for i in range(10)] | ||
tf_random = [] | ||
with tf.Session() as sess: | ||
for i in range(10): | ||
a = tf.random_uniform([1]) | ||
tf_random.append(int(sess.run(a)[0] * 1000)) | ||
|
||
model_helpers.set_random_seed(42) | ||
py_random = [int(random.random() * 1000) for i in range(10)] | ||
|
||
# Instead of concerning ourselves with the particular results, we simply | ||
# want to ensure that the results are reproducible. So, we seed, read, | ||
# re-seed, re-read. | ||
self.assertAllEqual(expected_py_random, py_random) | ||
|
||
# TF does not accept being re-seeded. | ||
expected_tf_random = [637, 689, 961, 969, 321, 390, 919, 681, 112, 187] | ||
self.assertAllEqual(expected_tf_random, tf_random) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's roll this into one line, and just return with a no-op if none.