Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset: "Shuffle" doesn't work #13446

Closed
jundengdeng opened this issue Oct 2, 2017 · 5 comments
Closed

Dataset: "Shuffle" doesn't work #13446

jundengdeng opened this issue Oct 2, 2017 · 5 comments
Labels
type:support Support issues

Comments

@jundengdeng
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
  • TensorFlow installed from (source or binary): source
  • TensorFlow version (use command below):
  • Python version: python 3.5
  • Bazel version (if compiling from source):
    Build label: 0.5.4
    Build target: bazel-out/local-fastbuild/bin/src/main/java/com/google/devtools/build/lib/bazel/BazelServer_deploy.jar
    Build time: Fri Aug 25 10:00:00 2017 (1503655200)
    Build timestamp: 1503655200
    Build timestamp as int: 1503655200
  • CUDA/cuDNN version:
  • GPU model and memory:
  • Exact command to reproduce:

You can collect some of this information using our environment capture script:

https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh

You can obtain the TensorFlow version with

python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"

== tensorflow import ============================================
tf.VERSION = 1.3.0
tf.GIT_VERSION = b'v1.3.0-rc1-2408-ge9d5ee1'
tf.COMPILER_VERSION = b'v1.3.0-rc1-2408-ge9d5ee1'

Describe the problem

"Shuffle" from Dataset doesn't work.

Source code / logs

The following files can be used to reproduce problem.

import tensorflow as tf
tf.set_random_seed(123)

def input_pipeline(filenames, batch_size):
    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
    dataset = (tf.contrib.data.TextLineDataset(filenames)
               .map(lambda line: tf.decode_csv(
                    line, record_defaults=[['1'], ['1'], ['1']], field_delim='-'))
               .shuffle(buffer_size=10)  # Equivalent to min_after_dequeue=10.
               .batch(batch_size))

    # Return an *initializable* iterator over the dataset, which will allow us to
    # re-initialize it at the beginning of each epoch.
    return dataset.make_initializable_iterator()

filenames=['1.txt']
batch_size = 3
num_epochs = 3
iterator = input_pipeline(filenames, batch_size)

# `a1`, `a2`, and `a3` represent the next element to be retrieved from the iterator.
a1, a2, a3 = iterator.get_next()

with tf.Session() as sess:
    for epoch in range(num_epochs):
        # Resets the iterator at the beginning of an epoch.
        sess.run(iterator.initializer)
        print('epoch:%d' % (epoch))
        try:
            while True:
                a, b, c = sess.run([a1, a2, a3])
                print(a, b, c)
        except tf.errors.OutOfRangeError:
            # This will be raised when you reach the end of an epoch (i.e. the
            # iterator has no more elements).
            pass

The corresponding file: "1.txt"

1,2-3,4-A
7,8-9,10-B
12,13-14,15-C
17,18-19,20-D
22,23-24,25-E
27,28-29,30-F
32,33-34,35-G
37,38-39,40-H

The output:

2017-10-02 15:06:43.523320: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:892] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2017-10-02 15:06:43.523788: I tensorflow/core/common_runtime/gpu/gpu_device.cc:965] Found device 0 with properties: 
name: GeForce GTX 1060 major: 6 minor: 1 memoryClockRate(GHz): 1.6705
pciBusID: 0000:01:00.0
totalMemory: 5.93GiB freeMemory: 5.44GiB
2017-10-02 15:06:43.523800: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1055] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce GTX 1060, pci bus id: 0000:01:00.0, compute capability: 6.1)
epoch:0
[b'27,28' b'17,18' b'7,8'] [b'29,30' b'19,20' b'9,10'] [b'F' b'D' b'B']
[b'32,33' b'22,23' b'12,13'] [b'34,35' b'24,25' b'14,15'] [b'G' b'E' b'C']
[b'1,2' b'37,38'] [b'3,4' b'39,40'] [b'A' b'H']
epoch:1
[b'27,28' b'17,18' b'7,8'] [b'29,30' b'19,20' b'9,10'] [b'F' b'D' b'B']
[b'32,33' b'22,23' b'12,13'] [b'34,35' b'24,25' b'14,15'] [b'G' b'E' b'C']
[b'1,2' b'37,38'] [b'3,4' b'39,40'] [b'A' b'H']
epoch:2
[b'27,28' b'17,18' b'7,8'] [b'29,30' b'19,20' b'9,10'] [b'F' b'D' b'B']
[b'32,33' b'22,23' b'12,13'] [b'34,35' b'24,25' b'14,15'] [b'G' b'E' b'C']
[b'1,2' b'37,38'] [b'3,4' b'39,40'] [b'A' b'H']
@mrry
Copy link
Contributor

mrry commented Oct 2, 2017

Can you describe the problem? From the shuffled output, it does seem to be working as intended.

(Perhaps you need to pass a different seed for each epoch? When you call tf.set_random_seed(123) at the beginning of your program, that will make the output of the shuffle deterministic.)

@jundengdeng
Copy link
Author

thanks a lot for your instant response!
I expect that the order of the shuffled output in each epoch would be different. When a seed is set by tf.se_random_seed, the training process is supposed to be reproducible.
But the current shuffle function always produces the same order of the shuffled outputs for each epoch.

@mrry
Copy link
Contributor

mrry commented Oct 2, 2017

Right, it is trying very hard to be reproducible, but there's no indication in the code of when it should reshuffle. (Consider that if it did reshuffle each iteration, you wouldn't be able to reproduce the same sequence within the same session.) If you're using tf.set_random_seed(), you need to do a little more work to get different orders on each epoch. The easiest workaround would be to specify an additional seed as a placeholder, which you then feed with a different value on each :

def input_pipeline(filenames, batch_size, seed=None):
    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
    dataset = (tf.contrib.data.TextLineDataset(filenames)
               .map(lambda line: tf.decode_csv(
                    line, record_defaults=[['1'], ['1'], ['1']], field_delim='-'))
               .shuffle(buffer_size=10, seed=seed)  # Equivalent to min_after_dequeue=10.
               .batch(batch_size))

    # Return an *initializable* iterator over the dataset, which will allow us to
    # re-initialize it at the beginning of each epoch.
    return dataset.make_initializable_iterator()

filenames=['1.txt']
batch_size = 3
num_epochs = 3
seed = tf.placeholder(tf.int64, shape=())
iterator = input_pipeline(filenames, batch_size, seed)

# `a1`, `a2`, and `a3` represent the next element to be retrieved from the iterator.
a1, a2, a3 = iterator.get_next()

with tf.Session() as sess:
    for epoch in range(num_epochs):
        # Resets the iterator at the beginning of an epoch.
        sess.run(iterator.initializer, feed_dict={seed: epoch})
        print('epoch:%d' % (epoch))
        try:
            while True:
                a, b, c = sess.run([a1, a2, a3])
                print(a, b, c)
        except tf.errors.OutOfRangeError:
            # This will be raised when you reach the end of an epoch (i.e. the
            # iterator has no more elements).
            pass

@asimshankar asimshankar added stat:awaiting response Status - Awaiting response from author type:support Support issues labels Oct 2, 2017
@jundengdeng
Copy link
Author

I appreciate the solution you suggested. Do you think, is there any more elegant way?

@aselle aselle removed the stat:awaiting response Status - Awaiting response from author label Oct 2, 2017
@mrry
Copy link
Contributor

mrry commented Oct 2, 2017

I think the workaround is about as elegant as it will get with the current API. (One could imagine adding something like per-run() seeds and threading those through the random operations, but that would be a big change!)

@mrry mrry closed this as completed Oct 2, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:support Support issues
Projects
None yet
Development

No branches or pull requests

4 participants