Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def main(argv):

model_function = model_fn

if flags.seed is not None:
model_helpers.set_random_seed(flags.seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's roll this into one line, and just return with a no-op if none.


if flags.multi_gpu:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for here, but, for the record: we should consider at this point rolling some of this logic into the argparser. Something like, set_conditions_from_flags() that handles all of the validation/env-wide setting that we don't need return vals for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. As some models (like MiniGo) may need more flags for different modes, this is indeed necessary.

validate_batch_size_for_multi_gpu(flags.batch_size)

Expand Down
4 changes: 4 additions & 0 deletions official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model
from official.resnet import resnet_run_loop
from official.utils.misc import model_helpers

_DEFAULT_IMAGE_SIZE = 224
_NUM_CHANNELS = 3
Expand Down Expand Up @@ -315,6 +316,9 @@ def main(argv):

flags = parser.parse_args(args=argv[1:])

if flags.seed is not None:
model_helpers.set_random_seed(flags.seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably not be in imagenet, but in the resnet run loop. Otherwise, has to be duped for cifar as well.


input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn

resnet_run_loop.resnet_main(
Expand Down
13 changes: 12 additions & 1 deletion official/utils/arg_parsers/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ class BaseParser(argparse.ArgumentParser):
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
seed: Create a flag to set random seeds.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @robieta as this will conflict with #3887

"""

def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True,
stop_threshold=True, batch_size=True, multi_gpu=True,
hooks=True):
hooks=True, seed=True):
super(BaseParser, self).__init__(add_help=add_help)

if data_dir:
Expand Down Expand Up @@ -176,6 +177,16 @@ def __init__(self, add_help=False, data_dir=True, model_dir=True,
metavar="<HK>"
)

if seed:
self.add_argument(
"--seed", "-s", nargs="+", type=int, default=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this nargs=+? Does this take a list, or copy-pasta?
And, nit: random_seed, rs; seed is too general.

help="[default: %(default)s] An integer to seed random number"
"generators. If unset, RNGs choose their own seeds resulting "
"in each run having a different seed.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work globally, including numpy? Can we expand on this description to describe the expected effect and also which RNGs get set specifically?

metavar="<SEED>"
)



class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments.
Expand Down
23 changes: 23 additions & 0 deletions official/utils/misc/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

import numbers
import random

import tensorflow as tf

Expand Down Expand Up @@ -53,3 +54,25 @@ def past_stop_threshold(stop_threshold, eval_metric):
return True

return False


def set_random_seed(seed):
"""Sets the random seeds for available RNGs.
This seeds RNGs for python's random and for Tensorflow. The intended
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

glint me. I think a \n is required here.

use case is for this to be called exactly once at the start of execution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which "execution" does it mean here? The execution of the training process? Or the execution of one Session? Or others? Does it act the same in the distributed environment (Multiple GPUs) vs. one single GPU?

to improve stability and reproducability between runs.

Successive calls to re-seed will not behave as expected. This should
be called at most once.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can/should we prevent this? What happens if you clear the seed in between with set_random_seed(None)? I think that should work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: does what happens if you reset the graph (reset_default_graph)? If we expect the seed to persist, we should check that it does, and if expect it to not, we should test that it doesn't.


Args:
seed: integer, a seed which will be passed to the RNGs.

Raises:
ValueError: if the seed is not an integer or if deemed unsuitable for
seeding a the RNGs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or if deemed unsuitable? What does that mean?

"""
if not isinstance(seed, int):
raise ValueError("Random seed is not an integer: {}".format(seed))
random.seed(seed)
tf.set_random_seed(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notably, this will persist across sessions, but not across graphs. With estimators, I think the graph is maintained throughout the life of an estimator, but this might have to get called again in the case of starting a new graph? Not sure.

25 changes: 25 additions & 0 deletions official/utils/misc/model_helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

import random

import tensorflow as tf # pylint: disable=g-bad-import-order

from official.utils.misc import model_helpers
Expand Down Expand Up @@ -64,6 +66,29 @@ def test_past_stop_threshold_not_number(self):
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold(tf.constant(4), None)

def test_random_seed(self):
"""It is unclear if this test is a good idea or stable.
If tests are run in parallel, this could be flakey."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should figure this out before including. CC @robieta for testing magic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notably, TF does not seem to test set_random_seed in this way, which implies it is not a good idea. Perhaps we should instead test just that the seeds for python and tf are set as expected by their own reports.

model_helpers.set_random_seed(42)
expected_py_random = [int(random.random() * 1000) for i in range(10)]
tf_random = []
with tf.Session() as sess:
for i in range(10):
a = tf.random_uniform([1])
tf_random.append(int(sess.run(a)[0] * 1000))

model_helpers.set_random_seed(42)
py_random = [int(random.random() * 1000) for i in range(10)]

# Instead of concerning ourselves with the particular results, we simply
# want to ensure that the results are reproducible. So, we seed, read,
# re-seed, re-read.
self.assertAllEqual(expected_py_random, py_random)

# TF does not accept being re-seeded.
expected_tf_random = [637, 689, 961, 969, 321, 390, 919, 681, 112, 187]
self.assertAllEqual(expected_tf_random, tf_random)


if __name__ == "__main__":
tf.test.main()
3 changes: 3 additions & 0 deletions official/wide_deep/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def main(argv):
parser = WideDeepArgParser()
flags = parser.parse_args(args=argv[1:])

if flags.seed is not None:
model_helpers.set_random_seed(flags.seed)

# Clean up the model directory if present
shutil.rmtree(flags.model_dir, ignore_errors=True)
model = build_estimator(flags.model_dir, flags.model_type)
Expand Down