Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-determinism from tf.data.Dataset.map with random ops #13932

Closed
dusenberrymw opened this issue Oct 24, 2017 · 9 comments
Closed

Non-determinism from tf.data.Dataset.map with random ops #13932

dusenberrymw opened this issue Oct 24, 2017 · 9 comments
Labels
stat:awaiting response Status - Awaiting response from author

Comments

@dusenberrymw
Copy link

dusenberrymw commented Oct 24, 2017

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes -- please see the minimal reproducible example script below.
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.12, Linux CentOS 7 (4.6.6-300.el7.centos.x86_64)
  • TensorFlow installed from (source or binary): pip3 install tf-nightly (also happens when built from source)
  • TensorFlow version (use command below): v1.3.0-rc1-3690-g9b9cbbe 1.5.0-dev20171023
  • Python version: 3.6.3
  • Bazel version (if compiling from source): N/A since nightly build reproduces the issue (but when built from source, I use 0.6.1-homebrew)
  • CUDA/cuDNN version: a GPU is not needed to reproduce the issue (however, it has also been tested with CUDA 8.0.61 / cuDNN 7.0.1)
  • GPU model and memory: N/A -- a GPU is not needed to reproduce the issue (however, it has also been tested with Tesla K80s)
  • Exact command to reproduce: See minimal reproducible example below

Describe the problem

The new tf.data.Dataset API contains a map function with a num_parallel_calls parameter, which allows elements to be processed in parallel by multiple threads. Although not explicitly mentioned in the API docs, prior discussions (such as a comment from today) have indicated that the map function should be deterministic (w.r.t. the graph seed) even if num_parallel_calls > 1. I have observed that if the function being mapped contains only non-random ops, then this determinism is observed (see step 2 below). However, if the the function being mapped contains a random op, the results become non-deterministic for all values of num_parallel_calls > 1. This is unexpected, and prevents training experiments from being reproducible, unless num_parallel_calls == 1. Also, please note that the example below serves as a minimal example to reproduce the issue. The real scenario involves running data augmentation during training.

Source code / logs

  1. pip3 install tf-nightly
  2. Run the following code to observe that map functions with only non-random ops are deterministic for all values of num_parallel_calls, which is the expected behavior:
import numpy as np
import tensorflow as tf

def test(threads):
  np.random.seed(42)
  tf.set_random_seed(42)
  images = np.random.rand(100, 64, 64, 3).astype(np.float32)

  def get_data():
    dataset = tf.data.Dataset.from_tensor_slices(images)  # some initial dataset
    dataset = dataset.map(lambda x: x * 2, num_parallel_calls=threads)  # this works fine always
    dataset = dataset.batch(32)
    x = dataset.make_one_shot_iterator().get_next()
    return x

  # execution 1
  x = get_data()
  with tf.Session() as sess:
    x_batch1 = sess.run(x)

  # clear out everything
  tf.reset_default_graph()

  # execution 2
  x = get_data()
  with tf.Session() as sess:
    x_batch2 = sess.run(x)

  # results should be equivalent
  assert np.allclose(x_batch1, x_batch2)

test(1)  # works with 1 thread!
test(15)  # works with >1 threads!
  1. Run the following code to observe that map functions with random ops are deterministic if num_parallel_calls == 1, but are non-deterministic for values of num_parallel_calls > 1, which seems to me to be an unexpected behavior:
import numpy as np
import tensorflow as tf

def test(threads):
  np.random.seed(42)
  tf.set_random_seed(42)
  images = np.random.rand(100, 64, 64, 3).astype(np.float32)

  def get_data():
    dataset = tf.data.Dataset.from_tensor_slices(images)  # some initial dataset
    # ONLY DIFFERENCE IS THE BELOW LINE:
    dataset = dataset.map(lambda image: tf.image.random_hue(image, 0.04, seed=42), num_parallel_calls=threads)
    # ONLY DIFFERENCE IS THE ABOVE LINE ^^^:
    dataset = dataset.batch(32)
    x = dataset.make_one_shot_iterator().get_next()
    return x

  # execution 1
  x = get_data()
  with tf.Session() as sess:
    x_batch1 = sess.run(x)

  # clear out everything
  tf.reset_default_graph()

  # execution 2
  x = get_data()
  with tf.Session() as sess:
    x_batch2 = sess.run(x)

  # results should be equivalent
  assert np.allclose(x_batch1, x_batch2)

test(1)  # works with 1 thread!
test(15)  # fails with >1 threads!
  1. Observe that swapping out the map line above with an entirely different random op such as dataset = dataset.map(lambda x: x * tf.random_normal([64, 64, 3], seed=42), num_parallel_calls=threads) is also non-deterministic for values of num_parallel_calls > 1.
@skye
Copy link
Member

skye commented Oct 24, 2017

@mrry

@skye skye added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 24, 2017
@dusenberrymw
Copy link
Author

Additionally, I would like to note that for steps 3 and 4, an op-level seed must be set on the random ops used within the map function, regardless of whether or not a graph-level seed is set. This appears to be an inconsistent behavior with that of the documentation for tf.set_random_seed():

  1. If the graph-level seed is set, but the operation seed is not: The system deterministically picks an operation seed in conjunction with the graph-level seed so that it gets a unique random sequence.

@viirya
Copy link
Contributor

viirya commented Nov 1, 2017

I'm not familiar with tensorflow codes, but I tried to trace this. Looks like if we can't assign the exact thread in thread pool to run for each input element, we can't make sure the parallel map functions with random ops are deterministic. However, assigning thread sounds counterintuitive to the nature of thread pool.

@mrry
Copy link
Contributor

mrry commented Nov 1, 2017

Unfortunately, this is "expected behavior" due to the way tf.random_uniform() (used inside tf.image.random_hue()) and the other RNG ops are implemented. The parallel invocations of map will race to access the mutable RNG state inside the op, and different invocations will see a non-deterministically chosen element of the same sequence. Currently, the only way to ensure deterministic results from Dataset.map() that contains an RNG op is to set num_parallel_calls=1.

In principle, you could slice your map() function so that the random number generation in a serial fashion, and the compute-intensive part of the op in a parallel map. For example, it's possible to do this manually for tf.image.random_hue(), because it is simply a composition of tf.adjust_hue(..., tf.random_uniform(...)):

import numpy as np
import tensorflow as tf

def test(threads):
  np.random.seed(42)
  tf.set_random_seed(42)
  images = np.random.rand(100, 64, 64, 3).astype(np.float32)

  def get_data():
    dataset = tf.data.Dataset.from_tensor_slices(images)
    # Perform the random number generation in a single-threaded map().
    dataset = dataset.map(
        lambda image: (image, tf.random_uniform([], -0.04, 0.04, seed=42)),
        num_parallel_calls=1)
    # Perform the compute-intensive hue adjustment in a multi-threaded map().
    dataset = dataset.map(
        lambda image, adjustment: tf.image.adjust_hue(image, adjustment),
        num_parallel_calls=threads)
    dataset = dataset.batch(32)
    x = dataset.make_one_shot_iterator().get_next()
    return x

  # execution 1
  x = get_data()
  with tf.Session() as sess:
    x_batch1 = sess.run(x)

  # clear out everything
  tf.reset_default_graph()

  # execution 2
  x = get_data()
  with tf.Session() as sess:
    x_batch2 = sess.run(x)

  # results should be equivalent
  assert np.allclose(x_batch1, x_batch2)

test(1)  # works with 1 thread!
test(15)  # works with >1 threads!

However, this manual approach might not scale to a real program. In our CNN benchmarks, we've been using a sequence number to deterministically map "random" perturbations onto input images. In future we might consider doing this kind of slicing automatically, but that's probably some way off.

Hope this helps though!

@mrry mrry added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Nov 1, 2017
@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and the awaiting response label was assigned. Is this still an issue? Please update the label and/or status accordingly.

1 similar comment
@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and the awaiting response label was assigned. Is this still an issue? Please update the label and/or status accordingly.

@zaccharieramzi
Copy link
Contributor

Hi @mrry ,

I just stumbled on this behaviour.
I wanted to understand whether this was something that could (and would) be fixed in the future?

If not, I think it would be nice to have a warning in the docs, especially since there is a deterministic keyword in the docs of map. I could submit a PR for that if needed.

@xfffrank
Copy link

xfffrank commented Dec 10, 2020

@zaccharieramzi @mrry
I encountered this nondeterminism in map as well, when it is used with the random augmentation function like tf.image.random_brightness and num_parallel_calls > 1.

I tried setting deterministic = True but it didn't work.

By the way, I've already called a function for deterministic results as below.

def seed_everything(seed_value):
    tf.random.set_seed(seed_value)
    os.environ['TF_DETERMINISTIC_OPS'] = '1' 

And the tf version I'm using is tf-nightly-gpu 2.5.0-dev20201130.

@duncanriach
Copy link
Contributor

duncanriach commented Jul 22, 2021

@xfffrank,

The work-around suggested by @mrry can be extended using the stateless random image ops. For example, an early stage in your tf.data.Dataloader pipeline could append a (deterministic) random seed to each example using a single-threaded (num_parallel_calls=1) map. Then, any subsequent stateless random image op in a parallel stage (num_parallel_calls > 1) could use the seed associated with the example. This would require you replacing tf.image.random_brightness with tf.image.stateless_random_brightness in your example.

The advantage of using the relatively newly added stateless random image ops in this way is that you only have to inject one random number per-example into the pipeline and that one random number can be used for all the stateless random image ops (as the op's seed parameter).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response Status - Awaiting response from author
Projects
None yet
Development

No branches or pull requests

8 participants