Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random seed not set in graph context of Dataset#map #29101

Closed
eha11 opened this issue May 29, 2019 · 9 comments
Closed

Random seed not set in graph context of Dataset#map #29101

eha11 opened this issue May 29, 2019 · 9 comments
Assignees
Labels
comp:data tf.data related issues TF 1.13 Issues related to TF 1.13 type:bug Bug

Comments

@eha11
Copy link

eha11 commented May 29, 2019

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Jupyter notebook on https://colab.research.google.com
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: NA
  • TensorFlow installed from (source or binary): Stock on https:///colab.research.google.com
  • TensorFlow version (use command below): b'v1.13.1-2-g09e3b09e69' 1.13.1
  • Python version: 2/3
  • Bazel version (if compiling from source): NA
  • GCC/Compiler version (if compiling from source): NA
  • CUDA/cuDNN version: NA
  • GPU model and memory: NA

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
The random seed set via tf.set_random_seed(seed) is not set in the context in which the functions passed to tf.data.Dataset#map are invoked. Even for the single thread case.
Describe the expected behavior
The random seed set via tf.set_random_seed(seed) should be set in the context in which the functions passed to tf.data.Dataset#map are invoked, at least for the single thread case.

Code to reproduce the issue

import tensorflow as tf

def seed_assert(elt):
  seed = tf.get_default_graph().seed
  print("Seed is {}".format(seed))
  assert seed is not None, "Random seed is not set. Random graph operations added during mapping will not be reproducible."
  return elt

seed = 37
      
tf.set_random_seed(seed)  

ds = tf.data.Dataset.from_generator(lambda : (yield (0)), (tf.int64))

seed_assert(None)

ds.map(seed_assert)

Can run here:
Seed in Dataset#Map.ipynb

Other info / logs
I originally saw this issue locally but was able to reproduce it on the Jupyter notebook provided by Google. Here is the log of the errors I see when running the above code.

Seed is 37
Seed is None

---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

<ipython-input-7-38991a9ee77e> in <module>()
     15 seed_assert(None)
     16 
---> 17 ds.map(seed_assert)

8 frames

<ipython-input-7-38991a9ee77e> in seed_assert(elt)
      4   seed = tf.get_default_graph().seed
      5   print("Seed is {}".format(seed))
----> 6   assert seed is not None, "Random seed is not set. Random graph operations added during mapping will not be reproducible."
      7   return elt
      8 

AssertionError: Random seed is not set. Random graph operations added during mapping will not be reproducible.
@eha11 eha11 changed the title Random seed set in graph context of Dataset#map Random seed not set in graph context of Dataset#map May 29, 2019
@gadagashwini-zz gadagashwini-zz self-assigned this May 29, 2019
@gadagashwini-zz gadagashwini-zz added TF 1.13 Issues related to TF 1.13 comp:data tf.data related issues labels May 29, 2019
@gadagashwini-zz
Copy link
Contributor

@eha11 I tried to reproduce the issue on my system and also on google colab but code executed as expected. Can you try once and let us know if that still an issue. Thanks!

@gadagashwini-zz gadagashwini-zz added the stat:awaiting response Status - Awaiting response from author label May 29, 2019
@eha11
Copy link
Author

eha11 commented May 29, 2019

@gadagashwini as I wrote the example originally it wouldn't fail, just print out that the seed was None. I've simplified and updated the example to make it assert when the seed is None. I can reliably reproduce this in the linked Google Collab above and locally.

@TimZaman
Copy link
Contributor

TimZaman commented May 29, 2019

A related but slightly different issue was posted in #13932. That particular issue is more about race conditions due to parallelism.

The issue in this MR, is also reported in another MR's comment: #23789 (comment)
That MR is closed and the solution stated as:

the fix here will depend on switching to the new Python function implementation
@shivaniag reports that:
The new python function implementation is in and this issue has been fixed with that
and has subsequently closed the issue, without linking to the actual solution.

From gathering a few snippets, I have created the following test that shows the issue:

import numpy as np
import tensorflow as tf

def test(threads):
  tf.reset_default_graph()
  np.random.seed(42)
  tf.set_random_seed(42)
  images = np.random.rand(100, 64, 64, 3).astype(np.float32)

  def mapfn(p):
    return  tf.image.random_hue(p, 0.04)

  dataset = tf.data.Dataset.from_tensor_slices(images)
  dataset = dataset.map(mapfn, num_parallel_calls=threads)
  dataset = dataset.batch(32)
  x = dataset.make_one_shot_iterator().get_next()

  with tf.Session() as sess:
    return sess.run(x)

assert np.allclose(test(1), test(1)), "num_parallel_calls=1 undeterministic"
assert np.allclose(test(15), test(15)),  "num_parallel_calls=15 undeterministic"

num_parallel_calls == 1 fails

Above fails with

AssertionError: num_parallel_calls=1 undeterministic

We can solve this first test with 1 parallel call in two ways:

   def mapfn(p):
-    return  tf.image.random_hue(p, 0.04)
+    return  tf.image.random_hue(p, 0.04, seed=42)

Or

   def mapfn(p):
+    tf.set_random_seed(42)
     return  tf.image.random_hue(p, 0.04)

@mrry, Derek, It's unclear to me why map does not respect the graph's default seed?

num_parallel_calls > 1 fails

And of course, after fixing the first test as above, the second one will fail with 15 parallel calls:

AssertionError: num_parallel_calls=15 undeterministic

But that case is just a race condition and not necessarily a bug. This is what #13932 is all about.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label May 30, 2019
@gadagashwini-zz
Copy link
Contributor

@gadagashwini as I wrote the example originally it wouldn't fail, just print out that the seed was None. I've simplified and updated the example to make it assert when the seed is None. I can reliably reproduce this in the linked Google Collab above and locally.

You are correct. I am able to reproduce the reported issue with Tensorflow 1.13. Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 31, 2019
@jsimsa
Copy link
Contributor

jsimsa commented Jun 1, 2019

@eha11 could you please see if the issue is present in 1.14? RC0 for 1.14 was released three days ago.

I was not able to reproduce the issue locally running the bleeding edge of TensorFlow which hopefully means this issue has been indeed fixed.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 1, 2019
@tensorflowbutler
Copy link
Member

Hi There,

We are checking to see if you still need help on this, as you are using an older version of tensorflow which is officially considered end of life . We recommend that you upgrade to the latest 2.x version and let us know if the issue still persists in newer versions. Please open a new issue for any help you need against 2.x, and we will get you the right help.

This issue will be closed automatically 7 days from now. If you still need help with this issue, please provide us with more information.

@TimZaman
Copy link
Contributor

Hi There,

We are checking to see if you still need help on this, as you are using an older version of tensorflow which is officially considered end of life . We recommend that you upgrade to the latest 2.x version and let us know if the issue still persists in newer versions. Please open a new issue for any help you need against 2.x, and we will get you the right help.

This issue will be closed automatically 7 days from now. If you still need help with this issue, please provide us with more information.

Nice robot. This post has nothing to do with TF1 specifically.

@eha11
Copy link
Author

eha11 commented Feb 19, 2021

Yes this was fixed at least by 1.15 so it can be closed. Hi @TimZaman !

@eha11 eha11 closed this as completed Feb 19, 2021
@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues TF 1.13 Issues related to TF 1.13 type:bug Bug
Projects
None yet
Development

No branches or pull requests

6 participants